From 2224e43caa2235e1e14b105ab776522975f78f96 Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Thu, 2 Apr 2020 09:13:21 -0400 Subject: [PATCH] Neaten up some logic --- bilby/core/sampler/dynesty.py | 66 ++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 342fad2f..1ecbf29e 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -259,6 +259,39 @@ class Dynesty(NestedSampler): self.kwargs["periodic"] = self._periodic self.kwargs["reflective"] = self._reflective + def _setup_pool(self): + if self.kwargs["pool"] is not None: + logger.info("Using user defined pool.") + self.pool = self.kwargs["pool"] + elif self.kwargs["queue_size"] > 1: + logger.info( + "Setting up multiproccesing pool with {} processes.".format( + self.kwargs["queue_size"] + ) + ) + import multiprocessing + self.pool = multiprocessing.Pool( + processes=self.kwargs["queue_size"], + initializer=_initialize_global_variables, + initargs=(self.likelihood, self.priors, self._search_parameter_keys) + ) + self.kwargs["pool"] = self.pool + else: + _initialize_global_variables( + likelihood=self.likelihood, + priors=self.priors, + search_parameter_keys=self._search_parameter_keys + ) + self.pool = None + + def _close_pool(self): + if getattr(self, "pool", None) is not None: + logger.info("Starting to close worker pool.") + self.pool.close() + self.pool.join() + self.pool = None + logger.info("Finished closing worker pool.") + def run_sampler(self): import dynesty logger.info("Using dynesty version {}".format(dynesty.__version__)) @@ -285,26 +318,7 @@ class Dynesty(NestedSampler): logger.info( "Using the dynesty-implemented rstagger sample method") - if self.kwargs["queue_size"] > 1: - logger.info( - "Setting up multiproccesing pool with {} processes.".format( - self.kwargs["queue_size"] - ) - ) - import multiprocessing - self.pool = multiprocessing.Pool( - processes=self.kwargs["queue_size"], - initializer=_initialize_global_variables, - initargs=(self.likelihood, self.priors, self._search_parameter_keys) - ) - self.kwargs["pool"] = self.pool - else: - _initialize_global_variables( - likelihood=self.likelihood, - priors=self.priors, - search_parameter_keys=self._search_parameter_keys - ) - self.pool = None + self._setup_pool() self.sampler = dynesty.NestedSampler( loglikelihood=_log_likelihood_wrapper, @@ -316,11 +330,7 @@ class Dynesty(NestedSampler): else: out = self._run_external_sampler_without_checkpointing() - if self.kwargs["queue_size"] > 1: - # stop and remove the pool object as it can't be json serialised - self.pool.close() - self.pool.join() - self.kwargs["pool"] = None + self._close_pool() # Flushes the output to force a line break if self.kwargs["verbose"]: @@ -477,12 +487,10 @@ class Dynesty(NestedSampler): Make sure that if a pool of jobs is running only the parent tries to checkpoint and exit. Only the parent has a 'pool' attribute. """ - if hasattr(self, "pool"): + if getattr(self, "pool", None) is not None: logger.warning("Run terminated with signal {}".format(signum)) - if self.pool is not None: - self.pool.close() - self.pool.join() self.write_current_state(plot=False) + self._close_pool() sys.exit(130) def write_current_state(self, plot=True): -- GitLab