From 99ba5346e7705eb3de1e81b7e64c7733dd7fff40 Mon Sep 17 00:00:00 2001 From: Michael Williams <michael.williams@ligo.org> Date: Tue, 16 Aug 2022 13:46:22 +0000 Subject: [PATCH] Improve how sampling seed is handled --- bilby/bilby_mcmc/sampler.py | 1 + bilby/core/sampler/base_sampler.py | 18 ++++++++++++++++-- bilby/core/sampler/cpnest.py | 2 ++ bilby/core/sampler/dnest4.py | 2 ++ bilby/core/sampler/dynesty.py | 1 + bilby/core/sampler/emcee.py | 1 + bilby/core/sampler/nessai.py | 7 ++----- bilby/core/sampler/nestle.py | 1 + bilby/core/sampler/polychord.py | 2 ++ bilby/core/sampler/ptemcee.py | 1 + bilby/core/sampler/ptmcmc.py | 1 + bilby/core/sampler/pymc3.py | 2 ++ bilby/core/sampler/pymultinest.py | 2 ++ bilby/core/sampler/ultranest.py | 1 + test/core/sampler/base_sampler_test.py | 9 +++++++++ test/core/sampler/nessai_test.py | 11 +++-------- 16 files changed, 47 insertions(+), 15 deletions(-) diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py index 555e32149..3decaf74a 100644 --- a/bilby/bilby_mcmc/sampler.py +++ b/bilby/bilby_mcmc/sampler.py @@ -202,6 +202,7 @@ class Bilby_MCMC(MCMCSampler): logger.warning("Burn-in inefficiency fraction greater than 10%") def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "printdt" not in kwargs: for equiv in ["print_dt", "print_update"]: if equiv in kwargs: diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index c30f76045..47eee4ab6 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -196,7 +196,13 @@ class Sampler(object): "cores", "n_pool", ] + sampling_seed_equiv_kwargs = ["sampling_seed", "seed", "random_seed"] hard_exit = False + sampling_seed_key = None + """Name of keyword argument for setting the sampling for the specific sampler. + If a specific sampler does not have a sampling seed option, then it should be + left as None. + """ def __init__( self, @@ -289,8 +295,16 @@ class Sampler(object): self._verify_kwargs_against_default_kwargs() def _translate_kwargs(self, kwargs): - """Template for child classes""" - pass + """Translate keyword arguments. + + Default only translates the sampling seed if the sampler has + :code:`sampling_seed_key` set. + """ + if self.sampling_seed_key and self.sampling_seed_key not in kwargs: + for equiv in self.sampling_seed_equiv_kwargs: + if equiv in kwargs: + kwargs[self.sampling_seed_key] = kwargs.pop(equiv) + return kwargs @property def external_sampler_name(self): diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py index b12464375..bc3b36465 100644 --- a/bilby/core/sampler/cpnest.py +++ b/bilby/core/sampler/cpnest.py @@ -54,8 +54,10 @@ class Cpnest(NestedSampler): n_periodic_checkpoint=8000, ) hard_exit = True + sampling_seed_key = "seed" def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "nlive" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/dnest4.py b/bilby/core/sampler/dnest4.py index 7d5b97092..5c3d7566e 100644 --- a/bilby/core/sampler/dnest4.py +++ b/bilby/core/sampler/dnest4.py @@ -114,6 +114,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler): ) short_name = "dn4" hard_exit = True + sampling_seed_key = "seed" def __init__( self, @@ -254,6 +255,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler): return self.result def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "num_steps" not in kwargs: for equiv in self.walks_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index ab0af61be..82da609e7 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -239,6 +239,7 @@ class Dynesty(NestedSampler): } def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "nlive" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 5afa169b8..18a36fd13 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -113,6 +113,7 @@ class Emcee(MCMCSampler): return emcee def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "nwalkers" not in kwargs: for equiv in self.nwalkers_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/nessai.py b/bilby/core/sampler/nessai.py index fdee87b05..d0d050370 100644 --- a/bilby/core/sampler/nessai.py +++ b/bilby/core/sampler/nessai.py @@ -19,7 +19,7 @@ class Nessai(NestedSampler): """ _default_kwargs = None - seed_equiv_kwargs = ["sampling_seed"] + sampling_seed_key = "seed" @property def default_kwargs(self): @@ -165,6 +165,7 @@ class Nessai(NestedSampler): return self.result def _translate_kwargs(self, kwargs): + super()._translate_kwargs(kwargs) if "nlive" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: @@ -175,10 +176,6 @@ class Nessai(NestedSampler): kwargs["n_pool"] = kwargs.pop(equiv) if "n_pool" not in kwargs: kwargs["n_pool"] = self._npool - if "seed" not in kwargs: - for equiv in self.seed_equiv_kwargs: - if equiv in kwargs: - kwargs["seed"] = kwargs.pop(equiv) def _verify_kwargs_against_default_kwargs(self): """ diff --git a/bilby/core/sampler/nestle.py b/bilby/core/sampler/nestle.py index 2ea8787a6..41318e962 100644 --- a/bilby/core/sampler/nestle.py +++ b/bilby/core/sampler/nestle.py @@ -42,6 +42,7 @@ class Nestle(NestedSampler): ) def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "npoints" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/polychord.py b/bilby/core/sampler/polychord.py index 617d6c7d1..e43c5d50b 100644 --- a/bilby/core/sampler/polychord.py +++ b/bilby/core/sampler/polychord.py @@ -50,6 +50,7 @@ class PyPolyChord(NestedSampler): nlives={}, ) hard_exit = True + sampling_seed_key = "seed" @signal_wrapper def run_sampler(self): @@ -100,6 +101,7 @@ class PyPolyChord(NestedSampler): self.kwargs["num_repeats"] = self.ndim * 5 def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "nlive" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index 063e2af7e..2534b0369 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -286,6 +286,7 @@ class Ptemcee(MCMCSampler): def _translate_kwargs(self, kwargs): """Translate kwargs""" + kwargs = super()._translate_kwargs(kwargs) if "nwalkers" not in kwargs: for equiv in self.nwalkers_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/ptmcmc.py b/bilby/core/sampler/ptmcmc.py index 6b9c3c96e..42279e018 100644 --- a/bilby/core/sampler/ptmcmc.py +++ b/bilby/core/sampler/ptmcmc.py @@ -116,6 +116,7 @@ class PTMCMCSampler(MCMCSampler): ) def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "Niter" not in kwargs: for equiv in self.nwalkers_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py index 1c6a2790a..a67094eb0 100644 --- a/bilby/core/sampler/pymc3.py +++ b/bilby/core/sampler/pymc3.py @@ -91,6 +91,8 @@ class Pymc3(MCMCSampler): default_kwargs.update(default_nuts_kwargs) + sampling_seed_key = "random_seed" + def __init__( self, likelihood, diff --git a/bilby/core/sampler/pymultinest.py b/bilby/core/sampler/pymultinest.py index da6e7a977..6f0349fe3 100644 --- a/bilby/core/sampler/pymultinest.py +++ b/bilby/core/sampler/pymultinest.py @@ -62,6 +62,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler): ) short_name = "pm" hard_exit = True + sampling_seed_key = "seed" def __init__( self, @@ -104,6 +105,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler): self.use_temporary_directory = temporary_directory and not using_mpi def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "n_live_points" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/ultranest.py b/bilby/core/sampler/ultranest.py index fc70b38ad..4cc14a9fa 100644 --- a/bilby/core/sampler/ultranest.py +++ b/bilby/core/sampler/ultranest.py @@ -104,6 +104,7 @@ class Ultranest(_TemporaryFileSamplerMixin, NestedSampler): self.callback_interval = callback_interval def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "num_live_points" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py index 3a1059e0d..4856a9e7d 100644 --- a/test/core/sampler/base_sampler_test.py +++ b/test/core/sampler/base_sampler_test.py @@ -62,6 +62,15 @@ class TestSampler(unittest.TestCase): def test_label(self): self.assertEqual(self.sampler.label, "label") + @parameterized.expand(["sampling_seed", "seed", "random_seed"]) + def test_translate_kwargs(self, key): + self.sampler.sampling_seed_key = key + for k in self.sampler.sampling_seed_equiv_kwargs: + kwargs = {k: 1234} + updated_kwargs = self.sampler._translate_kwargs(kwargs) + self.assertDictEqual(updated_kwargs, {key: 1234}) + self.sampler.sampling_seed_key = None + def test_prior_transform_transforms_search_parameter_keys(self): self.sampler.prior_transform([0]) expected_prior = prior.Uniform(0, 1) diff --git a/test/core/sampler/nessai_test.py b/test/core/sampler/nessai_test.py index 86b03fb38..7f6ec21a8 100644 --- a/test/core/sampler/nessai_test.py +++ b/test/core/sampler/nessai_test.py @@ -20,9 +20,11 @@ class TestNessai(unittest.TestCase): use_ratio=False, plot=False, skip_import_verification=True, + sampling_seed=150914, ) self.expected = self.sampler.default_kwargs self.expected['output'] = 'outdir/label_nessai/' + self.expected['seed'] = 150914 def tearDown(self): del self.likelihood @@ -54,14 +56,7 @@ class TestNessai(unittest.TestCase): self.assertDictEqual(expected, self.sampler.kwargs) def test_translate_kwargs_seed(self): - expected = self.expected.copy() - expected["seed"] = 150914 - for equiv in bilby.core.sampler.nessai.Nessai.seed_equiv_kwargs: - new_kwargs = self.sampler.kwargs.copy() - del new_kwargs["seed"] - new_kwargs[equiv] = 150914 - self.sampler.kwargs = new_kwargs - self.assertDictEqual(expected, self.sampler.kwargs) + assert self.expected["seed"] == 150914 def test_npool_max_threads(self): expected = self.expected.copy() -- GitLab