Commit c4987ddd authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'dynesty-multiprocessing' into 'master'

Allow dynesty to run with multiprocessing

See merge request !754
parents 2ca8a475 104e29a5
Pipeline #118689 failed with stages
in 60 minutes and 1 second
......@@ -21,6 +21,49 @@ from dynesty.utils import unitcheck
import warnings
_likelihood = None
_priors = None
_search_parameter_keys = None
_use_ratio = False
def _initialize_global_variables(
likelihood, priors, search_parameter_keys, use_ratio
):
"""
Store a global copy of the likelihood, priors, and search keys for
multiprocessing.
"""
global _likelihood
global _priors
global _search_parameter_keys
global _use_ratio
_likelihood = likelihood
_priors = priors
_search_parameter_keys = search_parameter_keys
_use_ratio = use_ratio
def _prior_transform_wrapper(theta):
"""Wrapper to the prior transformation. Needed for multiprocessing."""
return _priors.rescale(_search_parameter_keys, theta)
def _log_likelihood_wrapper(theta):
"""Wrapper to the log likelihood. Needed for multiprocessing."""
if _priors.evaluate_constraints({
key: theta[ii] for ii, key in enumerate(_search_parameter_keys)
}):
params = {key: t for key, t in zip(_search_parameter_keys, theta)}
_likelihood.parameters.update(params)
if _use_ratio:
return _likelihood.log_likelihood_ratio()
else:
return _likelihood.log_likelihood()
else:
return np.nan_to_num(-np.inf)
class Dynesty(NestedSampler):
"""
bilby wrapper of `dynesty.NestedSampler`
......@@ -85,7 +128,7 @@ class Dynesty(NestedSampler):
verbose=True, periodic=None, reflective=None,
check_point_delta_t=600, nlive=1000,
first_update=None, walks=100,
npdim=None, rstate=None, queue_size=None, pool=None,
npdim=None, rstate=None, queue_size=1, pool=None,
use_pool=None, live_points=None,
logl_args=None, logl_kwargs=None,
ptform_args=None, ptform_kwargs=None,
......@@ -224,6 +267,46 @@ class Dynesty(NestedSampler):
self.kwargs["periodic"] = self._periodic
self.kwargs["reflective"] = self._reflective
def _setup_pool(self):
if self.kwargs["pool"] is not None:
logger.info("Using user defined pool.")
self.pool = self.kwargs["pool"]
elif self.kwargs["queue_size"] > 1:
logger.info(
"Setting up multiproccesing pool with {} processes.".format(
self.kwargs["queue_size"]
)
)
import multiprocessing
self.pool = multiprocessing.Pool(
processes=self.kwargs["queue_size"],
initializer=_initialize_global_variables,
initargs=(
self.likelihood,
self.priors,
self._search_parameter_keys,
self.use_ratio
)
)
else:
_initialize_global_variables(
likelihood=self.likelihood,
priors=self.priors,
search_parameter_keys=self._search_parameter_keys,
use_ratio=self.use_ratio
)
self.pool = None
self.kwargs["pool"] = self.pool
def _close_pool(self):
if getattr(self, "pool", None) is not None:
logger.info("Starting to close worker pool.")
self.pool.close()
self.pool.join()
self.pool = None
self.kwargs["pool"] = self.pool
logger.info("Finished closing worker pool.")
def run_sampler(self):
import dynesty
logger.info("Using dynesty version {}".format(dynesty.__version__))
......@@ -250,9 +333,11 @@ class Dynesty(NestedSampler):
logger.info(
"Using the dynesty-implemented rstagger sample method")
self._setup_pool()
self.sampler = dynesty.NestedSampler(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
loglikelihood=_log_likelihood_wrapper,
prior_transform=_prior_transform_wrapper,
ndim=self.ndim, **self.sampler_init_kwargs)
if self.check_point:
......@@ -260,6 +345,8 @@ class Dynesty(NestedSampler):
else:
out = self._run_external_sampler_without_checkpointing()
self._close_pool()
# Flushes the output to force a line break
if self.kwargs["verbose"]:
self.pbar.close()
......@@ -407,6 +494,11 @@ class Dynesty(NestedSampler):
self.sampler.rstate = np.random
self.start_time = self.sampler.kwargs.pop("start_time")
self.sampling_time = self.sampler.kwargs.pop("sampling_time")
self.sampler.pool = self.pool
if self.pool is not None:
self.sampler.M = self.pool.map
else:
self.sampler.M = map
return True
else:
logger.info(
......@@ -414,16 +506,22 @@ class Dynesty(NestedSampler):
return False
def write_current_state_and_exit(self, signum=None, frame=None):
if signum == 14:
logger.info(
"Run interrupted by alarm signal {}: checkpoint and exit on {}"
.format(signum, self.exit_code))
else:
logger.info(
"Run interrupted by signal {}: checkpoint and exit on {}"
.format(signum, self.exit_code))
self.write_current_state()
os._exit(self.exit_code)
"""
Make sure that if a pool of jobs is running only the parent tries to
checkpoint and exit. Only the parent has a 'pool' attribute.
"""
if self.kwargs["queue_size"] == 1 or getattr(self, "pool", None) is not None:
if signum == 14:
logger.info(
"Run interrupted by alarm signal {}: checkpoint and exit on {}"
.format(signum, self.exit_code))
else:
logger.info(
"Run interrupted by signal {}: checkpoint and exit on {}"
.format(signum, self.exit_code))
self.write_current_state()
self._close_pool()
os._exit(self.exit_code)
def write_current_state(self):
"""
......@@ -449,6 +547,8 @@ class Dynesty(NestedSampler):
self.sampler.versions = dict(
bilby=bilby_version, dynesty=dynesty_version
)
self.sampler.pool = None
self.sampler.M = map
if dill.pickles(self.sampler):
safe_file_dump(self.sampler, self.resume_file, dill)
logger.info("Written checkpoint file {}".format(self.resume_file))
......@@ -457,6 +557,9 @@ class Dynesty(NestedSampler):
"Cannot write pickle resume file! "
"Job will not resume if interrupted."
)
self.sampler.pool = self.pool
if self.sampler.pool is not None:
self.sampler.M = self.sampler.pool.map
def plot_current_state(self):
if self.check_point_plot:
......
......@@ -151,7 +151,7 @@ class TestDynesty(unittest.TestCase):
def test_default_kwargs(self):
expected = dict(bound='multi', sample='rwalk', periodic=None, reflective=None, verbose=True,
check_point_delta_t=600, nlive=1000, first_update=None,
npdim=None, rstate=None, queue_size=None, pool=None,
npdim=None, rstate=None, queue_size=1, pool=None,
use_pool=None, live_points=None, logl_args=None, logl_kwargs=None,
ptform_args=None, ptform_kwargs=None,
enlarge=1.5, bootstrap=None, vol_dec=0.5, vol_check=8.0,
......@@ -173,7 +173,7 @@ class TestDynesty(unittest.TestCase):
def test_translate_kwargs(self):
expected = dict(bound='multi', sample='rwalk', periodic=[], reflective=[], verbose=True,
check_point_delta_t=600, nlive=1000, first_update=None,
npdim=None, rstate=None, queue_size=None, pool=None,
npdim=None, rstate=None, queue_size=1, pool=None,
use_pool=None, live_points=None, logl_args=None, logl_kwargs=None,
ptform_args=None, ptform_kwargs=None,
enlarge=1.5, bootstrap=None, vol_dec=0.5, vol_check=8.0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment