Skip to content
Snippets Groups Projects
Commit c4db9452 authored by Colm Talbot's avatar Colm Talbot
Browse files

Allow dynesty to run with multiprocessing

parent c54802e5
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,41 @@ from dynesty.utils import unitcheck
import warnings
_likelihood = None
_priors = None
_search_parameter_keys = None
def initialize_likelihood_and_prior(likelihood, priors, search_parameter_keys):
"""
Store a global copy of the likelihood, priors, and search keys for
multiprocessing.
"""
global _likelihood
global _priors
global _search_parameter_keys
_likelihood = likelihood
_priors = priors
_search_parameter_keys = search_parameter_keys
def _prior_transform_wrapper(theta):
"""Wrapper to the prior transformation. Needed for multiprocessing."""
return _priors.rescale(_search_parameter_keys, theta)
def _log_likelihood_wrapper(theta):
"""Wrapper to the log likelihood. Needed for multiprocessing."""
if _priors.evaluate_constraints({
key: theta[ii] for ii, key in enumerate(_search_parameter_keys)
}):
params = {key: t for key, t in zip(_search_parameter_keys, theta)}
_likelihood.parameters.update(params)
return _likelihood.log_likelihood_ratio()
else:
return np.nan_to_num(-np.inf)
class Dynesty(NestedSampler):
"""
bilby wrapper of `dynesty.NestedSampler`
......@@ -83,7 +118,7 @@ class Dynesty(NestedSampler):
verbose=True, periodic=None, reflective=None,
check_point_delta_t=600, nlive=1000,
first_update=None, walks=100,
npdim=None, rstate=None, queue_size=None, pool=None,
npdim=None, rstate=None, queue_size=1, pool=None,
use_pool=None, live_points=None,
logl_args=None, logl_kwargs=None,
ptform_args=None, ptform_kwargs=None,
......@@ -250,9 +285,30 @@ class Dynesty(NestedSampler):
logger.info(
"Using the dynesty-implemented rstagger sample method")
if self.kwargs["queue_size"] > 1:
logger.info(
"Setting up multiproccesing pool with {} processes.".format(
self.kwargs["queue_size"]
)
)
import multiprocessing
self.pool = multiprocessing.Pool(
processes=self.kwargs["queue_size"],
initializer=initialize_likelihood_and_prior,
initargs=(self.likelihood, self.priors, self._search_parameter_keys)
)
self.kwargs["pool"] = self.pool
else:
initialize_likelihood_and_prior(
likelihood=self.likelihood,
priors=self.priors,
search_parameter_keys=self._search_parameter_keys
)
self.pool = None
self.sampler = dynesty.NestedSampler(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
loglikelihood=_log_likelihood_wrapper,
prior_transform=_prior_transform_wrapper,
ndim=self.ndim, **self.sampler_init_kwargs)
if self.check_point:
......@@ -260,6 +316,12 @@ class Dynesty(NestedSampler):
else:
out = self._run_external_sampler_without_checkpointing()
if self.kwargs["queue_size"] > 1:
# stop and remove the pool object as it can't be json serialised
self.pool.close()
self.pool.join()
self.kwargs["pool"] = None
# Flushes the output to force a line break
if self.kwargs["verbose"]:
self.pbar.close()
......@@ -411,9 +473,17 @@ class Dynesty(NestedSampler):
return False
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
self.write_current_state(plot=False)
sys.exit(130)
"""
Make sure that if a pool of jobs is running only the parent tries to
checkpoint and exit. Only the parent has a 'pool' attribute.
"""
if hasattr(self, "pool"):
logger.warning("Run terminated with signal {}".format(signum))
if self.pool is not None:
self.pool.close()
self.pool.join()
self.write_current_state(plot=False)
sys.exit(130)
def write_current_state(self, plot=True):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment