From b3de448152ed0d18945bcfe34b9ea59416d33228 Mon Sep 17 00:00:00 2001 From: Moritz Huebner <moritz.huebner@ligo.org> Date: Wed, 11 Dec 2019 23:50:53 -0600 Subject: [PATCH] Merge remote-tracking branch 'origin/master' into 442-scheduled-tests-failing-gw_example_test-py # Conflicts: # bilby/core/prior.py --- bilby/core/prior.py | 14 ++++++-------- test/prior_test.py | 22 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/bilby/core/prior.py b/bilby/core/prior.py index 650c47fd9..9e338a70b 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -641,22 +641,16 @@ class ConditionalPriorDict(PriorDict): self._check_resolved() self._update_rescale_keys(keys) result = dict() - for key, index in zip(self._rescale_keys, self._rescale_indexes): + for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes): required_variables = {k: result[k] for k in getattr(self[key], 'required_variables', [])} result[key] = self[key].rescale(theta[index], **required_variables) return [result[key] for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: - self._set_rescale_keys_and_indexes(keys) + self._rescale_indexes = [keys.index(element) for element in self.sorted_keys_without_fixed_parameters] self._least_recently_rescaled_keys = keys - def _set_rescale_keys_and_indexes(self, keys): - unconditional_keys, unconditional_idxs, _ = np.intersect1d(keys, self.unconditional_keys, return_indices=True) - conditional_keys, conditional_idxs, _ = np.intersect1d(keys, self.conditional_keys, return_indices=True) - self._rescale_keys = np.append(unconditional_keys, conditional_keys) - self._rescale_indexes = np.append(unconditional_idxs, conditional_idxs) - def _check_resolved(self): if not self._resolved: raise IllegalConditionsException("The current set of priors contains unresolveable conditions.") @@ -673,6 +667,10 @@ class ConditionalPriorDict(PriorDict): def sorted_keys(self): return self.unconditional_keys + self.conditional_keys + @property + def sorted_keys_without_fixed_parameters(self): + return [key for key in self.sorted_keys if not isinstance(self[key], (DeltaFunction, Constraint))] + def __setitem__(self, key, value): super(ConditionalPriorDict, self).__setitem__(key, value) self._resolve_conditions() diff --git a/test/prior_test.py b/test/prior_test.py index b6f554d44..77c939b8e 100644 --- a/test/prior_test.py +++ b/test/prior_test.py @@ -1182,6 +1182,28 @@ class TestConditionalPriorDict(unittest.TestCase): with self.assertRaises(bilby.core.prior.IllegalConditionsException): self.conditional_priors.rescale(keys=list(self.test_sample.keys()), theta=list(self.test_sample.values())) + def test_combined_conditions(self): + def d_condition_func(reference_params, a, b, c): + return dict(minimum=reference_params['minimum'], maximum=reference_params['maximum']) + + def a_condition_func(reference_params, b, c): + return dict(minimum=reference_params['minimum'], maximum=reference_params['maximum']) + + priors = bilby.core.prior.ConditionalPriorDict() + + priors['a'] = bilby.core.prior.ConditionalUniform(condition_func=a_condition_func, + minimum=0, maximum=1) + + priors['b'] = bilby.core.prior.LogUniform(minimum=1, maximum=10) + + priors['d'] = bilby.core.prior.ConditionalUniform(condition_func=d_condition_func, + minimum=0.0, maximum=1.0) + + priors['c'] = bilby.core.prior.LogUniform(minimum=1, maximum=10) + sample = priors.sample() + res = priors.rescale(['a', 'b', 'd', 'c'], [0.5, 0.5, 0.5, 0.5]) + print(res) + class TestJsonIO(unittest.TestCase): -- GitLab