Skip to content
Snippets Groups Projects
Commit 825ee34c authored by Michael Williams's avatar Michael Williams
Browse files

TST: add reproducibility test for dynesty

parent 33a2e752
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ import numpy as np
import parameterized
from bilby.core.sampler import dynesty_utils
from scipy.stats import gamma, ks_1samp, uniform, powerlaw
import shutil
@define
......@@ -271,5 +272,75 @@ class TestEstimateNMCMC(unittest.TestCase):
self.assertAlmostEqual(estimated, expected)
class TestReproducibility(unittest.TestCase):
@staticmethod
def model(x, m, c):
return m * x + c
def setUp(self):
bilby.core.utils.random.seed(42)
bilby.core.utils.command_line_args.bilby_test_mode = False
rng = bilby.core.utils.random.rng
self.x = np.linspace(0, 1, 11)
self.injection_parameters = dict(m=0.5, c=0.2)
self.sigma = 0.1
self.y = self.model(self.x, **self.injection_parameters) + rng.normal(
0, self.sigma, len(self.x)
)
self.likelihood = bilby.likelihood.GaussianLikelihood(
self.x, self.y, self.model, self.sigma
)
self.priors = bilby.core.prior.PriorDict()
self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic")
self.priors["c"] = bilby.core.prior.Uniform(-2, 2, boundary="reflective")
self._remove_tree()
bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir")
def tearDown(self):
del self.likelihood
del self.priors
bilby.core.utils.command_line_args.bilby_test_mode = False
self._remove_tree()
def _remove_tree(self):
try:
shutil.rmtree("outdir")
except OSError:
pass
def _run_sampler(self, **kwargs):
bilby.core.utils.random.seed(42)
return bilby.run_sampler(
likelihood=self.likelihood,
priors=self.priors,
sampler="dynesty",
save=False,
resume=False,
dlogz=1.0,
nlive=20,
**kwargs,
)
def test_reproducibility_seed(self):
res0 = self._run_sampler(seed=1234)
res1 = self._run_sampler(seed=1234)
assert res0.log_evidence == res1.log_evidence
def test_reproducibility_state(self):
rstate = np.random.default_rng(1234)
res0 = self._run_sampler(rstate=rstate)
rstate = np.random.default_rng(1234)
res1 = self._run_sampler(rstate=rstate)
assert res0.log_evidence == res1.log_evidence
def test_reproducibility_state_and_seed(self):
rstate = np.random.default_rng(1234)
res0 = self._run_sampler(rstate=rstate)
res1 = self._run_sampler(seed=1234)
assert res0.log_evidence == res1.log_evidence
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment