diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 67bdb813abc2aaad7518cd9b07721067d50455cd..7a9a8e5dcb88176b9c5622ab2082e20597974ef9 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -491,6 +491,21 @@ class PriorDict(dict): constrained_ln_prob[keep] = ln_prob[keep] + np.log(ratio) return constrained_ln_prob + def cdf(self, sample): + """Evaluate the cumulative distribution function at the provided points + + Parameters + ---------- + sample: dict, pandas.DataFrame + Dictionary of the samples of which to calculate the CDF + + Returns + ------- + dict, pandas.DataFrame: Dictionary containing the CDF values + + """ + return sample.__class__({key: self[key].cdf(sample) for key, sample in sample.items()}) + def rescale(self, keys, theta): """Rescale samples from unit cube to prior @@ -654,9 +669,7 @@ class ConditionalPriorDict(PriorDict): float: Joint probability of all individual sample probabilities """ - self._check_resolved() - for key, value in sample.items(): - self[key].least_recently_sampled = value + self._prepare_evaluation(*zip(*sample.items())) res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample] return np.product(res, **kwargs) @@ -675,12 +688,15 @@ class ConditionalPriorDict(PriorDict): float: Joint log probability of all the individual sample probabilities """ - self._check_resolved() - for key, value in sample.items(): - self[key].least_recently_sampled = value + self._prepare_evaluation(*zip(*sample.items())) res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample] return np.sum(res, axis=axis) + def cdf(self, sample): + self._prepare_evaluation(*zip(*sample.items())) + res = {key: self[key].cdf(sample[key], **self.get_required_variables(key)) for key in sample} + return sample.__class__(res) + def rescale(self, keys, theta): """Rescale samples from unit cube to prior @@ -695,12 +711,14 @@ class ConditionalPriorDict(PriorDict): ======= list: List of floats containing the rescaled sample """ + keys = list(keys) + theta = list(theta) self._check_resolved() self._update_rescale_keys(keys) result = dict() for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes): - required_variables = {k: result[k] for k in getattr(self[key], 'required_variables', [])} - result[key] = self[key].rescale(theta[index], **required_variables) + result[key] = self[key].rescale(theta[index], **self.get_required_variables(key)) + self[key].least_recently_sampled = result[key] return [result[key] for key in keys] def _update_rescale_keys(self, keys): @@ -708,6 +726,11 @@ class ConditionalPriorDict(PriorDict): self._rescale_indexes = [keys.index(element) for element in self.sorted_keys_without_fixed_parameters] self._least_recently_rescaled_keys = keys + def _prepare_evaluation(self, keys, theta): + self._check_resolved() + for key, value in zip(keys, theta): + self[key].least_recently_sampled = value + def _check_resolved(self): if not self._resolved: raise IllegalConditionsException("The current set of priors contains unresolveable conditions.") diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index a76f10a4fa379d34b4fe35499a76fb4005506741..3e2c5d051cedfe93ee401424020ee5958e98d7a9 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -1,6 +1,7 @@ import unittest import mock +import numpy as np import bilby @@ -170,13 +171,13 @@ class TestConditionalPrior(unittest.TestCase): class TestConditionalPriorDict(unittest.TestCase): def setUp(self): def condition_func_1(reference_parameters, var_0): - return reference_parameters + return dict(minimum=reference_parameters["minimum"], maximum=var_0) def condition_func_2(reference_parameters, var_0, var_1): - return reference_parameters + return dict(minimum=reference_parameters["minimum"], maximum=var_1) def condition_func_3(reference_parameters, var_1, var_2): - return reference_parameters + return dict(minimum=reference_parameters["minimum"], maximum=var_2) self.minimum = 0 self.maximum = 1 @@ -203,7 +204,8 @@ class TestConditionalPriorDict(unittest.TestCase): self.conditional_priors_manually_set_items = ( bilby.core.prior.ConditionalPriorDict() ) - self.test_sample = dict(var_0=0.3, var_1=0.4, var_2=0.5, var_3=0.4) + self.test_sample = dict(var_0=0.7, var_1=0.6, var_2=0.5, var_3=0.4) + self.test_value = 1 / np.prod([self.test_sample[f"var_{ii}"] for ii in range(3)]) for key, value in dict( var_0=self.prior_0, var_1=self.prior_1, @@ -254,7 +256,7 @@ class TestConditionalPriorDict(unittest.TestCase): ) def test_prob(self): - self.assertEqual(1, self.conditional_priors.prob(sample=self.test_sample)) + self.assertEqual(self.test_value, self.conditional_priors.prob(sample=self.test_sample)) def test_prob_illegal_conditions(self): del self.conditional_priors["var_0"] @@ -262,7 +264,7 @@ class TestConditionalPriorDict(unittest.TestCase): self.conditional_priors.prob(sample=self.test_sample) def test_ln_prob(self): - self.assertEqual(0, self.conditional_priors.ln_prob(sample=self.test_sample)) + self.assertEqual(np.log(self.test_value), self.conditional_priors.ln_prob(sample=self.test_sample)) def test_ln_prob_illegal_conditions(self): del self.conditional_priors["var_0"] @@ -273,7 +275,7 @@ class TestConditionalPriorDict(unittest.TestCase): with mock.patch("numpy.random.uniform") as m: m.return_value = 0.5 self.assertDictEqual( - dict(var_0=0.5, var_1=0.5, var_2=0.5, var_3=0.5), + dict(var_0=0.5, var_1=0.5 ** 2, var_2=0.5 ** 3, var_3=0.5 ** 4), self.conditional_priors.sample_subset( keys=["var_0", "var_1", "var_2", "var_3"] ), @@ -301,31 +303,6 @@ class TestConditionalPriorDict(unittest.TestCase): print(priors.sample(2)) def test_rescale(self): - def condition_func_1_rescale(reference_parameters, var_0): - if var_0 == 0.5: - return dict(minimum=reference_parameters["minimum"], maximum=1) - return reference_parameters - - def condition_func_2_rescale(reference_parameters, var_0, var_1): - if var_0 == 0.5 and var_1 == 0.5: - return dict(minimum=reference_parameters["minimum"], maximum=1) - return reference_parameters - - def condition_func_3_rescale(reference_parameters, var_1, var_2): - if var_1 == 0.5 and var_2 == 0.5: - return dict(minimum=reference_parameters["minimum"], maximum=1) - return reference_parameters - - self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=1) - self.prior_1 = bilby.core.prior.ConditionalUniform( - condition_func=condition_func_1_rescale, minimum=self.minimum, maximum=2 - ) - self.prior_2 = bilby.core.prior.ConditionalUniform( - condition_func=condition_func_2_rescale, minimum=self.minimum, maximum=2 - ) - self.prior_3 = bilby.core.prior.ConditionalUniform( - condition_func=condition_func_3_rescale, minimum=self.minimum, maximum=2 - ) self.conditional_priors = bilby.core.prior.ConditionalPriorDict( dict( var_3=self.prior_3, @@ -334,11 +311,28 @@ class TestConditionalPriorDict(unittest.TestCase): var_1=self.prior_1, ) ) - ref_variables = [0.5, 0.5, 0.5, 0.5] + ref_variables = self.test_sample.values() res = self.conditional_priors.rescale( - keys=list(self.test_sample.keys()), theta=ref_variables + keys=self.test_sample.keys(), theta=ref_variables + ) + expected = [self.test_sample["var_0"]] + for ii in range(1, 4): + expected.append(expected[-1] * self.test_sample[f"var_{ii}"]) + self.assertListEqual(expected, res) + + def test_cdf(self): + """ + Test that the CDF method is the inverse of the rescale method. + + Note that the format of inputs/outputs is different between the two methods. + """ + sample = self.conditional_priors.sample() + self.assertEqual( + self.conditional_priors.rescale( + sample.keys(), + self.conditional_priors.cdf(sample=sample).values() + ), list(sample.values()) ) - self.assertListEqual(ref_variables, res) def test_rescale_illegal_conditions(self): del self.conditional_priors["var_0"] diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index 546bcd4652cb592755630a4d356cda1517f1203f..3b8487fa1f49ea9c25459c738cc16e58252a5f76 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -298,6 +298,20 @@ class TestPriorDict(unittest.TestCase): ), ) + def test_cdf(self): + """ + Test that the CDF method is the inverse of the rescale method. + + Note that the format of inputs/outputs is different between the two methods. + """ + sample = self.prior_set_from_dict.sample() + self.assertEqual( + self.prior_set_from_dict.rescale( + sample.keys(), + self.prior_set_from_dict.cdf(sample=sample).values() + ), list(sample.values()) + ) + def test_redundancy(self): for key in self.prior_set_from_dict.keys(): self.assertFalse(self.prior_set_from_dict.test_redundancy(key=key))