From c451695baa9dd90902a871a9de7032c15e5bfa9d Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Wed, 11 Dec 2019 19:07:20 -0600
Subject: [PATCH] Revert "Merge branch
 '439-conditional-priors-not-working-reliably-with-nested-conditions' into
 'master'"

This reverts merge request !674
---
 bilby/core/prior.py | 14 ++++++++------
 test/prior_test.py  | 22 ----------------------
 2 files changed, 8 insertions(+), 28 deletions(-)

diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index 64ed03cbb..650c47fd9 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -641,16 +641,22 @@ class ConditionalPriorDict(PriorDict):
         self._check_resolved()
         self._update_rescale_keys(keys)
         result = dict()
-        for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes):
+        for key, index in zip(self._rescale_keys, self._rescale_indexes):
             required_variables = {k: result[k] for k in getattr(self[key], 'required_variables', [])}
             result[key] = self[key].rescale(theta[index], **required_variables)
         return [result[key] for key in keys]
 
     def _update_rescale_keys(self, keys):
         if not keys == self._least_recently_rescaled_keys:
-            self._rescale_indexes = [keys.index(element) for element in self.sorted_keys_without_fixed_parameters]
+            self._set_rescale_keys_and_indexes(keys)
             self._least_recently_rescaled_keys = keys
 
+    def _set_rescale_keys_and_indexes(self, keys):
+        unconditional_keys, unconditional_idxs, _ = np.intersect1d(keys, self.unconditional_keys, return_indices=True)
+        conditional_keys, conditional_idxs, _ = np.intersect1d(keys, self.conditional_keys, return_indices=True)
+        self._rescale_keys = np.append(unconditional_keys, conditional_keys)
+        self._rescale_indexes = np.append(unconditional_idxs, conditional_idxs)
+
     def _check_resolved(self):
         if not self._resolved:
             raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
@@ -667,10 +673,6 @@ class ConditionalPriorDict(PriorDict):
     def sorted_keys(self):
         return self.unconditional_keys + self.conditional_keys
 
-    @property
-    def sorted_keys_without_fixed_parameters(self):
-        return [key for key in self.sorted_keys if not isinstance(self[key], DeltaFunction)]
-
     def __setitem__(self, key, value):
         super(ConditionalPriorDict, self).__setitem__(key, value)
         self._resolve_conditions()
diff --git a/test/prior_test.py b/test/prior_test.py
index 0a1906697..b6f554d44 100644
--- a/test/prior_test.py
+++ b/test/prior_test.py
@@ -1182,28 +1182,6 @@ class TestConditionalPriorDict(unittest.TestCase):
         with self.assertRaises(bilby.core.prior.IllegalConditionsException):
             self.conditional_priors.rescale(keys=list(self.test_sample.keys()), theta=list(self.test_sample.values()))
 
-    def test_what_broke(self):
-        def d_condition_func(reference_params, a, b, c):
-            return dict(minimum=reference_params['minimum'], maximum=reference_params['maximum'])
-
-        def a_condition_func(reference_params, b, c):
-            return dict(minimum=reference_params['minimum'], maximum=reference_params['maximum'])
-
-        priors = bilby.core.prior.ConditionalPriorDict()
-
-        priors['a'] = bilby.core.prior.ConditionalUniform(condition_func=a_condition_func,
-                                                          minimum=0, maximum=1)
-
-        priors['b'] = bilby.core.prior.LogUniform(minimum=1, maximum=10)
-
-        priors['d'] = bilby.core.prior.ConditionalUniform(condition_func=d_condition_func,
-                                                          minimum=0.0, maximum=1.0)
-
-        priors['c'] = bilby.core.prior.LogUniform(minimum=1, maximum=10)
-        sample = priors.sample()
-        res = priors.rescale(['a', 'b', 'd', 'c'], [0.5, 0.5, 0.5, 0.5])
-        print(res)
-
 
 class TestJsonIO(unittest.TestCase):
 
-- 
GitLab