From a9be56e1c940d735e54631afb797ab146b9bec46 Mon Sep 17 00:00:00 2001 From: Moritz <email@moritz-huebner.de> Date: Thu, 24 Oct 2019 11:57:43 +1100 Subject: [PATCH] Created a generic conditional prior example --- bilby/core/prior.py | 11 ++++++++++- test/prior_test.py | 12 ++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/bilby/core/prior.py b/bilby/core/prior.py index 33737c91c..e30f6b02b 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -546,7 +546,16 @@ class ConditionalPriorDict(PriorDict): raise IllegalConditionsException("The current set of priors contains unresolveable conditions.") res = dict() for key in subset_dict.sorted_keys: - res[key] = subset_dict[key].sample(size=size, **subset_dict.get_required_variables(key)) + try: + res[key] = subset_dict[key].sample(size=size, **subset_dict.get_required_variables(key)) + except ValueError: + # Some prior classes can not handle an array of reference parameters (e.g. alpha for PowerLaw + # If that is the case, we sample each sample individually. + required_variables = subset_dict.get_required_variables(key) + res[key] = np.zeros(size) + for i in range(size): + rvars = {key: value[i] for key, value in required_variables.items()} + res[key][i] = subset_dict[key].sample(**rvars) return res def get_required_variables(self, key): diff --git a/test/prior_test.py b/test/prior_test.py index a637fe07c..13ce19a29 100644 --- a/test/prior_test.py +++ b/test/prior_test.py @@ -147,6 +147,7 @@ class TestPriorBoundary(unittest.TestCase): with self.assertRaises(ValueError): self.prior.boundary = 'else' + class TestPriorClasses(unittest.TestCase): def setUp(self): @@ -1042,6 +1043,17 @@ class TestConditionalPriorDict(unittest.TestCase): with self.assertRaises(bilby.core.prior.IllegalConditionsException): self.conditional_priors.sample_subset(keys=['var_1']) + def test_sample_multiple(self): + def condition_func(reference_params, a): + return dict(minimum=reference_params['minimum'], + maximum=reference_params['maximum'], + alpha=reference_params['alpha'] * a) + priors = bilby.core.prior.ConditionalPriorDict() + priors['a'] = bilby.core.prior.Uniform(minimum=0, maximum=1) + priors['b'] = bilby.core.prior.ConditionalPowerLaw(condition_func=condition_func, minimum=1, maximum=2, + alpha=-2) + print(priors.sample(2)) + def test_rescale(self): def condition_func_1_rescale(reference_parameters, var_0): -- GitLab