diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py index 555e321496cc1273db669daf43f21e86ea485a11..3decaf74a443d8668212e6a32c9cb4a44850f1a0 100644 --- a/bilby/bilby_mcmc/sampler.py +++ b/bilby/bilby_mcmc/sampler.py @@ -202,6 +202,7 @@ class Bilby_MCMC(MCMCSampler): logger.warning("Burn-in inefficiency fraction greater than 10%") def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "printdt" not in kwargs: for equiv in ["print_dt", "print_update"]: if equiv in kwargs: diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index c30f76045d285d0362e6d972330694202e4db381..47eee4ab67174e6fe34bcdbc4251a9639fc5a29e 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -196,7 +196,13 @@ class Sampler(object): "cores", "n_pool", ] + sampling_seed_equiv_kwargs = ["sampling_seed", "seed", "random_seed"] hard_exit = False + sampling_seed_key = None + """Name of keyword argument for setting the sampling for the specific sampler. + If a specific sampler does not have a sampling seed option, then it should be + left as None. + """ def __init__( self, @@ -289,8 +295,16 @@ class Sampler(object): self._verify_kwargs_against_default_kwargs() def _translate_kwargs(self, kwargs): - """Template for child classes""" - pass + """Translate keyword arguments. + + Default only translates the sampling seed if the sampler has + :code:`sampling_seed_key` set. + """ + if self.sampling_seed_key and self.sampling_seed_key not in kwargs: + for equiv in self.sampling_seed_equiv_kwargs: + if equiv in kwargs: + kwargs[self.sampling_seed_key] = kwargs.pop(equiv) + return kwargs @property def external_sampler_name(self): diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py index b124643759fc742465d612918523161b585bb17c..bc3b364656d26bcff0c14e3852bbbd394c5887cf 100644 --- a/bilby/core/sampler/cpnest.py +++ b/bilby/core/sampler/cpnest.py @@ -54,8 +54,10 @@ class Cpnest(NestedSampler): n_periodic_checkpoint=8000, ) hard_exit = True + sampling_seed_key = "seed" def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "nlive" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/dnest4.py b/bilby/core/sampler/dnest4.py index 7d5b97092919e0951de2329de8633e4a6bc7fc2b..5c3d7566e729fb92e5427e9c8b5968c3df6c6abe 100644 --- a/bilby/core/sampler/dnest4.py +++ b/bilby/core/sampler/dnest4.py @@ -114,6 +114,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler): ) short_name = "dn4" hard_exit = True + sampling_seed_key = "seed" def __init__( self, @@ -254,6 +255,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler): return self.result def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "num_steps" not in kwargs: for equiv in self.walks_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index ab0af61be6dc55ffac996931e4ed48a00cc348f1..82da609e77a17613f90c248b6b04b7b1c2854bf9 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -239,6 +239,7 @@ class Dynesty(NestedSampler): } def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "nlive" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 5afa169b8959faa4c8e1e2df3a3d7397819ecd2f..18a36fd1371aed4ccfb47d0c412cb57c77fb803b 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -113,6 +113,7 @@ class Emcee(MCMCSampler): return emcee def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "nwalkers" not in kwargs: for equiv in self.nwalkers_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/nessai.py b/bilby/core/sampler/nessai.py index fdee87b058647a47a5671aed9a9ab14d23ad0573..d0d05037031383ff9a22a08898856e06a6ddbf8d 100644 --- a/bilby/core/sampler/nessai.py +++ b/bilby/core/sampler/nessai.py @@ -19,7 +19,7 @@ class Nessai(NestedSampler): """ _default_kwargs = None - seed_equiv_kwargs = ["sampling_seed"] + sampling_seed_key = "seed" @property def default_kwargs(self): @@ -165,6 +165,7 @@ class Nessai(NestedSampler): return self.result def _translate_kwargs(self, kwargs): + super()._translate_kwargs(kwargs) if "nlive" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: @@ -175,10 +176,6 @@ class Nessai(NestedSampler): kwargs["n_pool"] = kwargs.pop(equiv) if "n_pool" not in kwargs: kwargs["n_pool"] = self._npool - if "seed" not in kwargs: - for equiv in self.seed_equiv_kwargs: - if equiv in kwargs: - kwargs["seed"] = kwargs.pop(equiv) def _verify_kwargs_against_default_kwargs(self): """ diff --git a/bilby/core/sampler/nestle.py b/bilby/core/sampler/nestle.py index 2ea8787a63a85b62101e6f0d193bae85765f3884..41318e9628d63e996f50113448b4db36488382f2 100644 --- a/bilby/core/sampler/nestle.py +++ b/bilby/core/sampler/nestle.py @@ -42,6 +42,7 @@ class Nestle(NestedSampler): ) def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "npoints" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/polychord.py b/bilby/core/sampler/polychord.py index 617d6c7d17b22569f1b9bc23e37e58e292033625..e43c5d50b248ba0fb12cd8d5bca97b0fee726c45 100644 --- a/bilby/core/sampler/polychord.py +++ b/bilby/core/sampler/polychord.py @@ -50,6 +50,7 @@ class PyPolyChord(NestedSampler): nlives={}, ) hard_exit = True + sampling_seed_key = "seed" @signal_wrapper def run_sampler(self): @@ -100,6 +101,7 @@ class PyPolyChord(NestedSampler): self.kwargs["num_repeats"] = self.ndim * 5 def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "nlive" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index 063e2af7e472812e0a9a8f88cc2ea7c1fe384cc0..2534b0369d1de8d1b75e8561f3cafb4cda26ec73 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -286,6 +286,7 @@ class Ptemcee(MCMCSampler): def _translate_kwargs(self, kwargs): """Translate kwargs""" + kwargs = super()._translate_kwargs(kwargs) if "nwalkers" not in kwargs: for equiv in self.nwalkers_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/ptmcmc.py b/bilby/core/sampler/ptmcmc.py index 6b9c3c96eb83a81486df1e38dd5948b668cfb358..42279e018ed124cd117118e75949b60d74e3a302 100644 --- a/bilby/core/sampler/ptmcmc.py +++ b/bilby/core/sampler/ptmcmc.py @@ -116,6 +116,7 @@ class PTMCMCSampler(MCMCSampler): ) def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "Niter" not in kwargs: for equiv in self.nwalkers_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py index 1c6a2790a74c61001577bc8606ee0834a1b80455..a67094eb0dc510a9849f6e1544239901421408a5 100644 --- a/bilby/core/sampler/pymc3.py +++ b/bilby/core/sampler/pymc3.py @@ -91,6 +91,8 @@ class Pymc3(MCMCSampler): default_kwargs.update(default_nuts_kwargs) + sampling_seed_key = "random_seed" + def __init__( self, likelihood, diff --git a/bilby/core/sampler/pymultinest.py b/bilby/core/sampler/pymultinest.py index da6e7a9778273f0bb25a2702bddcbc2fd18f4ae3..6f0349fe33964ad381675dbc713edfe2dcc4ab1f 100644 --- a/bilby/core/sampler/pymultinest.py +++ b/bilby/core/sampler/pymultinest.py @@ -62,6 +62,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler): ) short_name = "pm" hard_exit = True + sampling_seed_key = "seed" def __init__( self, @@ -104,6 +105,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler): self.use_temporary_directory = temporary_directory and not using_mpi def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "n_live_points" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: diff --git a/bilby/core/sampler/ultranest.py b/bilby/core/sampler/ultranest.py index fc70b38ad4f0b04169254e8031c0a996087abdd2..4cc14a9fa7ff9ba1ce4f19979571f8dde2c71e7b 100644 --- a/bilby/core/sampler/ultranest.py +++ b/bilby/core/sampler/ultranest.py @@ -104,6 +104,7 @@ class Ultranest(_TemporaryFileSamplerMixin, NestedSampler): self.callback_interval = callback_interval def _translate_kwargs(self, kwargs): + kwargs = super()._translate_kwargs(kwargs) if "num_live_points" not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py index 3a1059e0dd82ac27b1897713b6b40f9f702bf644..4856a9e7df4922bf94fe26746972bb1864774535 100644 --- a/test/core/sampler/base_sampler_test.py +++ b/test/core/sampler/base_sampler_test.py @@ -62,6 +62,15 @@ class TestSampler(unittest.TestCase): def test_label(self): self.assertEqual(self.sampler.label, "label") + @parameterized.expand(["sampling_seed", "seed", "random_seed"]) + def test_translate_kwargs(self, key): + self.sampler.sampling_seed_key = key + for k in self.sampler.sampling_seed_equiv_kwargs: + kwargs = {k: 1234} + updated_kwargs = self.sampler._translate_kwargs(kwargs) + self.assertDictEqual(updated_kwargs, {key: 1234}) + self.sampler.sampling_seed_key = None + def test_prior_transform_transforms_search_parameter_keys(self): self.sampler.prior_transform([0]) expected_prior = prior.Uniform(0, 1) diff --git a/test/core/sampler/nessai_test.py b/test/core/sampler/nessai_test.py index 86b03fb38e74afb5ce75c06b5fbd91add3d7f49e..7f6ec21a8a5d26b606e8c6f8aa3cae3ede905ca2 100644 --- a/test/core/sampler/nessai_test.py +++ b/test/core/sampler/nessai_test.py @@ -20,9 +20,11 @@ class TestNessai(unittest.TestCase): use_ratio=False, plot=False, skip_import_verification=True, + sampling_seed=150914, ) self.expected = self.sampler.default_kwargs self.expected['output'] = 'outdir/label_nessai/' + self.expected['seed'] = 150914 def tearDown(self): del self.likelihood @@ -54,14 +56,7 @@ class TestNessai(unittest.TestCase): self.assertDictEqual(expected, self.sampler.kwargs) def test_translate_kwargs_seed(self): - expected = self.expected.copy() - expected["seed"] = 150914 - for equiv in bilby.core.sampler.nessai.Nessai.seed_equiv_kwargs: - new_kwargs = self.sampler.kwargs.copy() - del new_kwargs["seed"] - new_kwargs[equiv] = 150914 - self.sampler.kwargs = new_kwargs - self.assertDictEqual(expected, self.sampler.kwargs) + assert self.expected["seed"] == 150914 def test_npool_max_threads(self): expected = self.expected.copy()