From f67b4961dcb2307496dc9da156c80704ee2c5f42 Mon Sep 17 00:00:00 2001
From: Moritz <email@moritz-huebner.de>
Date: Wed, 16 Oct 2019 14:28:57 +1100
Subject: [PATCH] Some modifications to ConditionalPriorDict after testing

---
 bilby/core/prior.py |  20 +++---
 test/prior_test.py  | 148 ++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 155 insertions(+), 13 deletions(-)

diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index de41bc431..b94927107 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -501,10 +501,10 @@ class ConditionalPriorDict(PriorDict):
         self._rescale_indexes = []
         self._least_recently_rescaled_keys = []
         super(ConditionalPriorDict, self).__init__(dictionary=dictionary, filename=filename)
-        self.resolved = False
+        self._resolved = False
         self._resolve_conditions()
 
-    def _resolve_conditions(self, disable_log=False):
+    def _resolve_conditions(self):
         """ Resolves how variables depend on each other and automatically sorts them into the right order """
         conditioned_keys_unsorted = [key for key in self.keys() if hasattr(self[key], 'condition_func')]
         self._unconditional_keys = [key for key in self.keys() if not hasattr(self[key], 'condition_func')]
@@ -519,12 +519,10 @@ class ConditionalPriorDict(PriorDict):
 
         self._sorted_keys = self._unconditional_keys.copy()
         self._sorted_keys.extend(self.conditional_keys)
-        self.resolved = True
+        self._resolved = True
 
         if len(conditioned_keys_unsorted) != 0:
-            self.resolved = False
-            if not disable_log:
-                logger.warning('This set contains unresolvable conditions')
+            self._resolved = False
 
     def _check_conditions_resolved(self, key, sampled_keys):
         """ Checks if all required variables have already been sampled so we can sample this key """
@@ -537,7 +535,7 @@ class ConditionalPriorDict(PriorDict):
     def sample_subset(self, keys=iter([]), size=None):
         self.convert_floats_to_delta_functions()
         subset_dict = ConditionalPriorDict({key: self[key] for key in keys})
-        if not subset_dict.resolved:
+        if not subset_dict._resolved:
             raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
         res = dict()
         for key in subset_dict.sorted_keys:
@@ -634,7 +632,7 @@ class ConditionalPriorDict(PriorDict):
         self._rescale_indexes = np.append(unconditional_idxs, conditional_idxs)
 
     def _check_resolved(self):
-        if not self.resolved:
+        if not self._resolved:
             raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
 
     @property
@@ -651,7 +649,11 @@ class ConditionalPriorDict(PriorDict):
 
     def __setitem__(self, key, value):
         super(ConditionalPriorDict, self).__setitem__(key, value)
-        self._resolve_conditions(disable_log=True)
+        self._resolve_conditions()
+
+    def __delitem__(self, key):
+        super(ConditionalPriorDict, self).__delitem__(key)
+        self._resolve_conditions()
 
 
 def create_default_prior(name, default_priors_file=None):
diff --git a/test/prior_test.py b/test/prior_test.py
index 9ec91dd94..c66fc3943 100644
--- a/test/prior_test.py
+++ b/test/prior_test.py
@@ -294,6 +294,11 @@ class TestPriorClasses(unittest.TestCase):
                 outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000)
                 self.assertTrue(all(prior.prob(outside_domain) == 0))
 
+    def test_least_recently_sampled(self):
+        for prior in self.priors:
+            lrs = prior.sample()
+            self.assertEqual(lrs, prior.least_recently_sampled)
+
     def test_prob_and_ln_prob(self):
         for prior in self.priors:
             sample = prior.sample()
@@ -852,7 +857,7 @@ class TestConditionalPrior(unittest.TestCase):
     def setUp(self):
         self.condition_func_call_counter = 0
 
-        def condition_func(reference_parameters, test_parameter_1, test_parameter_2):
+        def condition_func(reference_parameters, test_variable_1, test_variable_2):
             self.condition_func_call_counter += 1
             return {key: value + 1 for key, value in reference_parameters.items()}
         self.condition_func = condition_func
@@ -860,14 +865,17 @@ class TestConditionalPrior(unittest.TestCase):
         self.maximum = 5
         self.test_parameter_1 = 0
         self.test_parameter_2 = 1
-        self.prior = bilby.core.prior.ConditionalPrior(condition_func=condition_func,
-                                                       minimum=self.minimum,
-                                                       maximum=self.maximum)
+        self.prior = bilby.core.prior.ConditionalBasePrior(condition_func=condition_func,
+                                                           minimum=self.minimum,
+                                                           maximum=self.maximum)
 
     def tearDown(self):
         del self.condition_func
+        del self.condition_func_call_counter
         del self.minimum
         del self.maximum
+        del self.test_parameter_1
+        del self.test_parameter_2
         del self.prior
 
     def test_reference_params(self):
@@ -876,6 +884,11 @@ class TestConditionalPrior(unittest.TestCase):
     def test_required_variables(self):
         self.assertListEqual(['test_parameter_1', 'test_parameter_2'], sorted(self.prior.required_variables))
 
+    def test_required_variables_no_condition_func(self):
+        self.prior = bilby.core.prior.ConditionalBasePrior(minimum=self.minimum,
+                                                           maximum=self.maximum)
+        self.assertListEqual([], self.prior.required_variables)
+
     def test_get_instantiation_dict(self):
         expected = dict(minimum=0, maximum=5, name=None, latex_label=None, unit=None,
                         boundary=None, condition_func=self.condition_func)
@@ -937,6 +950,133 @@ class TestConditionalPrior(unittest.TestCase):
         self.assertEqual(self.prior.reference_params['minimum'], self.prior.minimum)
         self.assertEqual(self.prior.reference_params['maximum'], self.prior.maximum)
 
+    def test_cond_prior_instantiation_no_boundary_prior(self):
+        prior = bilby.core.prior.ConditionalFermiDirac(sigma=1)
+        self.assertIsNone(prior.boundary)
+
+
+class TestConditionalPriorDict(unittest.TestCase):
+
+    def setUp(self):
+        def condition_func_1(reference_parameters, var_0):
+            return reference_parameters
+
+        def condition_func_2(reference_parameters, var_0, var_1):
+            return reference_parameters
+
+        def condition_func_3(reference_parameters, var_1, var_2):
+            return reference_parameters
+
+        self.minimum = 0
+        self.maximum = 1
+        self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=self.maximum)
+        self.prior_1 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_1,
+                                                           minimum=self.minimum, maximum=self.maximum)
+        self.prior_2 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_2,
+                                                           minimum=self.minimum, maximum=self.maximum)
+        self.prior_3 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_3,
+                                                           minimum=self.minimum, maximum=self.maximum)
+        self.conditional_priors = bilby.core.prior.ConditionalPriorDict(dict(var_3=self.prior_3, var_2=self.prior_2,
+                                                                             var_0=self.prior_0, var_1=self.prior_1))
+        self.conditional_priors_manually_set_items = bilby.core.prior.ConditionalPriorDict()
+        self.test_sample = dict(var_0=0.3, var_1=0.4, var_2=0.5, var_3=0.4)
+        for key, value in dict(var_0=self.prior_0, var_1=self.prior_1, var_2=self.prior_2, var_3=self.prior_3).items():
+            self.conditional_priors_manually_set_items[key] = value
+
+    def tearDown(self):
+        del self.minimum
+        del self.maximum
+        del self.prior_0
+        del self.prior_1
+        del self.prior_2
+        del self.prior_3
+        del self.conditional_priors
+        del self.conditional_priors_manually_set_items
+        del self.test_sample
+
+    def test_conditions_resolved_upon_instantiation(self):
+        self.assertListEqual(['var_0', 'var_1', 'var_2', 'var_3'], self.conditional_priors.sorted_keys)
+
+    def test_conditions_resolved_setting_items(self):
+        self.assertListEqual(['var_0', 'var_1', 'var_2', 'var_3'],
+                             self.conditional_priors_manually_set_items.sorted_keys)
+
+    def test_unconditional_keys_upon_instantiation(self):
+        self.assertListEqual(['var_0'], self.conditional_priors.unconditional_keys)
+
+    def test_unconditional_keys_setting_items(self):
+        self.assertListEqual(['var_0'], self.conditional_priors_manually_set_items.unconditional_keys)
+
+    def test_conditional_keys_upon_instantiation(self):
+        self.assertListEqual(['var_1', 'var_2', 'var_3'], self.conditional_priors.conditional_keys)
+
+    def test_conditional_keys_setting_items(self):
+        self.assertListEqual(['var_1', 'var_2', 'var_3'], self.conditional_priors_manually_set_items.conditional_keys)
+
+    def test_prob(self):
+        self.assertEqual(1, self.conditional_priors.prob(sample=self.test_sample))
+
+    def test_prob_illegal_conditions(self):
+        del self.conditional_priors['var_0']
+        with self.assertRaises(bilby.core.prior.IllegalConditionsException):
+            self.conditional_priors.prob(sample=self.test_sample)
+
+    def test_ln_prob(self):
+        self.assertEqual(0, self.conditional_priors.ln_prob(sample=self.test_sample))
+
+    def test_ln_prob_illegal_conditions(self):
+        del self.conditional_priors['var_0']
+        with self.assertRaises(bilby.core.prior.IllegalConditionsException):
+            self.conditional_priors.ln_prob(sample=self.test_sample)
+
+    def test_sample_subset_all_keys(self):
+        with mock.patch("numpy.random.uniform") as m:
+            m.return_value = 0.5
+            self.assertDictEqual(dict(var_0=0.5, var_1=0.5, var_2=0.5, var_3=0.5),
+                                 self.conditional_priors.sample_subset(keys=['var_0', 'var_1', 'var_2', 'var_3']))
+
+    def test_sample_illegal_subset(self):
+        with mock.patch("numpy.random.uniform") as m:
+            m.return_value = 0.5
+            with self.assertRaises(bilby.core.prior.IllegalConditionsException):
+                self.conditional_priors.sample_subset(keys=['var_1'])
+
+    def test_rescale(self):
+
+        def condition_func_1_rescale(reference_parameters, var_0):
+            if var_0 == 0.5:
+                return dict(minimum=reference_parameters['minimum'], maximum=1)
+            return reference_parameters
+
+        def condition_func_2_rescale(reference_parameters, var_0, var_1):
+            if var_0 == 0.5 and var_1 == 0.5:
+                return dict(minimum=reference_parameters['minimum'], maximum=1)
+            return reference_parameters
+
+        def condition_func_3_rescale(reference_parameters, var_1, var_2):
+            if var_1 == 0.5 and var_2 == 0.5:
+                return dict(minimum=reference_parameters['minimum'], maximum=1)
+            return reference_parameters
+
+        self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=1)
+        self.prior_1 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_1_rescale,
+                                                           minimum=self.minimum, maximum=2)
+        self.prior_2 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_2_rescale,
+                                                           minimum=self.minimum, maximum=2)
+        self.prior_3 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_3_rescale,
+                                                           minimum=self.minimum, maximum=2)
+        self.conditional_priors = bilby.core.prior.ConditionalPriorDict(dict(var_3=self.prior_3, var_2=self.prior_2,
+                                                                             var_0=self.prior_0, var_1=self.prior_1))
+        ref_variables = [0.5, 0.5, 0.5, 0.5]
+        res = self.conditional_priors.rescale(keys=list(self.test_sample.keys()),
+                                              theta=ref_variables)
+        self.assertListEqual(ref_variables, res)
+
+    def test_rescale_illegal_conditions(self):
+        del self.conditional_priors['var_0']
+        with self.assertRaises(bilby.core.prior.IllegalConditionsException):
+            self.conditional_priors.rescale(keys=list(self.test_sample.keys()), theta=list(self.test_sample.values()))
+
 
 class TestJsonIO(unittest.TestCase):
 
-- 
GitLab