From b3de448152ed0d18945bcfe34b9ea59416d33228 Mon Sep 17 00:00:00 2001
From: Moritz Huebner <moritz.huebner@ligo.org>
Date: Wed, 11 Dec 2019 23:50:53 -0600
Subject: [PATCH] Merge remote-tracking branch 'origin/master' into
 442-scheduled-tests-failing-gw_example_test-py

# Conflicts:
#	bilby/core/prior.py
---
 bilby/core/prior.py | 14 ++++++--------
 test/prior_test.py  | 22 ++++++++++++++++++++++
 2 files changed, 28 insertions(+), 8 deletions(-)

diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index 650c47fd9..9e338a70b 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -641,22 +641,16 @@ class ConditionalPriorDict(PriorDict):
         self._check_resolved()
         self._update_rescale_keys(keys)
         result = dict()
-        for key, index in zip(self._rescale_keys, self._rescale_indexes):
+        for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes):
             required_variables = {k: result[k] for k in getattr(self[key], 'required_variables', [])}
             result[key] = self[key].rescale(theta[index], **required_variables)
         return [result[key] for key in keys]
 
     def _update_rescale_keys(self, keys):
         if not keys == self._least_recently_rescaled_keys:
-            self._set_rescale_keys_and_indexes(keys)
+            self._rescale_indexes = [keys.index(element) for element in self.sorted_keys_without_fixed_parameters]
             self._least_recently_rescaled_keys = keys
 
-    def _set_rescale_keys_and_indexes(self, keys):
-        unconditional_keys, unconditional_idxs, _ = np.intersect1d(keys, self.unconditional_keys, return_indices=True)
-        conditional_keys, conditional_idxs, _ = np.intersect1d(keys, self.conditional_keys, return_indices=True)
-        self._rescale_keys = np.append(unconditional_keys, conditional_keys)
-        self._rescale_indexes = np.append(unconditional_idxs, conditional_idxs)
-
     def _check_resolved(self):
         if not self._resolved:
             raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
@@ -673,6 +667,10 @@ class ConditionalPriorDict(PriorDict):
     def sorted_keys(self):
         return self.unconditional_keys + self.conditional_keys
 
+    @property
+    def sorted_keys_without_fixed_parameters(self):
+        return [key for key in self.sorted_keys if not isinstance(self[key], (DeltaFunction, Constraint))]
+
     def __setitem__(self, key, value):
         super(ConditionalPriorDict, self).__setitem__(key, value)
         self._resolve_conditions()
diff --git a/test/prior_test.py b/test/prior_test.py
index b6f554d44..77c939b8e 100644
--- a/test/prior_test.py
+++ b/test/prior_test.py
@@ -1182,6 +1182,28 @@ class TestConditionalPriorDict(unittest.TestCase):
         with self.assertRaises(bilby.core.prior.IllegalConditionsException):
             self.conditional_priors.rescale(keys=list(self.test_sample.keys()), theta=list(self.test_sample.values()))
 
+    def test_combined_conditions(self):
+        def d_condition_func(reference_params, a, b, c):
+            return dict(minimum=reference_params['minimum'], maximum=reference_params['maximum'])
+
+        def a_condition_func(reference_params, b, c):
+            return dict(minimum=reference_params['minimum'], maximum=reference_params['maximum'])
+
+        priors = bilby.core.prior.ConditionalPriorDict()
+
+        priors['a'] = bilby.core.prior.ConditionalUniform(condition_func=a_condition_func,
+                                                          minimum=0, maximum=1)
+
+        priors['b'] = bilby.core.prior.LogUniform(minimum=1, maximum=10)
+
+        priors['d'] = bilby.core.prior.ConditionalUniform(condition_func=d_condition_func,
+                                                          minimum=0.0, maximum=1.0)
+
+        priors['c'] = bilby.core.prior.LogUniform(minimum=1, maximum=10)
+        sample = priors.sample()
+        res = priors.rescale(['a', 'b', 'd', 'c'], [0.5, 0.5, 0.5, 0.5])
+        print(res)
+
 
 class TestJsonIO(unittest.TestCase):
 
-- 
GitLab