From 6c6201d705e03fbede2cd845757e0489893a49b8 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Fri, 10 Nov 2023 15:38:41 +0000
Subject: [PATCH] ConditionalPriorDict: fix subset sampling for external
 'DeltaFunction' dependencies

---
 bilby/core/prior/analytical.py      |  1 +
 bilby/core/prior/dict.py            | 12 ++++++++---
 test/core/prior/conditional_test.py | 32 +++++++++++++++++++++++++++++
 3 files changed, 42 insertions(+), 3 deletions(-)

diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py
index c7018657f..5e7b3099f 100644
--- a/bilby/core/prior/analytical.py
+++ b/bilby/core/prior/analytical.py
@@ -28,6 +28,7 @@ class DeltaFunction(Prior):
                                             minimum=peak, maximum=peak, check_range_nonzero=False)
         self.peak = peak
         self._is_fixed = True
+        self.least_recently_sampled = peak
 
     def rescale(self, val):
         """Rescale everything to the peak with the correct shape.
diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py
index d25ca6487..d888e4b8f 100644
--- a/bilby/core/prior/dict.py
+++ b/bilby/core/prior/dict.py
@@ -711,16 +711,22 @@ class ConditionalPriorDict(PriorDict):
 
     def sample_subset(self, keys=iter([]), size=None):
         self.convert_floats_to_delta_functions()
-        subset_dict = ConditionalPriorDict({key: self[key] for key in keys})
+        add_delta_keys = [
+            key
+            for key in self.keys()
+            if key not in keys and isinstance(self[key], DeltaFunction)
+        ]
+        use_keys = add_delta_keys + list(keys)
+        subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys})
         if not subset_dict._resolved:
             raise IllegalConditionsException(
                 "The current set of priors contains unresolvable conditions."
             )
         samples = dict()
         for key in subset_dict.sorted_keys:
-            if isinstance(self[key], Constraint):
+            if key not in keys or isinstance(self[key], Constraint):
                 continue
-            elif isinstance(self[key], Prior):
+            if isinstance(self[key], Prior):
                 try:
                     samples[key] = subset_dict[key].sample(
                         size=size, **subset_dict.get_required_variables(key)
diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py
index 5ee4efd60..fbfa45cc0 100644
--- a/test/core/prior/conditional_test.py
+++ b/test/core/prior/conditional_test.py
@@ -412,6 +412,38 @@ class TestConditionalPriorDict(unittest.TestCase):
         res = priors.rescale(["a", "b", "d", "c"], [0.5, 0.5, 0.5, 0.5])
         print(res)
 
+    def test_subset_sampling(self):
+        def _tp_conditional_uniform(ref_params, period):
+            min_ref, max_ref = ref_params["minimum"], ref_params["maximum"]
+            max_ref = np.minimum(max_ref, min_ref + period)
+            return {"minimum": min_ref, "maximum": max_ref}
+
+        p0 = 68400.0
+        prior = bilby.core.prior.ConditionalPriorDict(
+            {
+                "tp": bilby.core.prior.ConditionalUniform(
+                    condition_func=_tp_conditional_uniform, minimum=0, maximum=2 * p0
+                )
+            }
+        )
+
+        # ---------- 0. Sanity check: sample full prior
+        prior["period"] = p0
+        samples2d = prior.sample(1000)
+        assert samples2d["tp"].max() < p0
+
+        # ---------- 1. Subset sampling with external delta-prior
+        print("Test 1: Subset-sampling conditionals for fixed 'externals':")
+        prior["period"] = p0
+        samples1d = prior.sample_subset(["tp"], 1000)
+        self.assertLess(samples1d["tp"].max(), p0)
+
+        # ---------- 2. Subset sampling with external uniform prior
+        prior["period"] = bilby.core.prior.Uniform(minimum=p0, maximum=2 * p0)
+        print("Test 2: Subset-sampling conditionals for 'external' uncertainties:")
+        with self.assertRaises(bilby.core.prior.IllegalConditionsException):
+            prior.sample_subset(["tp"], 1000)
+
 
 class TestDirichletPrior(unittest.TestCase):
 
-- 
GitLab