From 21a337a46b85b428345738c5abd0a8fce570a8f0 Mon Sep 17 00:00:00 2001 From: Moritz Huebner <email@moritz-huebner.de> Date: Tue, 6 Oct 2020 14:25:38 +1100 Subject: [PATCH] Added CDF explicitly to conditional priors --- bilby/core/prior/conditional.py | 32 +++++++++++++++++++++++++++++ test/core/prior/conditional_test.py | 17 ++++++++++++++- 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index c96c1a05d..80cdbb91a 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -110,9 +110,41 @@ def conditional_prior_factory(prior_class): return super(ConditionalPrior, self).prob(val) def ln_prob(self, val, **required_variables): + """Return the natural log prior probability of val. + + Parameters + ---------- + val: Union[float, int, array_like] + See superclass + required_variables: + Any required variables that this prior depends on + + + Returns + ------- + float: Natural log prior probability of val + """ self.update_conditions(**required_variables) return super(ConditionalPrior, self).ln_prob(val) + def cdf(self, val, **required_variables): + """Return the cdf of val. + + Parameters + ---------- + val: Union[float, int, array_like] + See superclass + required_variables: + Any required variables that this prior depends on + + + Returns + ------- + float: CDF of val + """ + self.update_conditions(**required_variables) + return super(ConditionalPrior, self).cdf(val) + def update_conditions(self, **required_variables): """ This method updates the conditional parameters (depending on the parent class diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 0e44fda36..a76f10a4f 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -110,7 +110,7 @@ class TestConditionalPrior(unittest.TestCase): test_parameter_2=self.test_variable_2, ) - def test_rescale_prob_update_conditions(self): + def test_prob_calls_update_conditions(self): with mock.patch.object(self.prior, "update_conditions") as m: self.prior.prob( 1, @@ -138,6 +138,21 @@ class TestConditionalPrior(unittest.TestCase): ] m.assert_has_calls(calls) + def test_cdf_calls_update_conditions(self): + self.prior = bilby.core.prior.ConditionalUniform( + condition_func=self.condition_func, minimum=self.minimum, maximum=self.maximum + ) + with mock.patch.object(self.prior, "update_conditions") as m: + self.prior.cdf( + 1, + test_parameter_1=self.test_variable_1, + test_parameter_2=self.test_variable_2, + ) + m.assert_called_with( + test_parameter_1=self.test_variable_1, + test_parameter_2=self.test_variable_2, + ) + def test_reset_to_reference_parameters(self): self.prior.minimum = 10 self.prior.maximum = 20 -- GitLab