Skip to content
Snippets Groups Projects
Commit 21a337a4 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Added CDF explicitly to conditional priors

parent 82872fe5
No related branches found
No related tags found
1 merge request!882Added CDF explicitly to conditional priors
Pipeline #160070 passed
......@@ -110,9 +110,41 @@ def conditional_prior_factory(prior_class):
return super(ConditionalPrior, self).prob(val)
def ln_prob(self, val, **required_variables):
"""Return the natural log prior probability of val.
Parameters
----------
val: Union[float, int, array_like]
See superclass
required_variables:
Any required variables that this prior depends on
Returns
-------
float: Natural log prior probability of val
"""
self.update_conditions(**required_variables)
return super(ConditionalPrior, self).ln_prob(val)
def cdf(self, val, **required_variables):
"""Return the cdf of val.
Parameters
----------
val: Union[float, int, array_like]
See superclass
required_variables:
Any required variables that this prior depends on
Returns
-------
float: CDF of val
"""
self.update_conditions(**required_variables)
return super(ConditionalPrior, self).cdf(val)
def update_conditions(self, **required_variables):
"""
This method updates the conditional parameters (depending on the parent class
......
......@@ -110,7 +110,7 @@ class TestConditionalPrior(unittest.TestCase):
test_parameter_2=self.test_variable_2,
)
def test_rescale_prob_update_conditions(self):
def test_prob_calls_update_conditions(self):
with mock.patch.object(self.prior, "update_conditions") as m:
self.prior.prob(
1,
......@@ -138,6 +138,21 @@ class TestConditionalPrior(unittest.TestCase):
]
m.assert_has_calls(calls)
def test_cdf_calls_update_conditions(self):
self.prior = bilby.core.prior.ConditionalUniform(
condition_func=self.condition_func, minimum=self.minimum, maximum=self.maximum
)
with mock.patch.object(self.prior, "update_conditions") as m:
self.prior.cdf(
1,
test_parameter_1=self.test_variable_1,
test_parameter_2=self.test_variable_2,
)
m.assert_called_with(
test_parameter_1=self.test_variable_1,
test_parameter_2=self.test_variable_2,
)
def test_reset_to_reference_parameters(self):
self.prior.minimum = 10
self.prior.maximum = 20
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment