Skip to content
Snippets Groups Projects
Commit 2224e43c authored by Colm Talbot's avatar Colm Talbot
Browse files

Neaten up some logic

parent 519dac91
No related branches found
No related tags found
No related merge requests found
......@@ -259,6 +259,39 @@ class Dynesty(NestedSampler):
self.kwargs["periodic"] = self._periodic
self.kwargs["reflective"] = self._reflective
def _setup_pool(self):
if self.kwargs["pool"] is not None:
logger.info("Using user defined pool.")
self.pool = self.kwargs["pool"]
elif self.kwargs["queue_size"] > 1:
logger.info(
"Setting up multiproccesing pool with {} processes.".format(
self.kwargs["queue_size"]
)
)
import multiprocessing
self.pool = multiprocessing.Pool(
processes=self.kwargs["queue_size"],
initializer=_initialize_global_variables,
initargs=(self.likelihood, self.priors, self._search_parameter_keys)
)
self.kwargs["pool"] = self.pool
else:
_initialize_global_variables(
likelihood=self.likelihood,
priors=self.priors,
search_parameter_keys=self._search_parameter_keys
)
self.pool = None
def _close_pool(self):
if getattr(self, "pool", None) is not None:
logger.info("Starting to close worker pool.")
self.pool.close()
self.pool.join()
self.pool = None
logger.info("Finished closing worker pool.")
def run_sampler(self):
import dynesty
logger.info("Using dynesty version {}".format(dynesty.__version__))
......@@ -285,26 +318,7 @@ class Dynesty(NestedSampler):
logger.info(
"Using the dynesty-implemented rstagger sample method")
if self.kwargs["queue_size"] > 1:
logger.info(
"Setting up multiproccesing pool with {} processes.".format(
self.kwargs["queue_size"]
)
)
import multiprocessing
self.pool = multiprocessing.Pool(
processes=self.kwargs["queue_size"],
initializer=_initialize_global_variables,
initargs=(self.likelihood, self.priors, self._search_parameter_keys)
)
self.kwargs["pool"] = self.pool
else:
_initialize_global_variables(
likelihood=self.likelihood,
priors=self.priors,
search_parameter_keys=self._search_parameter_keys
)
self.pool = None
self._setup_pool()
self.sampler = dynesty.NestedSampler(
loglikelihood=_log_likelihood_wrapper,
......@@ -316,11 +330,7 @@ class Dynesty(NestedSampler):
else:
out = self._run_external_sampler_without_checkpointing()
if self.kwargs["queue_size"] > 1:
# stop and remove the pool object as it can't be json serialised
self.pool.close()
self.pool.join()
self.kwargs["pool"] = None
self._close_pool()
# Flushes the output to force a line break
if self.kwargs["verbose"]:
......@@ -477,12 +487,10 @@ class Dynesty(NestedSampler):
Make sure that if a pool of jobs is running only the parent tries to
checkpoint and exit. Only the parent has a 'pool' attribute.
"""
if hasattr(self, "pool"):
if getattr(self, "pool", None) is not None:
logger.warning("Run terminated with signal {}".format(signum))
if self.pool is not None:
self.pool.close()
self.pool.join()
self.write_current_state(plot=False)
self._close_pool()
sys.exit(130)
def write_current_state(self, plot=True):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment