From f67b4961dcb2307496dc9da156c80704ee2c5f42 Mon Sep 17 00:00:00 2001 From: Moritz <email@moritz-huebner.de> Date: Wed, 16 Oct 2019 14:28:57 +1100 Subject: [PATCH] Some modifications to ConditionalPriorDict after testing --- bilby/core/prior.py | 20 +++--- test/prior_test.py | 148 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 155 insertions(+), 13 deletions(-) diff --git a/bilby/core/prior.py b/bilby/core/prior.py index de41bc431..b94927107 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -501,10 +501,10 @@ class ConditionalPriorDict(PriorDict): self._rescale_indexes = [] self._least_recently_rescaled_keys = [] super(ConditionalPriorDict, self).__init__(dictionary=dictionary, filename=filename) - self.resolved = False + self._resolved = False self._resolve_conditions() - def _resolve_conditions(self, disable_log=False): + def _resolve_conditions(self): """ Resolves how variables depend on each other and automatically sorts them into the right order """ conditioned_keys_unsorted = [key for key in self.keys() if hasattr(self[key], 'condition_func')] self._unconditional_keys = [key for key in self.keys() if not hasattr(self[key], 'condition_func')] @@ -519,12 +519,10 @@ class ConditionalPriorDict(PriorDict): self._sorted_keys = self._unconditional_keys.copy() self._sorted_keys.extend(self.conditional_keys) - self.resolved = True + self._resolved = True if len(conditioned_keys_unsorted) != 0: - self.resolved = False - if not disable_log: - logger.warning('This set contains unresolvable conditions') + self._resolved = False def _check_conditions_resolved(self, key, sampled_keys): """ Checks if all required variables have already been sampled so we can sample this key """ @@ -537,7 +535,7 @@ class ConditionalPriorDict(PriorDict): def sample_subset(self, keys=iter([]), size=None): self.convert_floats_to_delta_functions() subset_dict = ConditionalPriorDict({key: self[key] for key in keys}) - if not subset_dict.resolved: + if not subset_dict._resolved: raise IllegalConditionsException("The current set of priors contains unresolveable conditions.") res = dict() for key in subset_dict.sorted_keys: @@ -634,7 +632,7 @@ class ConditionalPriorDict(PriorDict): self._rescale_indexes = np.append(unconditional_idxs, conditional_idxs) def _check_resolved(self): - if not self.resolved: + if not self._resolved: raise IllegalConditionsException("The current set of priors contains unresolveable conditions.") @property @@ -651,7 +649,11 @@ class ConditionalPriorDict(PriorDict): def __setitem__(self, key, value): super(ConditionalPriorDict, self).__setitem__(key, value) - self._resolve_conditions(disable_log=True) + self._resolve_conditions() + + def __delitem__(self, key): + super(ConditionalPriorDict, self).__delitem__(key) + self._resolve_conditions() def create_default_prior(name, default_priors_file=None): diff --git a/test/prior_test.py b/test/prior_test.py index 9ec91dd94..c66fc3943 100644 --- a/test/prior_test.py +++ b/test/prior_test.py @@ -294,6 +294,11 @@ class TestPriorClasses(unittest.TestCase): outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000) self.assertTrue(all(prior.prob(outside_domain) == 0)) + def test_least_recently_sampled(self): + for prior in self.priors: + lrs = prior.sample() + self.assertEqual(lrs, prior.least_recently_sampled) + def test_prob_and_ln_prob(self): for prior in self.priors: sample = prior.sample() @@ -852,7 +857,7 @@ class TestConditionalPrior(unittest.TestCase): def setUp(self): self.condition_func_call_counter = 0 - def condition_func(reference_parameters, test_parameter_1, test_parameter_2): + def condition_func(reference_parameters, test_variable_1, test_variable_2): self.condition_func_call_counter += 1 return {key: value + 1 for key, value in reference_parameters.items()} self.condition_func = condition_func @@ -860,14 +865,17 @@ class TestConditionalPrior(unittest.TestCase): self.maximum = 5 self.test_parameter_1 = 0 self.test_parameter_2 = 1 - self.prior = bilby.core.prior.ConditionalPrior(condition_func=condition_func, - minimum=self.minimum, - maximum=self.maximum) + self.prior = bilby.core.prior.ConditionalBasePrior(condition_func=condition_func, + minimum=self.minimum, + maximum=self.maximum) def tearDown(self): del self.condition_func + del self.condition_func_call_counter del self.minimum del self.maximum + del self.test_parameter_1 + del self.test_parameter_2 del self.prior def test_reference_params(self): @@ -876,6 +884,11 @@ class TestConditionalPrior(unittest.TestCase): def test_required_variables(self): self.assertListEqual(['test_parameter_1', 'test_parameter_2'], sorted(self.prior.required_variables)) + def test_required_variables_no_condition_func(self): + self.prior = bilby.core.prior.ConditionalBasePrior(minimum=self.minimum, + maximum=self.maximum) + self.assertListEqual([], self.prior.required_variables) + def test_get_instantiation_dict(self): expected = dict(minimum=0, maximum=5, name=None, latex_label=None, unit=None, boundary=None, condition_func=self.condition_func) @@ -937,6 +950,133 @@ class TestConditionalPrior(unittest.TestCase): self.assertEqual(self.prior.reference_params['minimum'], self.prior.minimum) self.assertEqual(self.prior.reference_params['maximum'], self.prior.maximum) + def test_cond_prior_instantiation_no_boundary_prior(self): + prior = bilby.core.prior.ConditionalFermiDirac(sigma=1) + self.assertIsNone(prior.boundary) + + +class TestConditionalPriorDict(unittest.TestCase): + + def setUp(self): + def condition_func_1(reference_parameters, var_0): + return reference_parameters + + def condition_func_2(reference_parameters, var_0, var_1): + return reference_parameters + + def condition_func_3(reference_parameters, var_1, var_2): + return reference_parameters + + self.minimum = 0 + self.maximum = 1 + self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=self.maximum) + self.prior_1 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_1, + minimum=self.minimum, maximum=self.maximum) + self.prior_2 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_2, + minimum=self.minimum, maximum=self.maximum) + self.prior_3 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_3, + minimum=self.minimum, maximum=self.maximum) + self.conditional_priors = bilby.core.prior.ConditionalPriorDict(dict(var_3=self.prior_3, var_2=self.prior_2, + var_0=self.prior_0, var_1=self.prior_1)) + self.conditional_priors_manually_set_items = bilby.core.prior.ConditionalPriorDict() + self.test_sample = dict(var_0=0.3, var_1=0.4, var_2=0.5, var_3=0.4) + for key, value in dict(var_0=self.prior_0, var_1=self.prior_1, var_2=self.prior_2, var_3=self.prior_3).items(): + self.conditional_priors_manually_set_items[key] = value + + def tearDown(self): + del self.minimum + del self.maximum + del self.prior_0 + del self.prior_1 + del self.prior_2 + del self.prior_3 + del self.conditional_priors + del self.conditional_priors_manually_set_items + del self.test_sample + + def test_conditions_resolved_upon_instantiation(self): + self.assertListEqual(['var_0', 'var_1', 'var_2', 'var_3'], self.conditional_priors.sorted_keys) + + def test_conditions_resolved_setting_items(self): + self.assertListEqual(['var_0', 'var_1', 'var_2', 'var_3'], + self.conditional_priors_manually_set_items.sorted_keys) + + def test_unconditional_keys_upon_instantiation(self): + self.assertListEqual(['var_0'], self.conditional_priors.unconditional_keys) + + def test_unconditional_keys_setting_items(self): + self.assertListEqual(['var_0'], self.conditional_priors_manually_set_items.unconditional_keys) + + def test_conditional_keys_upon_instantiation(self): + self.assertListEqual(['var_1', 'var_2', 'var_3'], self.conditional_priors.conditional_keys) + + def test_conditional_keys_setting_items(self): + self.assertListEqual(['var_1', 'var_2', 'var_3'], self.conditional_priors_manually_set_items.conditional_keys) + + def test_prob(self): + self.assertEqual(1, self.conditional_priors.prob(sample=self.test_sample)) + + def test_prob_illegal_conditions(self): + del self.conditional_priors['var_0'] + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + self.conditional_priors.prob(sample=self.test_sample) + + def test_ln_prob(self): + self.assertEqual(0, self.conditional_priors.ln_prob(sample=self.test_sample)) + + def test_ln_prob_illegal_conditions(self): + del self.conditional_priors['var_0'] + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + self.conditional_priors.ln_prob(sample=self.test_sample) + + def test_sample_subset_all_keys(self): + with mock.patch("numpy.random.uniform") as m: + m.return_value = 0.5 + self.assertDictEqual(dict(var_0=0.5, var_1=0.5, var_2=0.5, var_3=0.5), + self.conditional_priors.sample_subset(keys=['var_0', 'var_1', 'var_2', 'var_3'])) + + def test_sample_illegal_subset(self): + with mock.patch("numpy.random.uniform") as m: + m.return_value = 0.5 + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + self.conditional_priors.sample_subset(keys=['var_1']) + + def test_rescale(self): + + def condition_func_1_rescale(reference_parameters, var_0): + if var_0 == 0.5: + return dict(minimum=reference_parameters['minimum'], maximum=1) + return reference_parameters + + def condition_func_2_rescale(reference_parameters, var_0, var_1): + if var_0 == 0.5 and var_1 == 0.5: + return dict(minimum=reference_parameters['minimum'], maximum=1) + return reference_parameters + + def condition_func_3_rescale(reference_parameters, var_1, var_2): + if var_1 == 0.5 and var_2 == 0.5: + return dict(minimum=reference_parameters['minimum'], maximum=1) + return reference_parameters + + self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=1) + self.prior_1 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_1_rescale, + minimum=self.minimum, maximum=2) + self.prior_2 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_2_rescale, + minimum=self.minimum, maximum=2) + self.prior_3 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_3_rescale, + minimum=self.minimum, maximum=2) + self.conditional_priors = bilby.core.prior.ConditionalPriorDict(dict(var_3=self.prior_3, var_2=self.prior_2, + var_0=self.prior_0, var_1=self.prior_1)) + ref_variables = [0.5, 0.5, 0.5, 0.5] + res = self.conditional_priors.rescale(keys=list(self.test_sample.keys()), + theta=ref_variables) + self.assertListEqual(ref_variables, res) + + def test_rescale_illegal_conditions(self): + del self.conditional_priors['var_0'] + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + self.conditional_priors.rescale(keys=list(self.test_sample.keys()), theta=list(self.test_sample.values())) + class TestJsonIO(unittest.TestCase): -- GitLab