diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index c96c1a05da9b4e9aee7f2159b868c5e1fd1f7ded..80cdbb91acfb3a49342aad479ef8881c27cca71d 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -110,9 +110,41 @@ def conditional_prior_factory(prior_class): return super(ConditionalPrior, self).prob(val) def ln_prob(self, val, **required_variables): + """Return the natural log prior probability of val. + + Parameters + ---------- + val: Union[float, int, array_like] + See superclass + required_variables: + Any required variables that this prior depends on + + + Returns + ------- + float: Natural log prior probability of val + """ self.update_conditions(**required_variables) return super(ConditionalPrior, self).ln_prob(val) + def cdf(self, val, **required_variables): + """Return the cdf of val. + + Parameters + ---------- + val: Union[float, int, array_like] + See superclass + required_variables: + Any required variables that this prior depends on + + + Returns + ------- + float: CDF of val + """ + self.update_conditions(**required_variables) + return super(ConditionalPrior, self).cdf(val) + def update_conditions(self, **required_variables): """ This method updates the conditional parameters (depending on the parent class diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 0e44fda3691172e1ca8618d7ace3a172de991912..a76f10a4fa379d34b4fe35499a76fb4005506741 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -110,7 +110,7 @@ class TestConditionalPrior(unittest.TestCase): test_parameter_2=self.test_variable_2, ) - def test_rescale_prob_update_conditions(self): + def test_prob_calls_update_conditions(self): with mock.patch.object(self.prior, "update_conditions") as m: self.prior.prob( 1, @@ -138,6 +138,21 @@ class TestConditionalPrior(unittest.TestCase): ] m.assert_has_calls(calls) + def test_cdf_calls_update_conditions(self): + self.prior = bilby.core.prior.ConditionalUniform( + condition_func=self.condition_func, minimum=self.minimum, maximum=self.maximum + ) + with mock.patch.object(self.prior, "update_conditions") as m: + self.prior.cdf( + 1, + test_parameter_1=self.test_variable_1, + test_parameter_2=self.test_variable_2, + ) + m.assert_called_with( + test_parameter_1=self.test_variable_1, + test_parameter_2=self.test_variable_2, + ) + def test_reset_to_reference_parameters(self): self.prior.minimum = 10 self.prior.maximum = 20