Skip to content
Snippets Groups Projects
Commit 09820b62 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch 'add_cdf_to_conditional_prior' into 'master'

Added CDF explicitly to conditional priors

See merge request !882
parents 145091c9 21a337a4
No related branches found
No related tags found
1 merge request!882Added CDF explicitly to conditional priors
Pipeline #162514 passed
......@@ -110,9 +110,41 @@ def conditional_prior_factory(prior_class):
return super(ConditionalPrior, self).prob(val)
def ln_prob(self, val, **required_variables):
"""Return the natural log prior probability of val.
Parameters
----------
val: Union[float, int, array_like]
See superclass
required_variables:
Any required variables that this prior depends on
Returns
-------
float: Natural log prior probability of val
"""
self.update_conditions(**required_variables)
return super(ConditionalPrior, self).ln_prob(val)
def cdf(self, val, **required_variables):
"""Return the cdf of val.
Parameters
----------
val: Union[float, int, array_like]
See superclass
required_variables:
Any required variables that this prior depends on
Returns
-------
float: CDF of val
"""
self.update_conditions(**required_variables)
return super(ConditionalPrior, self).cdf(val)
def update_conditions(self, **required_variables):
"""
This method updates the conditional parameters (depending on the parent class
......
......@@ -110,7 +110,7 @@ class TestConditionalPrior(unittest.TestCase):
test_parameter_2=self.test_variable_2,
)
def test_rescale_prob_update_conditions(self):
def test_prob_calls_update_conditions(self):
with mock.patch.object(self.prior, "update_conditions") as m:
self.prior.prob(
1,
......@@ -138,6 +138,21 @@ class TestConditionalPrior(unittest.TestCase):
]
m.assert_has_calls(calls)
def test_cdf_calls_update_conditions(self):
self.prior = bilby.core.prior.ConditionalUniform(
condition_func=self.condition_func, minimum=self.minimum, maximum=self.maximum
)
with mock.patch.object(self.prior, "update_conditions") as m:
self.prior.cdf(
1,
test_parameter_1=self.test_variable_1,
test_parameter_2=self.test_variable_2,
)
m.assert_called_with(
test_parameter_1=self.test_variable_1,
test_parameter_2=self.test_variable_2,
)
def test_reset_to_reference_parameters(self):
self.prior.minimum = 10
self.prior.maximum = 20
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment