From 99ba5346e7705eb3de1e81b7e64c7733dd7fff40 Mon Sep 17 00:00:00 2001
From: Michael Williams <michael.williams@ligo.org>
Date: Tue, 16 Aug 2022 13:46:22 +0000
Subject: [PATCH] Improve how sampling seed is handled

---
 bilby/bilby_mcmc/sampler.py            |  1 +
 bilby/core/sampler/base_sampler.py     | 18 ++++++++++++++++--
 bilby/core/sampler/cpnest.py           |  2 ++
 bilby/core/sampler/dnest4.py           |  2 ++
 bilby/core/sampler/dynesty.py          |  1 +
 bilby/core/sampler/emcee.py            |  1 +
 bilby/core/sampler/nessai.py           |  7 ++-----
 bilby/core/sampler/nestle.py           |  1 +
 bilby/core/sampler/polychord.py        |  2 ++
 bilby/core/sampler/ptemcee.py          |  1 +
 bilby/core/sampler/ptmcmc.py           |  1 +
 bilby/core/sampler/pymc3.py            |  2 ++
 bilby/core/sampler/pymultinest.py      |  2 ++
 bilby/core/sampler/ultranest.py        |  1 +
 test/core/sampler/base_sampler_test.py |  9 +++++++++
 test/core/sampler/nessai_test.py       | 11 +++--------
 16 files changed, 47 insertions(+), 15 deletions(-)

diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py
index 555e32149..3decaf74a 100644
--- a/bilby/bilby_mcmc/sampler.py
+++ b/bilby/bilby_mcmc/sampler.py
@@ -202,6 +202,7 @@ class Bilby_MCMC(MCMCSampler):
             logger.warning("Burn-in inefficiency fraction greater than 10%")
 
     def _translate_kwargs(self, kwargs):
+        kwargs = super()._translate_kwargs(kwargs)
         if "printdt" not in kwargs:
             for equiv in ["print_dt", "print_update"]:
                 if equiv in kwargs:
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index c30f76045..47eee4ab6 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -196,7 +196,13 @@ class Sampler(object):
         "cores",
         "n_pool",
     ]
+    sampling_seed_equiv_kwargs = ["sampling_seed", "seed", "random_seed"]
     hard_exit = False
+    sampling_seed_key = None
+    """Name of keyword argument for setting the sampling for the specific sampler.
+    If a specific sampler does not have a sampling seed option, then it should be
+    left as None.
+    """
 
     def __init__(
         self,
@@ -289,8 +295,16 @@ class Sampler(object):
         self._verify_kwargs_against_default_kwargs()
 
     def _translate_kwargs(self, kwargs):
-        """Template for child classes"""
-        pass
+        """Translate keyword arguments.
+
+        Default only translates the sampling seed if the sampler has
+        :code:`sampling_seed_key` set.
+        """
+        if self.sampling_seed_key and self.sampling_seed_key not in kwargs:
+            for equiv in self.sampling_seed_equiv_kwargs:
+                if equiv in kwargs:
+                    kwargs[self.sampling_seed_key] = kwargs.pop(equiv)
+        return kwargs
 
     @property
     def external_sampler_name(self):
diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py
index b12464375..bc3b36465 100644
--- a/bilby/core/sampler/cpnest.py
+++ b/bilby/core/sampler/cpnest.py
@@ -54,8 +54,10 @@ class Cpnest(NestedSampler):
         n_periodic_checkpoint=8000,
     )
     hard_exit = True
+    sampling_seed_key = "seed"
 
     def _translate_kwargs(self, kwargs):
+        kwargs = super()._translate_kwargs(kwargs)
         if "nlive" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
diff --git a/bilby/core/sampler/dnest4.py b/bilby/core/sampler/dnest4.py
index 7d5b97092..5c3d7566e 100644
--- a/bilby/core/sampler/dnest4.py
+++ b/bilby/core/sampler/dnest4.py
@@ -114,6 +114,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler):
     )
     short_name = "dn4"
     hard_exit = True
+    sampling_seed_key = "seed"
 
     def __init__(
         self,
@@ -254,6 +255,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler):
         return self.result
 
     def _translate_kwargs(self, kwargs):
+        kwargs = super()._translate_kwargs(kwargs)
         if "num_steps" not in kwargs:
             for equiv in self.walks_equiv_kwargs:
                 if equiv in kwargs:
diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index ab0af61be..82da609e7 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -239,6 +239,7 @@ class Dynesty(NestedSampler):
         }
 
     def _translate_kwargs(self, kwargs):
+        kwargs = super()._translate_kwargs(kwargs)
         if "nlive" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py
index 5afa169b8..18a36fd13 100644
--- a/bilby/core/sampler/emcee.py
+++ b/bilby/core/sampler/emcee.py
@@ -113,6 +113,7 @@ class Emcee(MCMCSampler):
         return emcee
 
     def _translate_kwargs(self, kwargs):
+        kwargs = super()._translate_kwargs(kwargs)
         if "nwalkers" not in kwargs:
             for equiv in self.nwalkers_equiv_kwargs:
                 if equiv in kwargs:
diff --git a/bilby/core/sampler/nessai.py b/bilby/core/sampler/nessai.py
index fdee87b05..d0d050370 100644
--- a/bilby/core/sampler/nessai.py
+++ b/bilby/core/sampler/nessai.py
@@ -19,7 +19,7 @@ class Nessai(NestedSampler):
     """
 
     _default_kwargs = None
-    seed_equiv_kwargs = ["sampling_seed"]
+    sampling_seed_key = "seed"
 
     @property
     def default_kwargs(self):
@@ -165,6 +165,7 @@ class Nessai(NestedSampler):
         return self.result
 
     def _translate_kwargs(self, kwargs):
+        super()._translate_kwargs(kwargs)
         if "nlive" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
@@ -175,10 +176,6 @@ class Nessai(NestedSampler):
                     kwargs["n_pool"] = kwargs.pop(equiv)
             if "n_pool" not in kwargs:
                 kwargs["n_pool"] = self._npool
-        if "seed" not in kwargs:
-            for equiv in self.seed_equiv_kwargs:
-                if equiv in kwargs:
-                    kwargs["seed"] = kwargs.pop(equiv)
 
     def _verify_kwargs_against_default_kwargs(self):
         """
diff --git a/bilby/core/sampler/nestle.py b/bilby/core/sampler/nestle.py
index 2ea8787a6..41318e962 100644
--- a/bilby/core/sampler/nestle.py
+++ b/bilby/core/sampler/nestle.py
@@ -42,6 +42,7 @@ class Nestle(NestedSampler):
     )
 
     def _translate_kwargs(self, kwargs):
+        kwargs = super()._translate_kwargs(kwargs)
         if "npoints" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
diff --git a/bilby/core/sampler/polychord.py b/bilby/core/sampler/polychord.py
index 617d6c7d1..e43c5d50b 100644
--- a/bilby/core/sampler/polychord.py
+++ b/bilby/core/sampler/polychord.py
@@ -50,6 +50,7 @@ class PyPolyChord(NestedSampler):
         nlives={},
     )
     hard_exit = True
+    sampling_seed_key = "seed"
 
     @signal_wrapper
     def run_sampler(self):
@@ -100,6 +101,7 @@ class PyPolyChord(NestedSampler):
             self.kwargs["num_repeats"] = self.ndim * 5
 
     def _translate_kwargs(self, kwargs):
+        kwargs = super()._translate_kwargs(kwargs)
         if "nlive" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index 063e2af7e..2534b0369 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -286,6 +286,7 @@ class Ptemcee(MCMCSampler):
 
     def _translate_kwargs(self, kwargs):
         """Translate kwargs"""
+        kwargs = super()._translate_kwargs(kwargs)
         if "nwalkers" not in kwargs:
             for equiv in self.nwalkers_equiv_kwargs:
                 if equiv in kwargs:
diff --git a/bilby/core/sampler/ptmcmc.py b/bilby/core/sampler/ptmcmc.py
index 6b9c3c96e..42279e018 100644
--- a/bilby/core/sampler/ptmcmc.py
+++ b/bilby/core/sampler/ptmcmc.py
@@ -116,6 +116,7 @@ class PTMCMCSampler(MCMCSampler):
             )
 
     def _translate_kwargs(self, kwargs):
+        kwargs = super()._translate_kwargs(kwargs)
         if "Niter" not in kwargs:
             for equiv in self.nwalkers_equiv_kwargs:
                 if equiv in kwargs:
diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py
index 1c6a2790a..a67094eb0 100644
--- a/bilby/core/sampler/pymc3.py
+++ b/bilby/core/sampler/pymc3.py
@@ -91,6 +91,8 @@ class Pymc3(MCMCSampler):
 
     default_kwargs.update(default_nuts_kwargs)
 
+    sampling_seed_key = "random_seed"
+
     def __init__(
         self,
         likelihood,
diff --git a/bilby/core/sampler/pymultinest.py b/bilby/core/sampler/pymultinest.py
index da6e7a977..6f0349fe3 100644
--- a/bilby/core/sampler/pymultinest.py
+++ b/bilby/core/sampler/pymultinest.py
@@ -62,6 +62,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler):
     )
     short_name = "pm"
     hard_exit = True
+    sampling_seed_key = "seed"
 
     def __init__(
         self,
@@ -104,6 +105,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler):
         self.use_temporary_directory = temporary_directory and not using_mpi
 
     def _translate_kwargs(self, kwargs):
+        kwargs = super()._translate_kwargs(kwargs)
         if "n_live_points" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
diff --git a/bilby/core/sampler/ultranest.py b/bilby/core/sampler/ultranest.py
index fc70b38ad..4cc14a9fa 100644
--- a/bilby/core/sampler/ultranest.py
+++ b/bilby/core/sampler/ultranest.py
@@ -104,6 +104,7 @@ class Ultranest(_TemporaryFileSamplerMixin, NestedSampler):
             self.callback_interval = callback_interval
 
     def _translate_kwargs(self, kwargs):
+        kwargs = super()._translate_kwargs(kwargs)
         if "num_live_points" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py
index 3a1059e0d..4856a9e7d 100644
--- a/test/core/sampler/base_sampler_test.py
+++ b/test/core/sampler/base_sampler_test.py
@@ -62,6 +62,15 @@ class TestSampler(unittest.TestCase):
     def test_label(self):
         self.assertEqual(self.sampler.label, "label")
 
+    @parameterized.expand(["sampling_seed", "seed", "random_seed"])
+    def test_translate_kwargs(self, key):
+        self.sampler.sampling_seed_key = key
+        for k in self.sampler.sampling_seed_equiv_kwargs:
+            kwargs = {k: 1234}
+            updated_kwargs = self.sampler._translate_kwargs(kwargs)
+            self.assertDictEqual(updated_kwargs, {key: 1234})
+        self.sampler.sampling_seed_key = None
+
     def test_prior_transform_transforms_search_parameter_keys(self):
         self.sampler.prior_transform([0])
         expected_prior = prior.Uniform(0, 1)
diff --git a/test/core/sampler/nessai_test.py b/test/core/sampler/nessai_test.py
index 86b03fb38..7f6ec21a8 100644
--- a/test/core/sampler/nessai_test.py
+++ b/test/core/sampler/nessai_test.py
@@ -20,9 +20,11 @@ class TestNessai(unittest.TestCase):
             use_ratio=False,
             plot=False,
             skip_import_verification=True,
+            sampling_seed=150914,
         )
         self.expected = self.sampler.default_kwargs
         self.expected['output'] = 'outdir/label_nessai/'
+        self.expected['seed'] = 150914
 
     def tearDown(self):
         del self.likelihood
@@ -54,14 +56,7 @@ class TestNessai(unittest.TestCase):
             self.assertDictEqual(expected, self.sampler.kwargs)
 
     def test_translate_kwargs_seed(self):
-        expected = self.expected.copy()
-        expected["seed"] = 150914
-        for equiv in bilby.core.sampler.nessai.Nessai.seed_equiv_kwargs:
-            new_kwargs = self.sampler.kwargs.copy()
-            del new_kwargs["seed"]
-            new_kwargs[equiv] = 150914
-            self.sampler.kwargs = new_kwargs
-            self.assertDictEqual(expected, self.sampler.kwargs)
+        assert self.expected["seed"] == 150914
 
     def test_npool_max_threads(self):
         expected = self.expected.copy()
-- 
GitLab