Skip to content
Snippets Groups Projects
Commit 374d810f authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'improve-sampling-seed-fixed' into 'master'

Improve how sampling seed is handled

See merge request lscsoft/bilby!1134
parents eac3f84e 99ba5346
No related branches found
No related tags found
1 merge request!1134Improve how sampling seed is handled
Pipeline #440714 passed
Showing with 47 additions and 15 deletions
...@@ -202,6 +202,7 @@ class Bilby_MCMC(MCMCSampler): ...@@ -202,6 +202,7 @@ class Bilby_MCMC(MCMCSampler):
logger.warning("Burn-in inefficiency fraction greater than 10%") logger.warning("Burn-in inefficiency fraction greater than 10%")
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "printdt" not in kwargs: if "printdt" not in kwargs:
for equiv in ["print_dt", "print_update"]: for equiv in ["print_dt", "print_update"]:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -196,7 +196,13 @@ class Sampler(object): ...@@ -196,7 +196,13 @@ class Sampler(object):
"cores", "cores",
"n_pool", "n_pool",
] ]
sampling_seed_equiv_kwargs = ["sampling_seed", "seed", "random_seed"]
hard_exit = False hard_exit = False
sampling_seed_key = None
"""Name of keyword argument for setting the sampling for the specific sampler.
If a specific sampler does not have a sampling seed option, then it should be
left as None.
"""
def __init__( def __init__(
self, self,
...@@ -289,8 +295,16 @@ class Sampler(object): ...@@ -289,8 +295,16 @@ class Sampler(object):
self._verify_kwargs_against_default_kwargs() self._verify_kwargs_against_default_kwargs()
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
"""Template for child classes""" """Translate keyword arguments.
pass
Default only translates the sampling seed if the sampler has
:code:`sampling_seed_key` set.
"""
if self.sampling_seed_key and self.sampling_seed_key not in kwargs:
for equiv in self.sampling_seed_equiv_kwargs:
if equiv in kwargs:
kwargs[self.sampling_seed_key] = kwargs.pop(equiv)
return kwargs
@property @property
def external_sampler_name(self): def external_sampler_name(self):
......
...@@ -54,8 +54,10 @@ class Cpnest(NestedSampler): ...@@ -54,8 +54,10 @@ class Cpnest(NestedSampler):
n_periodic_checkpoint=8000, n_periodic_checkpoint=8000,
) )
hard_exit = True hard_exit = True
sampling_seed_key = "seed"
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "nlive" not in kwargs: if "nlive" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -114,6 +114,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler): ...@@ -114,6 +114,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler):
) )
short_name = "dn4" short_name = "dn4"
hard_exit = True hard_exit = True
sampling_seed_key = "seed"
def __init__( def __init__(
self, self,
...@@ -254,6 +255,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler): ...@@ -254,6 +255,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler):
return self.result return self.result
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "num_steps" not in kwargs: if "num_steps" not in kwargs:
for equiv in self.walks_equiv_kwargs: for equiv in self.walks_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -239,6 +239,7 @@ class Dynesty(NestedSampler): ...@@ -239,6 +239,7 @@ class Dynesty(NestedSampler):
} }
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "nlive" not in kwargs: if "nlive" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -113,6 +113,7 @@ class Emcee(MCMCSampler): ...@@ -113,6 +113,7 @@ class Emcee(MCMCSampler):
return emcee return emcee
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "nwalkers" not in kwargs: if "nwalkers" not in kwargs:
for equiv in self.nwalkers_equiv_kwargs: for equiv in self.nwalkers_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -19,7 +19,7 @@ class Nessai(NestedSampler): ...@@ -19,7 +19,7 @@ class Nessai(NestedSampler):
""" """
_default_kwargs = None _default_kwargs = None
seed_equiv_kwargs = ["sampling_seed"] sampling_seed_key = "seed"
@property @property
def default_kwargs(self): def default_kwargs(self):
...@@ -165,6 +165,7 @@ class Nessai(NestedSampler): ...@@ -165,6 +165,7 @@ class Nessai(NestedSampler):
return self.result return self.result
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
super()._translate_kwargs(kwargs)
if "nlive" not in kwargs: if "nlive" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
...@@ -175,10 +176,6 @@ class Nessai(NestedSampler): ...@@ -175,10 +176,6 @@ class Nessai(NestedSampler):
kwargs["n_pool"] = kwargs.pop(equiv) kwargs["n_pool"] = kwargs.pop(equiv)
if "n_pool" not in kwargs: if "n_pool" not in kwargs:
kwargs["n_pool"] = self._npool kwargs["n_pool"] = self._npool
if "seed" not in kwargs:
for equiv in self.seed_equiv_kwargs:
if equiv in kwargs:
kwargs["seed"] = kwargs.pop(equiv)
def _verify_kwargs_against_default_kwargs(self): def _verify_kwargs_against_default_kwargs(self):
""" """
......
...@@ -42,6 +42,7 @@ class Nestle(NestedSampler): ...@@ -42,6 +42,7 @@ class Nestle(NestedSampler):
) )
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "npoints" not in kwargs: if "npoints" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -50,6 +50,7 @@ class PyPolyChord(NestedSampler): ...@@ -50,6 +50,7 @@ class PyPolyChord(NestedSampler):
nlives={}, nlives={},
) )
hard_exit = True hard_exit = True
sampling_seed_key = "seed"
@signal_wrapper @signal_wrapper
def run_sampler(self): def run_sampler(self):
...@@ -100,6 +101,7 @@ class PyPolyChord(NestedSampler): ...@@ -100,6 +101,7 @@ class PyPolyChord(NestedSampler):
self.kwargs["num_repeats"] = self.ndim * 5 self.kwargs["num_repeats"] = self.ndim * 5
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "nlive" not in kwargs: if "nlive" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -286,6 +286,7 @@ class Ptemcee(MCMCSampler): ...@@ -286,6 +286,7 @@ class Ptemcee(MCMCSampler):
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
"""Translate kwargs""" """Translate kwargs"""
kwargs = super()._translate_kwargs(kwargs)
if "nwalkers" not in kwargs: if "nwalkers" not in kwargs:
for equiv in self.nwalkers_equiv_kwargs: for equiv in self.nwalkers_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -116,6 +116,7 @@ class PTMCMCSampler(MCMCSampler): ...@@ -116,6 +116,7 @@ class PTMCMCSampler(MCMCSampler):
) )
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "Niter" not in kwargs: if "Niter" not in kwargs:
for equiv in self.nwalkers_equiv_kwargs: for equiv in self.nwalkers_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -91,6 +91,8 @@ class Pymc3(MCMCSampler): ...@@ -91,6 +91,8 @@ class Pymc3(MCMCSampler):
default_kwargs.update(default_nuts_kwargs) default_kwargs.update(default_nuts_kwargs)
sampling_seed_key = "random_seed"
def __init__( def __init__(
self, self,
likelihood, likelihood,
......
...@@ -62,6 +62,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler): ...@@ -62,6 +62,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler):
) )
short_name = "pm" short_name = "pm"
hard_exit = True hard_exit = True
sampling_seed_key = "seed"
def __init__( def __init__(
self, self,
...@@ -104,6 +105,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler): ...@@ -104,6 +105,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler):
self.use_temporary_directory = temporary_directory and not using_mpi self.use_temporary_directory = temporary_directory and not using_mpi
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "n_live_points" not in kwargs: if "n_live_points" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -104,6 +104,7 @@ class Ultranest(_TemporaryFileSamplerMixin, NestedSampler): ...@@ -104,6 +104,7 @@ class Ultranest(_TemporaryFileSamplerMixin, NestedSampler):
self.callback_interval = callback_interval self.callback_interval = callback_interval
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "num_live_points" not in kwargs: if "num_live_points" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -62,6 +62,15 @@ class TestSampler(unittest.TestCase): ...@@ -62,6 +62,15 @@ class TestSampler(unittest.TestCase):
def test_label(self): def test_label(self):
self.assertEqual(self.sampler.label, "label") self.assertEqual(self.sampler.label, "label")
@parameterized.expand(["sampling_seed", "seed", "random_seed"])
def test_translate_kwargs(self, key):
self.sampler.sampling_seed_key = key
for k in self.sampler.sampling_seed_equiv_kwargs:
kwargs = {k: 1234}
updated_kwargs = self.sampler._translate_kwargs(kwargs)
self.assertDictEqual(updated_kwargs, {key: 1234})
self.sampler.sampling_seed_key = None
def test_prior_transform_transforms_search_parameter_keys(self): def test_prior_transform_transforms_search_parameter_keys(self):
self.sampler.prior_transform([0]) self.sampler.prior_transform([0])
expected_prior = prior.Uniform(0, 1) expected_prior = prior.Uniform(0, 1)
......
...@@ -20,9 +20,11 @@ class TestNessai(unittest.TestCase): ...@@ -20,9 +20,11 @@ class TestNessai(unittest.TestCase):
use_ratio=False, use_ratio=False,
plot=False, plot=False,
skip_import_verification=True, skip_import_verification=True,
sampling_seed=150914,
) )
self.expected = self.sampler.default_kwargs self.expected = self.sampler.default_kwargs
self.expected['output'] = 'outdir/label_nessai/' self.expected['output'] = 'outdir/label_nessai/'
self.expected['seed'] = 150914
def tearDown(self): def tearDown(self):
del self.likelihood del self.likelihood
...@@ -54,14 +56,7 @@ class TestNessai(unittest.TestCase): ...@@ -54,14 +56,7 @@ class TestNessai(unittest.TestCase):
self.assertDictEqual(expected, self.sampler.kwargs) self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs_seed(self): def test_translate_kwargs_seed(self):
expected = self.expected.copy() assert self.expected["seed"] == 150914
expected["seed"] = 150914
for equiv in bilby.core.sampler.nessai.Nessai.seed_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
del new_kwargs["seed"]
new_kwargs[equiv] = 150914
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)
def test_npool_max_threads(self): def test_npool_max_threads(self):
expected = self.expected.copy() expected = self.expected.copy()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment