From 84c73998ab59832d2daa0f3614b7f10639e716b8 Mon Sep 17 00:00:00 2001 From: Michael Williams <michael.williams@ligo.org> Date: Wed, 14 Feb 2024 15:05:13 +0000 Subject: [PATCH] BUG: Fix dynesty reproducibility --- bilby/core/sampler/dynesty.py | 18 ++++++++ test/core/sampler/dynesty_test.py | 74 +++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 2767debd2..d8910a907 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -123,6 +123,9 @@ class Dynesty(NestedSampler): The proposal methods to use during MCMC. This can be some combination of :code:`"diff", "volumetric"`. See the dynesty guide in the Bilby docs for more details. default=:code:`["diff"]`. + rstate: numpy.random.Generator (None) + Instance of a numpy random generator for generating random numbers. + Also see :code:`seed` in 'Other Parameters'. Other Parameters ================ @@ -143,8 +146,13 @@ class Dynesty(NestedSampler): has no impact on the sampling. dlogz: float, (0.1) Stopping criteria + seed: int (None) + Use to seed the random number generator if :code:`rstate` is not + specified. """ + sampling_seed_key = "seed" + @property def _dynesty_init_kwargs(self): params = inspect.signature(self.sampler_init).parameters @@ -176,6 +184,7 @@ class Dynesty(NestedSampler): def default_kwargs(self): kwargs = self._dynesty_init_kwargs kwargs.update(self._dynesty_sampler_kwargs) + kwargs["seed"] = None return kwargs def __init__( @@ -265,6 +274,14 @@ class Dynesty(NestedSampler): for equiv in self.npool_equiv_kwargs: if equiv in kwargs: kwargs["queue_size"] = kwargs.pop(equiv) + if "seed" in kwargs: + seed = kwargs.get("seed") + if "rstate" not in kwargs: + kwargs["rstate"] = np.random.default_rng(seed) + else: + logger.warning( + "Kwargs contain both 'rstate' and 'seed', ignoring 'seed'." + ) def _verify_kwargs_against_default_kwargs(self): if not self.kwargs["walks"]: @@ -604,6 +621,7 @@ class Dynesty(NestedSampler): sampling_time_s=self.sampling_time.seconds, ncores=self.kwargs.get("queue_size", 1), ) + self.kwargs["rstate"] = None def _update_sampling_time(self): end_time = datetime.datetime.now() diff --git a/test/core/sampler/dynesty_test.py b/test/core/sampler/dynesty_test.py index 9640495bc..d88ba4de9 100644 --- a/test/core/sampler/dynesty_test.py +++ b/test/core/sampler/dynesty_test.py @@ -7,6 +7,7 @@ import numpy as np import parameterized from bilby.core.sampler import dynesty_utils from scipy.stats import gamma, ks_1samp, uniform, powerlaw +import shutil @define @@ -271,5 +272,78 @@ class TestEstimateNMCMC(unittest.TestCase): self.assertAlmostEqual(estimated, expected) +class TestReproducibility(unittest.TestCase): + + @staticmethod + def model(x, m, c): + return m * x + c + + def setUp(self): + bilby.core.utils.random.seed(42) + bilby.core.utils.command_line_args.bilby_test_mode = False + rng = bilby.core.utils.random.rng + self.x = np.linspace(0, 1, 11) + self.injection_parameters = dict(m=0.5, c=0.2) + self.sigma = 0.1 + self.y = self.model(self.x, **self.injection_parameters) + rng.normal( + 0, self.sigma, len(self.x) + ) + self.likelihood = bilby.likelihood.GaussianLikelihood( + self.x, self.y, self.model, self.sigma + ) + + self.priors = bilby.core.prior.PriorDict() + self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic") + self.priors["c"] = bilby.core.prior.Uniform(-2, 2, boundary="reflective") + # Evaluate prior once to ensure normalization constant have been set + theta = self.priors.sample() + self.priors.ln_prob(theta) + self._remove_tree() + bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir") + + def tearDown(self): + del self.likelihood + del self.priors + bilby.core.utils.command_line_args.bilby_test_mode = False + self._remove_tree() + + def _remove_tree(self): + try: + shutil.rmtree("outdir") + except OSError: + pass + + def _run_sampler(self, **kwargs): + bilby.core.utils.random.seed(42) + return bilby.run_sampler( + likelihood=self.likelihood, + priors=self.priors, + sampler="dynesty", + save=False, + resume=False, + dlogz=1.0, + nlive=20, + **kwargs, + ) + + def test_reproducibility_seed(self): + res0 = self._run_sampler(seed=1234) + res1 = self._run_sampler(seed=1234) + assert res0.log_evidence == res1.log_evidence + + def test_reproducibility_state(self): + rstate = np.random.default_rng(1234) + res0 = self._run_sampler(rstate=rstate) + rstate = np.random.default_rng(1234) + res1 = self._run_sampler(rstate=rstate) + assert res0.log_evidence == res1.log_evidence + + def test_reproducibility_state_and_seed(self): + rstate = np.random.default_rng(1234) + res0 = self._run_sampler(rstate=rstate) + res1 = self._run_sampler(seed=1234) + assert res0.log_evidence == res1.log_evidence + + if __name__ == "__main__": unittest.main() -- GitLab