diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 30e7ac6ac81aa8a704b1fc39a2a3b8520176fb77..688efc417721fa0d9b600b5e011ccf668fe2770b 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -21,6 +21,49 @@ from dynesty.utils import unitcheck import warnings +_likelihood = None +_priors = None +_search_parameter_keys = None +_use_ratio = False + + +def _initialize_global_variables( + likelihood, priors, search_parameter_keys, use_ratio +): + """ + Store a global copy of the likelihood, priors, and search keys for + multiprocessing. + """ + global _likelihood + global _priors + global _search_parameter_keys + global _use_ratio + _likelihood = likelihood + _priors = priors + _search_parameter_keys = search_parameter_keys + _use_ratio = use_ratio + + +def _prior_transform_wrapper(theta): + """Wrapper to the prior transformation. Needed for multiprocessing.""" + return _priors.rescale(_search_parameter_keys, theta) + + +def _log_likelihood_wrapper(theta): + """Wrapper to the log likelihood. Needed for multiprocessing.""" + if _priors.evaluate_constraints({ + key: theta[ii] for ii, key in enumerate(_search_parameter_keys) + }): + params = {key: t for key, t in zip(_search_parameter_keys, theta)} + _likelihood.parameters.update(params) + if _use_ratio: + return _likelihood.log_likelihood_ratio() + else: + return _likelihood.log_likelihood() + else: + return np.nan_to_num(-np.inf) + + class Dynesty(NestedSampler): """ bilby wrapper of `dynesty.NestedSampler` @@ -85,7 +128,7 @@ class Dynesty(NestedSampler): verbose=True, periodic=None, reflective=None, check_point_delta_t=600, nlive=1000, first_update=None, walks=100, - npdim=None, rstate=None, queue_size=None, pool=None, + npdim=None, rstate=None, queue_size=1, pool=None, use_pool=None, live_points=None, logl_args=None, logl_kwargs=None, ptform_args=None, ptform_kwargs=None, @@ -224,6 +267,46 @@ class Dynesty(NestedSampler): self.kwargs["periodic"] = self._periodic self.kwargs["reflective"] = self._reflective + def _setup_pool(self): + if self.kwargs["pool"] is not None: + logger.info("Using user defined pool.") + self.pool = self.kwargs["pool"] + elif self.kwargs["queue_size"] > 1: + logger.info( + "Setting up multiproccesing pool with {} processes.".format( + self.kwargs["queue_size"] + ) + ) + import multiprocessing + self.pool = multiprocessing.Pool( + processes=self.kwargs["queue_size"], + initializer=_initialize_global_variables, + initargs=( + self.likelihood, + self.priors, + self._search_parameter_keys, + self.use_ratio + ) + ) + else: + _initialize_global_variables( + likelihood=self.likelihood, + priors=self.priors, + search_parameter_keys=self._search_parameter_keys, + use_ratio=self.use_ratio + ) + self.pool = None + self.kwargs["pool"] = self.pool + + def _close_pool(self): + if getattr(self, "pool", None) is not None: + logger.info("Starting to close worker pool.") + self.pool.close() + self.pool.join() + self.pool = None + self.kwargs["pool"] = self.pool + logger.info("Finished closing worker pool.") + def run_sampler(self): import dynesty logger.info("Using dynesty version {}".format(dynesty.__version__)) @@ -250,9 +333,11 @@ class Dynesty(NestedSampler): logger.info( "Using the dynesty-implemented rstagger sample method") + self._setup_pool() + self.sampler = dynesty.NestedSampler( - loglikelihood=self.log_likelihood, - prior_transform=self.prior_transform, + loglikelihood=_log_likelihood_wrapper, + prior_transform=_prior_transform_wrapper, ndim=self.ndim, **self.sampler_init_kwargs) if self.check_point: @@ -260,6 +345,8 @@ class Dynesty(NestedSampler): else: out = self._run_external_sampler_without_checkpointing() + self._close_pool() + # Flushes the output to force a line break if self.kwargs["verbose"]: self.pbar.close() @@ -407,6 +494,11 @@ class Dynesty(NestedSampler): self.sampler.rstate = np.random self.start_time = self.sampler.kwargs.pop("start_time") self.sampling_time = self.sampler.kwargs.pop("sampling_time") + self.sampler.pool = self.pool + if self.pool is not None: + self.sampler.M = self.pool.map + else: + self.sampler.M = map return True else: logger.info( @@ -414,16 +506,22 @@ class Dynesty(NestedSampler): return False def write_current_state_and_exit(self, signum=None, frame=None): - if signum == 14: - logger.info( - "Run interrupted by alarm signal {}: checkpoint and exit on {}" - .format(signum, self.exit_code)) - else: - logger.info( - "Run interrupted by signal {}: checkpoint and exit on {}" - .format(signum, self.exit_code)) - self.write_current_state() - os._exit(self.exit_code) + """ + Make sure that if a pool of jobs is running only the parent tries to + checkpoint and exit. Only the parent has a 'pool' attribute. + """ + if self.kwargs["queue_size"] == 1 or getattr(self, "pool", None) is not None: + if signum == 14: + logger.info( + "Run interrupted by alarm signal {}: checkpoint and exit on {}" + .format(signum, self.exit_code)) + else: + logger.info( + "Run interrupted by signal {}: checkpoint and exit on {}" + .format(signum, self.exit_code)) + self.write_current_state() + self._close_pool() + os._exit(self.exit_code) def write_current_state(self): """ @@ -449,6 +547,8 @@ class Dynesty(NestedSampler): self.sampler.versions = dict( bilby=bilby_version, dynesty=dynesty_version ) + self.sampler.pool = None + self.sampler.M = map if dill.pickles(self.sampler): safe_file_dump(self.sampler, self.resume_file, dill) logger.info("Written checkpoint file {}".format(self.resume_file)) @@ -457,6 +557,9 @@ class Dynesty(NestedSampler): "Cannot write pickle resume file! " "Job will not resume if interrupted." ) + self.sampler.pool = self.pool + if self.sampler.pool is not None: + self.sampler.M = self.sampler.pool.map def plot_current_state(self): if self.check_point_plot: diff --git a/test/sampler_test.py b/test/sampler_test.py index ddd54b72bd1a749f2c0fe2985a4db18826d9e120..e2c71d6250849a3f3740cf125eaf366dff25f377 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -151,7 +151,7 @@ class TestDynesty(unittest.TestCase): def test_default_kwargs(self): expected = dict(bound='multi', sample='rwalk', periodic=None, reflective=None, verbose=True, check_point_delta_t=600, nlive=1000, first_update=None, - npdim=None, rstate=None, queue_size=None, pool=None, + npdim=None, rstate=None, queue_size=1, pool=None, use_pool=None, live_points=None, logl_args=None, logl_kwargs=None, ptform_args=None, ptform_kwargs=None, enlarge=1.5, bootstrap=None, vol_dec=0.5, vol_check=8.0, @@ -173,7 +173,7 @@ class TestDynesty(unittest.TestCase): def test_translate_kwargs(self): expected = dict(bound='multi', sample='rwalk', periodic=[], reflective=[], verbose=True, check_point_delta_t=600, nlive=1000, first_update=None, - npdim=None, rstate=None, queue_size=None, pool=None, + npdim=None, rstate=None, queue_size=1, pool=None, use_pool=None, live_points=None, logl_args=None, logl_kwargs=None, ptform_args=None, ptform_kwargs=None, enlarge=1.5, bootstrap=None, vol_dec=0.5, vol_check=8.0,