From 2224e43caa2235e1e14b105ab776522975f78f96 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Thu, 2 Apr 2020 09:13:21 -0400
Subject: [PATCH] Neaten up some logic

---
 bilby/core/sampler/dynesty.py | 66 ++++++++++++++++++++---------------
 1 file changed, 37 insertions(+), 29 deletions(-)

diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index 342fad2f..1ecbf29e 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -259,6 +259,39 @@ class Dynesty(NestedSampler):
         self.kwargs["periodic"] = self._periodic
         self.kwargs["reflective"] = self._reflective
 
+    def _setup_pool(self):
+        if self.kwargs["pool"] is not None:
+            logger.info("Using user defined pool.")
+            self.pool = self.kwargs["pool"]
+        elif self.kwargs["queue_size"] > 1:
+            logger.info(
+                "Setting up multiproccesing pool with {} processes.".format(
+                    self.kwargs["queue_size"]
+                )
+            )
+            import multiprocessing
+            self.pool = multiprocessing.Pool(
+                processes=self.kwargs["queue_size"],
+                initializer=_initialize_global_variables,
+                initargs=(self.likelihood, self.priors, self._search_parameter_keys)
+            )
+            self.kwargs["pool"] = self.pool
+        else:
+            _initialize_global_variables(
+                likelihood=self.likelihood,
+                priors=self.priors,
+                search_parameter_keys=self._search_parameter_keys
+            )
+            self.pool = None
+
+    def _close_pool(self):
+        if getattr(self, "pool", None) is not None:
+            logger.info("Starting to close worker pool.")
+            self.pool.close()
+            self.pool.join()
+            self.pool = None
+            logger.info("Finished closing worker pool.")
+
     def run_sampler(self):
         import dynesty
         logger.info("Using dynesty version {}".format(dynesty.__version__))
@@ -285,26 +318,7 @@ class Dynesty(NestedSampler):
             logger.info(
                 "Using the dynesty-implemented rstagger sample method")
 
-        if self.kwargs["queue_size"] > 1:
-            logger.info(
-                "Setting up multiproccesing pool with {} processes.".format(
-                    self.kwargs["queue_size"]
-                )
-            )
-            import multiprocessing
-            self.pool = multiprocessing.Pool(
-                processes=self.kwargs["queue_size"],
-                initializer=_initialize_global_variables,
-                initargs=(self.likelihood, self.priors, self._search_parameter_keys)
-            )
-            self.kwargs["pool"] = self.pool
-        else:
-            _initialize_global_variables(
-                likelihood=self.likelihood,
-                priors=self.priors,
-                search_parameter_keys=self._search_parameter_keys
-            )
-            self.pool = None
+        self._setup_pool()
 
         self.sampler = dynesty.NestedSampler(
             loglikelihood=_log_likelihood_wrapper,
@@ -316,11 +330,7 @@ class Dynesty(NestedSampler):
         else:
             out = self._run_external_sampler_without_checkpointing()
 
-        if self.kwargs["queue_size"] > 1:
-            # stop and remove the pool object as it can't be json serialised
-            self.pool.close()
-            self.pool.join()
-            self.kwargs["pool"] = None
+        self._close_pool()
 
         # Flushes the output to force a line break
         if self.kwargs["verbose"]:
@@ -477,12 +487,10 @@ class Dynesty(NestedSampler):
         Make sure that if a pool of jobs is running only the parent tries to
         checkpoint and exit. Only the parent has a 'pool' attribute.
         """
-        if hasattr(self, "pool"):
+        if getattr(self, "pool", None) is not None:
             logger.warning("Run terminated with signal {}".format(signum))
-            if self.pool is not None:
-                self.pool.close()
-                self.pool.join()
             self.write_current_state(plot=False)
+            self._close_pool()
             sys.exit(130)
 
     def write_current_state(self, plot=True):
-- 
GitLab