diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index dc717ba36128bec960698ee1dedb27a4c21d8f4b..05df00a7bad5a46021f6597617d833c6ee525816 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -63,21 +63,52 @@ class Prior(object): return self.sample() def __eq__(self, other): + """ + Test equality of two prior objects. + + Returns true iff: + + - The class of the two priors are the same + - Both priors have the same keys in the __dict__ attribute + - The instantiation arguments match + + We don't check that all entries the the __dict__ attribute + are equal as some attributes are variable for conditional + priors. + + Parameters + ========== + other: Prior + The prior to compare with + + Returns + ======= + bool + Whether the priors are equivalent + + Notes + ===== + A special case is made for :code `scipy.stats.beta`: instances. + It may be possible to remove this as we now only check instantiation + arguments. + + """ if self.__class__ != other.__class__: return False if sorted(self.__dict__.keys()) != sorted(other.__dict__.keys()): return False - for key in self.__dict__: + this_dict = self.get_instantiation_dict() + other_dict = other.get_instantiation_dict() + for key in this_dict: if key == "least_recently_sampled": - # ignore sample drawn from prior in comparison continue - if type(self.__dict__[key]) is np.ndarray: - if not np.array_equal(self.__dict__[key], other.__dict__[key]): + if isinstance(this_dict[key], np.ndarray): + if not np.array_equal(this_dict[key], other_dict[key]): return False - elif isinstance(self.__dict__[key], type(scipy.stats.beta(1., 1.))): + elif isinstance(this_dict[key], type(scipy.stats.beta(1., 1.))): continue else: - if not self.__dict__[key] == other.__dict__[key]: + if not this_dict[key] == other_dict[key]: return False return True diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index af29c5997d89947d2c9394c17cedb44ea4e0aa47..7c8d7e8c34cdac1f11fea2bdd0e5770085b41f13 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -250,22 +250,6 @@ class Sampler(object): return result - def _check_if_priors_can_be_sampled(self): - """Check if all priors can be sampled properly. - - Raises - ====== - AttributeError - prior can't be sampled. - """ - for key in self.priors: - if isinstance(self.priors[key], Constraint): - continue - try: - self.priors[key].sample() - except AttributeError as e: - logger.warning('Cannot sample from {}, {}'.format(key, e)) - def _verify_parameters(self): """ Evaluate a set of parameters drawn from the prior @@ -322,7 +306,13 @@ class Sampler(object): Checks if use_ratio is set. Prints a warning if use_ratio is set but not properly implemented. """ - self._check_if_priors_can_be_sampled() + try: + self.priors.sample_subset(self.search_parameter_keys) + except (KeyError, AttributeError): + logger.error("Cannot sample from priors with keys: {}.".format( + self.search_parameter_keys + )) + raise if self.use_ratio is False: logger.debug("use_ratio set to False") return diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 093e2eda8db4324985b10aed7d62460ec2ee85a5..caf68fd260036213e6a4c30a022a860b7fdc98c2 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -234,12 +234,26 @@ def convert_to_lal_binary_black_hole_parameters(parameters): for idx in ['1', '2']: key = 'chi_{}'.format(idx) if key in original_keys: - converted_parameters['a_{}'.format(idx)] = abs( - converted_parameters[key]) - converted_parameters['cos_tilt_{}'.format(idx)] = \ - np.sign(converted_parameters[key]) - converted_parameters['phi_jl'] = 0.0 - converted_parameters['phi_12'] = 0.0 + if "chi_{}_in_plane".format(idx) in original_keys: + converted_parameters["a_{}".format(idx)] = ( + converted_parameters[f"chi_{idx}"] ** 2 + + converted_parameters[f"chi_{idx}_in_plane"] ** 2 + ) ** 0.5 + converted_parameters[f"cos_tilt_{idx}"] = ( + converted_parameters[f"chi_{idx}"] + / converted_parameters[f"a_{idx}"] + ) + elif "a_{}".format(idx) not in original_keys: + converted_parameters['a_{}'.format(idx)] = abs( + converted_parameters[key]) + converted_parameters['cos_tilt_{}'.format(idx)] = \ + np.sign(converted_parameters[key]) + converted_parameters['phi_jl'] = 0.0 + converted_parameters['phi_12'] = 0.0 + else: + converted_parameters[f"cos_tilt_{idx}"] = ( + converted_parameters[key] / converted_parameters[f"a_{idx}"] + ) for angle in ['tilt_1', 'tilt_2', 'theta_jn']: cos_angle = str('cos_' + angle) diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 594f6f71126692539a6a66b6e863e3a7886d230a..899e98060cf7023eb616ec4eeda6ee14089c6e30 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -7,10 +7,12 @@ from scipy.integrate import cumtrapz from scipy.special import hyp2f1 from scipy.stats import norm -from ..core.prior import (PriorDict, Uniform, Prior, DeltaFunction, Gaussian, - Interped, Constraint, conditional_prior_factory, - BaseJointPriorDist, JointPrior, JointPriorDistError, - PowerLaw) +from ..core.prior import ( + PriorDict, Uniform, Prior, DeltaFunction, Gaussian, Interped, Constraint, + conditional_prior_factory, PowerLaw, ConditionalLogUniform, + ConditionalPriorDict, ConditionalBasePrior, BaseJointPriorDist, JointPrior, + JointPriorDistError, +) from ..core.utils import infer_args_from_method, logger from .conversion import ( convert_to_lal_binary_black_hole_parameters, @@ -417,6 +419,141 @@ class AlignedSpin(Interped): maximum=maximum) +class ConditionalChiUniformSpinMagnitude(ConditionalLogUniform): + r""" + This prior characterizes the conditional prior on the spin magnitude given + the aligned component of the spin such that the marginal prior is uniform + if the distribution of spin orientations is isotropic. + + .. math:: + p(a) &= \frac{1}{a_{\max}} + p(\chi) &= - \frac{1}{2 a_{\max}} \ln(|\chi|) + p(a | \chi) &\propto \frac{1}{a} + """ + + def __init__(self, minimum, maximum, name, latex_label=None, unit=None, boundary=None): + super(ConditionalChiUniformSpinMagnitude, self).__init__( + minimum=minimum, maximum=maximum, name=name, latex_label=latex_label, unit=unit, boundary=boundary, + condition_func=self._condition_function) + self._required_variables = [name.replace("a", "chi")] + self.__class__.__name__ = "ConditionalChiUniformSpinMagnitude" + self.__class__.__qualname__ = "ConditionalChiUniformSpinMagnitude" + + def _condition_function(self, reference_params, **kwargs): + return dict(minimum=np.abs(kwargs[self._required_variables[0]]), maximum=reference_params["maximum"]) + + def __repr__(self): + return Prior.__repr__(self) + + def get_instantiation_dict(self): + instantiation_dict = Prior.get_instantiation_dict(self) + for key, value in self.reference_params.items(): + if key in instantiation_dict: + instantiation_dict[key] = value + return instantiation_dict + + +class ConditionalChiInPlane(ConditionalBasePrior): + r""" + This prior characterizes the conditional prior on the in-plane spin magnitude + given the aligned component of the spin such that the marginal prior is uniform + if the distribution of spin orientations is isotropic. + + .. math:: + p(a) &= \frac{1}{a_{\max}} + p(\chi_\perp) = 2 N \chi_\perp / (\chi ** 2 + \chi_\perp ** 2) + N^{-1} &= 2 \ln(a_\max / |\chi|) + """ + + def __init__(self, minimum, maximum, name, latex_label=None, unit=None, boundary=None): + super(ConditionalChiInPlane, self).__init__( + minimum=minimum, maximum=maximum, + name=name, latex_label=latex_label, + unit=unit, boundary=boundary, + condition_func=self._condition_function + ) + self._required_variables = [name[:5]] + self._reference_maximum = maximum + self.__class__.__name__ = "ConditionalChiInPlane" + self.__class__.__qualname__ = "ConditionalChiInPlane" + + def prob(self, val, **required_variables): + self.update_conditions(**required_variables) + chi_aligned = abs(required_variables[self._required_variables[0]]) + return ( + (val >= self.minimum) * (val <= self.maximum) + * val + / (chi_aligned ** 2 + val ** 2) + / np.log(self._reference_maximum / chi_aligned) + ) + + def ln_prob(self, val, **required_variables): + return np.log(self.prob(val, **required_variables)) + + def cdf(self, val, **required_variables): + r""" + .. math:: + \text{CDF}(\chi_\per) = N ln(1 + (\chi_\perp / \chi) ** 2) + + Parameters + ---------- + val: (float, array-like) + The value at which to evaluate the CDF + required_variables: dict + A dictionary containing the aligned component of the spin + + Returns + ------- + (float, array-like) + The value of the CDF + + """ + self.update_conditions(**required_variables) + chi_aligned = abs(required_variables[self._required_variables[0]]) + return np.maximum(np.minimum( + (val >= self.minimum) * (val <= self.maximum) + * np.log(1 + (val / chi_aligned) ** 2) + / 2 / np.log(self._reference_maximum / chi_aligned) + , 1 + ), 0) + + def rescale(self, val, **required_variables): + r""" + .. math:: + \text{PPF}(\chi_\perp) = ((a_\max / \chi) ** (2x) - 1) ** 0.5 * \chi + + Parameters + ---------- + val: (float, array-like) + The value to rescale + required_variables: dict + Dictionary containing the aligned spin component + + Returns + ------- + (float, array-like) + The in-plane component of the spin + """ + self.update_conditions(**required_variables) + chi_aligned = abs(required_variables[self._required_variables[0]]) + return chi_aligned * ((self._reference_maximum / chi_aligned) ** (2 * val) - 1) ** 0.5 + + def _condition_function(self, reference_params, **kwargs): + return dict(minimum=0, maximum=( + self._reference_maximum ** 2 - kwargs[self._required_variables[0]] ** 2 + ) ** 0.5) + + def __repr__(self): + return Prior.__repr__(self) + + def get_instantiation_dict(self): + instantiation_dict = Prior.get_instantiation_dict(self) + for key, value in self.reference_params.items(): + if key in instantiation_dict: + instantiation_dict[key] = value + return instantiation_dict + + class EOSCheck(Constraint): def __init__(self, minimum=-np.inf, maximum=np.inf): """ @@ -443,7 +580,7 @@ class EOSCheck(Constraint): return result -class CBCPriorDict(PriorDict): +class CBCPriorDict(ConditionalPriorDict): @property def minimum_chirp_mass(self): mass_1 = None diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py index 2bbbc2bf1e87535ed56a8d47bff35ec57c2c9b4b..6afdc9933ad782fd5cc5eab603f92bf4287a7d5f 100644 --- a/test/core/prior/prior_test.py +++ b/test/core/prior/prior_test.py @@ -667,17 +667,6 @@ class TestPriorClasses(unittest.TestCase): prior_2.other_key = 5 self.assertNotEqual(prior_1, prior_2) - def test_np_array_eq(self): - prior_1 = bilby.core.prior.PowerLaw( - name="test", unit="unit", alpha=0, minimum=0, maximum=1 - ) - prior_2 = bilby.core.prior.PowerLaw( - name="test", unit="unit", alpha=0, minimum=0, maximum=1 - ) - prior_1.array_attribute = np.array([1, 2, 3]) - prior_2.array_attribute = np.array([2, 2, 3]) - self.assertNotEqual(prior_1, prior_2) - def test_repr(self): for prior in self.priors: if isinstance(prior, bilby.core.prior.Interped): diff --git a/test/gw/prior_test.py b/test/gw/prior_test.py index f192b645fadd99d3dc965da533863c15b7573bff..832f2a9d5213ed9f967c8b1ddbd30d6a6c3f21ea 100644 --- a/test/gw/prior_test.py +++ b/test/gw/prior_test.py @@ -523,5 +523,20 @@ class TestAlignedSpin(unittest.TestCase): self.assertAlmostEqual(max_difference, 0, 2) +class TestConditionalChiUniformSpinMagnitude(unittest.TestCase): + + def setUp(self): + pass + + def test_marginalized_prior_is_uniform(self): + priors = bilby.gw.prior.BBHPriorDict(aligned_spin=True) + priors["a_1"] = bilby.gw.prior.ConditionalChiUniformSpinMagnitude( + minimum=0.1, maximum=priors["chi_1"].maximum, name="a_1" + ) + samples = priors.sample(100000)["a_1"] + ks = ks_2samp(samples, np.random.uniform(0, priors["chi_1"].maximum, 100000)) + self.assertTrue(ks.pvalue > 0.001) + + if __name__ == "__main__": unittest.main() diff --git a/test/integration/sample_from_the_prior_test.py b/test/integration/sample_from_the_prior_test.py index bb54aa98bcb8b6613e4464d1d0fe911fac164b5d..e48c3574dc79c44b6b73089f2b3a71d24d4ea6c1 100644 --- a/test/integration/sample_from_the_prior_test.py +++ b/test/integration/sample_from_the_prior_test.py @@ -40,7 +40,7 @@ class Test(unittest.TestCase): duration = 4.0 sampling_frequency = 2048.0 label = "full_15_parameters" - np.random.seed(8817020) + np.random.seed(8817021) waveform_arguments = dict( waveform_approximant="IMRPhenomPv2",