From 6c6201d705e03fbede2cd845757e0489893a49b8 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Fri, 10 Nov 2023 15:38:41 +0000 Subject: [PATCH] ConditionalPriorDict: fix subset sampling for external 'DeltaFunction' dependencies --- bilby/core/prior/analytical.py | 1 + bilby/core/prior/dict.py | 12 ++++++++--- test/core/prior/conditional_test.py | 32 +++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index c7018657f..5e7b3099f 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -28,6 +28,7 @@ class DeltaFunction(Prior): minimum=peak, maximum=peak, check_range_nonzero=False) self.peak = peak self._is_fixed = True + self.least_recently_sampled = peak def rescale(self, val): """Rescale everything to the peak with the correct shape. diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index d25ca6487..d888e4b8f 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -711,16 +711,22 @@ class ConditionalPriorDict(PriorDict): def sample_subset(self, keys=iter([]), size=None): self.convert_floats_to_delta_functions() - subset_dict = ConditionalPriorDict({key: self[key] for key in keys}) + add_delta_keys = [ + key + for key in self.keys() + if key not in keys and isinstance(self[key], DeltaFunction) + ] + use_keys = add_delta_keys + list(keys) + subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys}) if not subset_dict._resolved: raise IllegalConditionsException( "The current set of priors contains unresolvable conditions." ) samples = dict() for key in subset_dict.sorted_keys: - if isinstance(self[key], Constraint): + if key not in keys or isinstance(self[key], Constraint): continue - elif isinstance(self[key], Prior): + if isinstance(self[key], Prior): try: samples[key] = subset_dict[key].sample( size=size, **subset_dict.get_required_variables(key) diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 5ee4efd60..fbfa45cc0 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -412,6 +412,38 @@ class TestConditionalPriorDict(unittest.TestCase): res = priors.rescale(["a", "b", "d", "c"], [0.5, 0.5, 0.5, 0.5]) print(res) + def test_subset_sampling(self): + def _tp_conditional_uniform(ref_params, period): + min_ref, max_ref = ref_params["minimum"], ref_params["maximum"] + max_ref = np.minimum(max_ref, min_ref + period) + return {"minimum": min_ref, "maximum": max_ref} + + p0 = 68400.0 + prior = bilby.core.prior.ConditionalPriorDict( + { + "tp": bilby.core.prior.ConditionalUniform( + condition_func=_tp_conditional_uniform, minimum=0, maximum=2 * p0 + ) + } + ) + + # ---------- 0. Sanity check: sample full prior + prior["period"] = p0 + samples2d = prior.sample(1000) + assert samples2d["tp"].max() < p0 + + # ---------- 1. Subset sampling with external delta-prior + print("Test 1: Subset-sampling conditionals for fixed 'externals':") + prior["period"] = p0 + samples1d = prior.sample_subset(["tp"], 1000) + self.assertLess(samples1d["tp"].max(), p0) + + # ---------- 2. Subset sampling with external uniform prior + prior["period"] = bilby.core.prior.Uniform(minimum=p0, maximum=2 * p0) + print("Test 2: Subset-sampling conditionals for 'external' uncertainties:") + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + prior.sample_subset(["tp"], 1000) + class TestDirichletPrior(unittest.TestCase): -- GitLab