From a9be56e1c940d735e54631afb797ab146b9bec46 Mon Sep 17 00:00:00 2001
From: Moritz <email@moritz-huebner.de>
Date: Thu, 24 Oct 2019 11:57:43 +1100
Subject: [PATCH] Created a generic conditional prior example

---
 bilby/core/prior.py | 11 ++++++++++-
 test/prior_test.py  | 12 ++++++++++++
 2 files changed, 22 insertions(+), 1 deletion(-)

diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index 33737c91c..e30f6b02b 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -546,7 +546,16 @@ class ConditionalPriorDict(PriorDict):
             raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
         res = dict()
         for key in subset_dict.sorted_keys:
-            res[key] = subset_dict[key].sample(size=size, **subset_dict.get_required_variables(key))
+            try:
+                res[key] = subset_dict[key].sample(size=size, **subset_dict.get_required_variables(key))
+            except ValueError:
+                # Some prior classes can not handle an array of reference parameters (e.g. alpha for PowerLaw
+                # If that is the case, we sample each sample individually.
+                required_variables = subset_dict.get_required_variables(key)
+                res[key] = np.zeros(size)
+                for i in range(size):
+                    rvars = {key: value[i] for key, value in required_variables.items()}
+                    res[key][i] = subset_dict[key].sample(**rvars)
         return res
 
     def get_required_variables(self, key):
diff --git a/test/prior_test.py b/test/prior_test.py
index a637fe07c..13ce19a29 100644
--- a/test/prior_test.py
+++ b/test/prior_test.py
@@ -147,6 +147,7 @@ class TestPriorBoundary(unittest.TestCase):
         with self.assertRaises(ValueError):
             self.prior.boundary = 'else'
 
+
 class TestPriorClasses(unittest.TestCase):
 
     def setUp(self):
@@ -1042,6 +1043,17 @@ class TestConditionalPriorDict(unittest.TestCase):
             with self.assertRaises(bilby.core.prior.IllegalConditionsException):
                 self.conditional_priors.sample_subset(keys=['var_1'])
 
+    def test_sample_multiple(self):
+        def condition_func(reference_params, a):
+            return dict(minimum=reference_params['minimum'],
+                        maximum=reference_params['maximum'],
+                        alpha=reference_params['alpha'] * a)
+        priors = bilby.core.prior.ConditionalPriorDict()
+        priors['a'] = bilby.core.prior.Uniform(minimum=0, maximum=1)
+        priors['b'] = bilby.core.prior.ConditionalPowerLaw(condition_func=condition_func, minimum=1, maximum=2,
+                                                           alpha=-2)
+        print(priors.sample(2))
+
     def test_rescale(self):
 
         def condition_func_1_rescale(reference_parameters, var_0):
-- 
GitLab