From 7f9f9cf31036944274f44a1b571e39b5477e10be Mon Sep 17 00:00:00 2001
From: Sylvia Biscoveanu <sylvia.biscoveanu@ligo.org>
Date: Mon, 29 Nov 2021 18:08:22 +0000
Subject: [PATCH] Change using_mpi test

---
 bilby/core/sampler/pymultinest.py | 33 ++++++++++++++++++++-----------
 1 file changed, 21 insertions(+), 12 deletions(-)

diff --git a/bilby/core/sampler/pymultinest.py b/bilby/core/sampler/pymultinest.py
index 91356d146..d98693622 100644
--- a/bilby/core/sampler/pymultinest.py
+++ b/bilby/core/sampler/pymultinest.py
@@ -79,6 +79,12 @@ class Pymultinest(NestedSampler):
         temporary_directory=True,
         **kwargs
     ):
+        try:
+            from mpi4py import MPI
+
+            using_mpi = MPI.COMM_WORLD.Get_size() > 1
+        except ImportError:
+            using_mpi = False
         super(Pymultinest, self).__init__(
             likelihood=likelihood,
             priors=priors,
@@ -92,7 +98,6 @@ class Pymultinest(NestedSampler):
         )
         self._apply_multinest_boundaries()
         self.exit_code = exit_code
-        using_mpi = len([key for key in os.environ if "MPI" in key])
         if using_mpi and temporary_directory:
             logger.info(
                 "Temporary directory incompatible with MPI, "
@@ -111,15 +116,15 @@ class Pymultinest(NestedSampler):
                     kwargs["n_live_points"] = kwargs.pop(equiv)
 
     def _verify_kwargs_against_default_kwargs(self):
-        """ Check the kwargs """
+        """Check the kwargs"""
 
         self.outputfiles_basename = self.kwargs.pop("outputfiles_basename", None)
 
         # for PyMultiNest >=2.9 the n_params kwarg cannot be None
         if self.kwargs["n_params"] is None:
             self.kwargs["n_params"] = self.ndim
-        if self.kwargs['dump_callback'] is None:
-            self.kwargs['dump_callback'] = self._dump_callback
+        if self.kwargs["dump_callback"] is None:
+            self.kwargs["dump_callback"] = self._dump_callback
         NestedSampler._verify_kwargs_against_default_kwargs(self)
 
     def _dump_callback(self, *args, **kwargs):
@@ -166,7 +171,7 @@ class Pymultinest(NestedSampler):
             )
 
     def write_current_state_and_exit(self, signum=None, frame=None):
-        """ Write current state and exit on exit_code """
+        """Write current state and exit on exit_code"""
         logger.info(
             "Run interrupted by signal {}: checkpoint and exit on {}".format(
                 signum, self.exit_code
@@ -187,11 +192,13 @@ class Pymultinest(NestedSampler):
                 self.outputfiles_basename, self.temporary_outputfiles_basename
             )
         )
-        if self.outputfiles_basename.endswith('/'):
+        if self.outputfiles_basename.endswith("/"):
             outputfiles_basename_stripped = self.outputfiles_basename[:-1]
         else:
             outputfiles_basename_stripped = self.outputfiles_basename
-        distutils.dir_util.copy_tree(self.temporary_outputfiles_basename, outputfiles_basename_stripped)
+        distutils.dir_util.copy_tree(
+            self.temporary_outputfiles_basename, outputfiles_basename_stripped
+        )
 
     def _move_temporary_directory_to_proper_path(self):
         """
@@ -241,9 +248,9 @@ class Pymultinest(NestedSampler):
         return self.result
 
     def _check_and_load_sampling_time_file(self):
-        self.time_file_path = self.kwargs["outputfiles_basename"] + '/sampling_time.dat'
+        self.time_file_path = self.kwargs["outputfiles_basename"] + "/sampling_time.dat"
         if os.path.exists(self.time_file_path):
-            with open(self.time_file_path, 'r') as time_file:
+            with open(self.time_file_path, "r") as time_file:
                 self.total_sampling_time = float(time_file.readline())
         else:
             self.total_sampling_time = 0
@@ -253,7 +260,7 @@ class Pymultinest(NestedSampler):
         new_sampling_time = current_time - self.start_time
         self.total_sampling_time += new_sampling_time
         self.start_time = current_time
-        with open(self.time_file_path, 'w') as time_file:
+        with open(self.time_file_path, "w") as time_file:
             time_file.write(str(self.total_sampling_time))
 
     def _clean_up_run_directory(self):
@@ -271,16 +278,18 @@ class Pymultinest(NestedSampler):
         estimate of `remaining_prior_volume / N`.
         """
         import pandas as pd
+
         dir_ = self.kwargs["outputfiles_basename"]
         dead_points = np.genfromtxt(dir_ + "/ev.dat")
         live_points = np.genfromtxt(dir_ + "/phys_live.points")
 
         nlive = self.kwargs["n_live_points"]
-        final_log_prior_volume = - len(dead_points) / nlive - np.log(nlive)
+        final_log_prior_volume = -len(dead_points) / nlive - np.log(nlive)
         live_points = np.insert(live_points, -1, final_log_prior_volume, axis=-1)
 
         nested_samples = pd.DataFrame(
             np.vstack([dead_points, live_points]).copy(),
-            columns=self.search_parameter_keys + ["log_likelihood", "log_prior_volume", "mode"]
+            columns=self.search_parameter_keys
+            + ["log_likelihood", "log_prior_volume", "mode"],
         )
         return nested_samples
-- 
GitLab