Skip to content
Snippets Groups Projects
Commit 84c73998 authored by Michael Williams's avatar Michael Williams Committed by Colm Talbot
Browse files

BUG: Fix dynesty reproducibility

parent 34cbf86e
No related branches found
No related tags found
No related merge requests found
......@@ -123,6 +123,9 @@ class Dynesty(NestedSampler):
The proposal methods to use during MCMC. This can be some combination
of :code:`"diff", "volumetric"`. See the dynesty guide in the Bilby docs
for more details. default=:code:`["diff"]`.
rstate: numpy.random.Generator (None)
Instance of a numpy random generator for generating random numbers.
Also see :code:`seed` in 'Other Parameters'.
Other Parameters
================
......@@ -143,8 +146,13 @@ class Dynesty(NestedSampler):
has no impact on the sampling.
dlogz: float, (0.1)
Stopping criteria
seed: int (None)
Use to seed the random number generator if :code:`rstate` is not
specified.
"""
sampling_seed_key = "seed"
@property
def _dynesty_init_kwargs(self):
params = inspect.signature(self.sampler_init).parameters
......@@ -176,6 +184,7 @@ class Dynesty(NestedSampler):
def default_kwargs(self):
kwargs = self._dynesty_init_kwargs
kwargs.update(self._dynesty_sampler_kwargs)
kwargs["seed"] = None
return kwargs
def __init__(
......@@ -265,6 +274,14 @@ class Dynesty(NestedSampler):
for equiv in self.npool_equiv_kwargs:
if equiv in kwargs:
kwargs["queue_size"] = kwargs.pop(equiv)
if "seed" in kwargs:
seed = kwargs.get("seed")
if "rstate" not in kwargs:
kwargs["rstate"] = np.random.default_rng(seed)
else:
logger.warning(
"Kwargs contain both 'rstate' and 'seed', ignoring 'seed'."
)
def _verify_kwargs_against_default_kwargs(self):
if not self.kwargs["walks"]:
......@@ -604,6 +621,7 @@ class Dynesty(NestedSampler):
sampling_time_s=self.sampling_time.seconds,
ncores=self.kwargs.get("queue_size", 1),
)
self.kwargs["rstate"] = None
def _update_sampling_time(self):
end_time = datetime.datetime.now()
......
......@@ -7,6 +7,7 @@ import numpy as np
import parameterized
from bilby.core.sampler import dynesty_utils
from scipy.stats import gamma, ks_1samp, uniform, powerlaw
import shutil
@define
......@@ -271,5 +272,78 @@ class TestEstimateNMCMC(unittest.TestCase):
self.assertAlmostEqual(estimated, expected)
class TestReproducibility(unittest.TestCase):
@staticmethod
def model(x, m, c):
return m * x + c
def setUp(self):
bilby.core.utils.random.seed(42)
bilby.core.utils.command_line_args.bilby_test_mode = False
rng = bilby.core.utils.random.rng
self.x = np.linspace(0, 1, 11)
self.injection_parameters = dict(m=0.5, c=0.2)
self.sigma = 0.1
self.y = self.model(self.x, **self.injection_parameters) + rng.normal(
0, self.sigma, len(self.x)
)
self.likelihood = bilby.likelihood.GaussianLikelihood(
self.x, self.y, self.model, self.sigma
)
self.priors = bilby.core.prior.PriorDict()
self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic")
self.priors["c"] = bilby.core.prior.Uniform(-2, 2, boundary="reflective")
# Evaluate prior once to ensure normalization constant have been set
theta = self.priors.sample()
self.priors.ln_prob(theta)
self._remove_tree()
bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir")
def tearDown(self):
del self.likelihood
del self.priors
bilby.core.utils.command_line_args.bilby_test_mode = False
self._remove_tree()
def _remove_tree(self):
try:
shutil.rmtree("outdir")
except OSError:
pass
def _run_sampler(self, **kwargs):
bilby.core.utils.random.seed(42)
return bilby.run_sampler(
likelihood=self.likelihood,
priors=self.priors,
sampler="dynesty",
save=False,
resume=False,
dlogz=1.0,
nlive=20,
**kwargs,
)
def test_reproducibility_seed(self):
res0 = self._run_sampler(seed=1234)
res1 = self._run_sampler(seed=1234)
assert res0.log_evidence == res1.log_evidence
def test_reproducibility_state(self):
rstate = np.random.default_rng(1234)
res0 = self._run_sampler(rstate=rstate)
rstate = np.random.default_rng(1234)
res1 = self._run_sampler(rstate=rstate)
assert res0.log_evidence == res1.log_evidence
def test_reproducibility_state_and_seed(self):
rstate = np.random.default_rng(1234)
res0 = self._run_sampler(rstate=rstate)
res1 = self._run_sampler(seed=1234)
assert res0.log_evidence == res1.log_evidence
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment