From c451695baa9dd90902a871a9de7032c15e5bfa9d Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Wed, 11 Dec 2019 19:07:20 -0600 Subject: [PATCH] Revert "Merge branch '439-conditional-priors-not-working-reliably-with-nested-conditions' into 'master'" This reverts merge request !674 --- bilby/core/prior.py | 14 ++++++++------ test/prior_test.py | 22 ---------------------- 2 files changed, 8 insertions(+), 28 deletions(-) diff --git a/bilby/core/prior.py b/bilby/core/prior.py index 64ed03cbb..650c47fd9 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -641,16 +641,22 @@ class ConditionalPriorDict(PriorDict): self._check_resolved() self._update_rescale_keys(keys) result = dict() - for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes): + for key, index in zip(self._rescale_keys, self._rescale_indexes): required_variables = {k: result[k] for k in getattr(self[key], 'required_variables', [])} result[key] = self[key].rescale(theta[index], **required_variables) return [result[key] for key in keys] def _update_rescale_keys(self, keys): if not keys == self._least_recently_rescaled_keys: - self._rescale_indexes = [keys.index(element) for element in self.sorted_keys_without_fixed_parameters] + self._set_rescale_keys_and_indexes(keys) self._least_recently_rescaled_keys = keys + def _set_rescale_keys_and_indexes(self, keys): + unconditional_keys, unconditional_idxs, _ = np.intersect1d(keys, self.unconditional_keys, return_indices=True) + conditional_keys, conditional_idxs, _ = np.intersect1d(keys, self.conditional_keys, return_indices=True) + self._rescale_keys = np.append(unconditional_keys, conditional_keys) + self._rescale_indexes = np.append(unconditional_idxs, conditional_idxs) + def _check_resolved(self): if not self._resolved: raise IllegalConditionsException("The current set of priors contains unresolveable conditions.") @@ -667,10 +673,6 @@ class ConditionalPriorDict(PriorDict): def sorted_keys(self): return self.unconditional_keys + self.conditional_keys - @property - def sorted_keys_without_fixed_parameters(self): - return [key for key in self.sorted_keys if not isinstance(self[key], DeltaFunction)] - def __setitem__(self, key, value): super(ConditionalPriorDict, self).__setitem__(key, value) self._resolve_conditions() diff --git a/test/prior_test.py b/test/prior_test.py index 0a1906697..b6f554d44 100644 --- a/test/prior_test.py +++ b/test/prior_test.py @@ -1182,28 +1182,6 @@ class TestConditionalPriorDict(unittest.TestCase): with self.assertRaises(bilby.core.prior.IllegalConditionsException): self.conditional_priors.rescale(keys=list(self.test_sample.keys()), theta=list(self.test_sample.values())) - def test_what_broke(self): - def d_condition_func(reference_params, a, b, c): - return dict(minimum=reference_params['minimum'], maximum=reference_params['maximum']) - - def a_condition_func(reference_params, b, c): - return dict(minimum=reference_params['minimum'], maximum=reference_params['maximum']) - - priors = bilby.core.prior.ConditionalPriorDict() - - priors['a'] = bilby.core.prior.ConditionalUniform(condition_func=a_condition_func, - minimum=0, maximum=1) - - priors['b'] = bilby.core.prior.LogUniform(minimum=1, maximum=10) - - priors['d'] = bilby.core.prior.ConditionalUniform(condition_func=d_condition_func, - minimum=0.0, maximum=1.0) - - priors['c'] = bilby.core.prior.LogUniform(minimum=1, maximum=10) - sample = priors.sample() - res = priors.rescale(['a', 'b', 'd', 'c'], [0.5, 0.5, 0.5, 0.5]) - print(res) - class TestJsonIO(unittest.TestCase): -- GitLab