Skip to content
Snippets Groups Projects
Commit 99ba5346 authored by Michael Williams's avatar Michael Williams Committed by Colm Talbot
Browse files

Improve how sampling seed is handled

parent c761682e
No related branches found
Tags 0.5.0
1 merge request!1134Improve how sampling seed is handled
Showing with 47 additions and 15 deletions
...@@ -202,6 +202,7 @@ class Bilby_MCMC(MCMCSampler): ...@@ -202,6 +202,7 @@ class Bilby_MCMC(MCMCSampler):
logger.warning("Burn-in inefficiency fraction greater than 10%") logger.warning("Burn-in inefficiency fraction greater than 10%")
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "printdt" not in kwargs: if "printdt" not in kwargs:
for equiv in ["print_dt", "print_update"]: for equiv in ["print_dt", "print_update"]:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -196,7 +196,13 @@ class Sampler(object): ...@@ -196,7 +196,13 @@ class Sampler(object):
"cores", "cores",
"n_pool", "n_pool",
] ]
sampling_seed_equiv_kwargs = ["sampling_seed", "seed", "random_seed"]
hard_exit = False hard_exit = False
sampling_seed_key = None
"""Name of keyword argument for setting the sampling for the specific sampler.
If a specific sampler does not have a sampling seed option, then it should be
left as None.
"""
def __init__( def __init__(
self, self,
...@@ -289,8 +295,16 @@ class Sampler(object): ...@@ -289,8 +295,16 @@ class Sampler(object):
self._verify_kwargs_against_default_kwargs() self._verify_kwargs_against_default_kwargs()
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
"""Template for child classes""" """Translate keyword arguments.
pass
Default only translates the sampling seed if the sampler has
:code:`sampling_seed_key` set.
"""
if self.sampling_seed_key and self.sampling_seed_key not in kwargs:
for equiv in self.sampling_seed_equiv_kwargs:
if equiv in kwargs:
kwargs[self.sampling_seed_key] = kwargs.pop(equiv)
return kwargs
@property @property
def external_sampler_name(self): def external_sampler_name(self):
......
...@@ -54,8 +54,10 @@ class Cpnest(NestedSampler): ...@@ -54,8 +54,10 @@ class Cpnest(NestedSampler):
n_periodic_checkpoint=8000, n_periodic_checkpoint=8000,
) )
hard_exit = True hard_exit = True
sampling_seed_key = "seed"
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "nlive" not in kwargs: if "nlive" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -114,6 +114,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler): ...@@ -114,6 +114,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler):
) )
short_name = "dn4" short_name = "dn4"
hard_exit = True hard_exit = True
sampling_seed_key = "seed"
def __init__( def __init__(
self, self,
...@@ -254,6 +255,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler): ...@@ -254,6 +255,7 @@ class DNest4(_TemporaryFileSamplerMixin, NestedSampler):
return self.result return self.result
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "num_steps" not in kwargs: if "num_steps" not in kwargs:
for equiv in self.walks_equiv_kwargs: for equiv in self.walks_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -239,6 +239,7 @@ class Dynesty(NestedSampler): ...@@ -239,6 +239,7 @@ class Dynesty(NestedSampler):
} }
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "nlive" not in kwargs: if "nlive" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -113,6 +113,7 @@ class Emcee(MCMCSampler): ...@@ -113,6 +113,7 @@ class Emcee(MCMCSampler):
return emcee return emcee
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "nwalkers" not in kwargs: if "nwalkers" not in kwargs:
for equiv in self.nwalkers_equiv_kwargs: for equiv in self.nwalkers_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -19,7 +19,7 @@ class Nessai(NestedSampler): ...@@ -19,7 +19,7 @@ class Nessai(NestedSampler):
""" """
_default_kwargs = None _default_kwargs = None
seed_equiv_kwargs = ["sampling_seed"] sampling_seed_key = "seed"
@property @property
def default_kwargs(self): def default_kwargs(self):
...@@ -165,6 +165,7 @@ class Nessai(NestedSampler): ...@@ -165,6 +165,7 @@ class Nessai(NestedSampler):
return self.result return self.result
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
super()._translate_kwargs(kwargs)
if "nlive" not in kwargs: if "nlive" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
...@@ -175,10 +176,6 @@ class Nessai(NestedSampler): ...@@ -175,10 +176,6 @@ class Nessai(NestedSampler):
kwargs["n_pool"] = kwargs.pop(equiv) kwargs["n_pool"] = kwargs.pop(equiv)
if "n_pool" not in kwargs: if "n_pool" not in kwargs:
kwargs["n_pool"] = self._npool kwargs["n_pool"] = self._npool
if "seed" not in kwargs:
for equiv in self.seed_equiv_kwargs:
if equiv in kwargs:
kwargs["seed"] = kwargs.pop(equiv)
def _verify_kwargs_against_default_kwargs(self): def _verify_kwargs_against_default_kwargs(self):
""" """
......
...@@ -42,6 +42,7 @@ class Nestle(NestedSampler): ...@@ -42,6 +42,7 @@ class Nestle(NestedSampler):
) )
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "npoints" not in kwargs: if "npoints" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -50,6 +50,7 @@ class PyPolyChord(NestedSampler): ...@@ -50,6 +50,7 @@ class PyPolyChord(NestedSampler):
nlives={}, nlives={},
) )
hard_exit = True hard_exit = True
sampling_seed_key = "seed"
@signal_wrapper @signal_wrapper
def run_sampler(self): def run_sampler(self):
...@@ -100,6 +101,7 @@ class PyPolyChord(NestedSampler): ...@@ -100,6 +101,7 @@ class PyPolyChord(NestedSampler):
self.kwargs["num_repeats"] = self.ndim * 5 self.kwargs["num_repeats"] = self.ndim * 5
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "nlive" not in kwargs: if "nlive" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -286,6 +286,7 @@ class Ptemcee(MCMCSampler): ...@@ -286,6 +286,7 @@ class Ptemcee(MCMCSampler):
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
"""Translate kwargs""" """Translate kwargs"""
kwargs = super()._translate_kwargs(kwargs)
if "nwalkers" not in kwargs: if "nwalkers" not in kwargs:
for equiv in self.nwalkers_equiv_kwargs: for equiv in self.nwalkers_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -116,6 +116,7 @@ class PTMCMCSampler(MCMCSampler): ...@@ -116,6 +116,7 @@ class PTMCMCSampler(MCMCSampler):
) )
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "Niter" not in kwargs: if "Niter" not in kwargs:
for equiv in self.nwalkers_equiv_kwargs: for equiv in self.nwalkers_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -91,6 +91,8 @@ class Pymc3(MCMCSampler): ...@@ -91,6 +91,8 @@ class Pymc3(MCMCSampler):
default_kwargs.update(default_nuts_kwargs) default_kwargs.update(default_nuts_kwargs)
sampling_seed_key = "random_seed"
def __init__( def __init__(
self, self,
likelihood, likelihood,
......
...@@ -62,6 +62,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler): ...@@ -62,6 +62,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler):
) )
short_name = "pm" short_name = "pm"
hard_exit = True hard_exit = True
sampling_seed_key = "seed"
def __init__( def __init__(
self, self,
...@@ -104,6 +105,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler): ...@@ -104,6 +105,7 @@ class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler):
self.use_temporary_directory = temporary_directory and not using_mpi self.use_temporary_directory = temporary_directory and not using_mpi
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "n_live_points" not in kwargs: if "n_live_points" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -104,6 +104,7 @@ class Ultranest(_TemporaryFileSamplerMixin, NestedSampler): ...@@ -104,6 +104,7 @@ class Ultranest(_TemporaryFileSamplerMixin, NestedSampler):
self.callback_interval = callback_interval self.callback_interval = callback_interval
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
kwargs = super()._translate_kwargs(kwargs)
if "num_live_points" not in kwargs: if "num_live_points" not in kwargs:
for equiv in self.npoints_equiv_kwargs: for equiv in self.npoints_equiv_kwargs:
if equiv in kwargs: if equiv in kwargs:
......
...@@ -62,6 +62,15 @@ class TestSampler(unittest.TestCase): ...@@ -62,6 +62,15 @@ class TestSampler(unittest.TestCase):
def test_label(self): def test_label(self):
self.assertEqual(self.sampler.label, "label") self.assertEqual(self.sampler.label, "label")
@parameterized.expand(["sampling_seed", "seed", "random_seed"])
def test_translate_kwargs(self, key):
self.sampler.sampling_seed_key = key
for k in self.sampler.sampling_seed_equiv_kwargs:
kwargs = {k: 1234}
updated_kwargs = self.sampler._translate_kwargs(kwargs)
self.assertDictEqual(updated_kwargs, {key: 1234})
self.sampler.sampling_seed_key = None
def test_prior_transform_transforms_search_parameter_keys(self): def test_prior_transform_transforms_search_parameter_keys(self):
self.sampler.prior_transform([0]) self.sampler.prior_transform([0])
expected_prior = prior.Uniform(0, 1) expected_prior = prior.Uniform(0, 1)
......
...@@ -20,9 +20,11 @@ class TestNessai(unittest.TestCase): ...@@ -20,9 +20,11 @@ class TestNessai(unittest.TestCase):
use_ratio=False, use_ratio=False,
plot=False, plot=False,
skip_import_verification=True, skip_import_verification=True,
sampling_seed=150914,
) )
self.expected = self.sampler.default_kwargs self.expected = self.sampler.default_kwargs
self.expected['output'] = 'outdir/label_nessai/' self.expected['output'] = 'outdir/label_nessai/'
self.expected['seed'] = 150914
def tearDown(self): def tearDown(self):
del self.likelihood del self.likelihood
...@@ -54,14 +56,7 @@ class TestNessai(unittest.TestCase): ...@@ -54,14 +56,7 @@ class TestNessai(unittest.TestCase):
self.assertDictEqual(expected, self.sampler.kwargs) self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs_seed(self): def test_translate_kwargs_seed(self):
expected = self.expected.copy() assert self.expected["seed"] == 150914
expected["seed"] = 150914
for equiv in bilby.core.sampler.nessai.Nessai.seed_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
del new_kwargs["seed"]
new_kwargs[equiv] = 150914
self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs)
def test_npool_max_threads(self): def test_npool_max_threads(self):
expected = self.expected.copy() expected = self.expected.copy()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment