Skip to content
Snippets Groups Projects
Commit 7f9f9cf3 authored by Sylvia Biscoveanu's avatar Sylvia Biscoveanu Committed by Colm Talbot
Browse files

Change using_mpi test

parent 6cf9a859
No related branches found
No related tags found
1 merge request!1039Change using_mpi test
...@@ -79,6 +79,12 @@ class Pymultinest(NestedSampler): ...@@ -79,6 +79,12 @@ class Pymultinest(NestedSampler):
temporary_directory=True, temporary_directory=True,
**kwargs **kwargs
): ):
try:
from mpi4py import MPI
using_mpi = MPI.COMM_WORLD.Get_size() > 1
except ImportError:
using_mpi = False
super(Pymultinest, self).__init__( super(Pymultinest, self).__init__(
likelihood=likelihood, likelihood=likelihood,
priors=priors, priors=priors,
...@@ -92,7 +98,6 @@ class Pymultinest(NestedSampler): ...@@ -92,7 +98,6 @@ class Pymultinest(NestedSampler):
) )
self._apply_multinest_boundaries() self._apply_multinest_boundaries()
self.exit_code = exit_code self.exit_code = exit_code
using_mpi = len([key for key in os.environ if "MPI" in key])
if using_mpi and temporary_directory: if using_mpi and temporary_directory:
logger.info( logger.info(
"Temporary directory incompatible with MPI, " "Temporary directory incompatible with MPI, "
...@@ -111,15 +116,15 @@ class Pymultinest(NestedSampler): ...@@ -111,15 +116,15 @@ class Pymultinest(NestedSampler):
kwargs["n_live_points"] = kwargs.pop(equiv) kwargs["n_live_points"] = kwargs.pop(equiv)
def _verify_kwargs_against_default_kwargs(self): def _verify_kwargs_against_default_kwargs(self):
""" Check the kwargs """ """Check the kwargs"""
self.outputfiles_basename = self.kwargs.pop("outputfiles_basename", None) self.outputfiles_basename = self.kwargs.pop("outputfiles_basename", None)
# for PyMultiNest >=2.9 the n_params kwarg cannot be None # for PyMultiNest >=2.9 the n_params kwarg cannot be None
if self.kwargs["n_params"] is None: if self.kwargs["n_params"] is None:
self.kwargs["n_params"] = self.ndim self.kwargs["n_params"] = self.ndim
if self.kwargs['dump_callback'] is None: if self.kwargs["dump_callback"] is None:
self.kwargs['dump_callback'] = self._dump_callback self.kwargs["dump_callback"] = self._dump_callback
NestedSampler._verify_kwargs_against_default_kwargs(self) NestedSampler._verify_kwargs_against_default_kwargs(self)
def _dump_callback(self, *args, **kwargs): def _dump_callback(self, *args, **kwargs):
...@@ -166,7 +171,7 @@ class Pymultinest(NestedSampler): ...@@ -166,7 +171,7 @@ class Pymultinest(NestedSampler):
) )
def write_current_state_and_exit(self, signum=None, frame=None): def write_current_state_and_exit(self, signum=None, frame=None):
""" Write current state and exit on exit_code """ """Write current state and exit on exit_code"""
logger.info( logger.info(
"Run interrupted by signal {}: checkpoint and exit on {}".format( "Run interrupted by signal {}: checkpoint and exit on {}".format(
signum, self.exit_code signum, self.exit_code
...@@ -187,11 +192,13 @@ class Pymultinest(NestedSampler): ...@@ -187,11 +192,13 @@ class Pymultinest(NestedSampler):
self.outputfiles_basename, self.temporary_outputfiles_basename self.outputfiles_basename, self.temporary_outputfiles_basename
) )
) )
if self.outputfiles_basename.endswith('/'): if self.outputfiles_basename.endswith("/"):
outputfiles_basename_stripped = self.outputfiles_basename[:-1] outputfiles_basename_stripped = self.outputfiles_basename[:-1]
else: else:
outputfiles_basename_stripped = self.outputfiles_basename outputfiles_basename_stripped = self.outputfiles_basename
distutils.dir_util.copy_tree(self.temporary_outputfiles_basename, outputfiles_basename_stripped) distutils.dir_util.copy_tree(
self.temporary_outputfiles_basename, outputfiles_basename_stripped
)
def _move_temporary_directory_to_proper_path(self): def _move_temporary_directory_to_proper_path(self):
""" """
...@@ -241,9 +248,9 @@ class Pymultinest(NestedSampler): ...@@ -241,9 +248,9 @@ class Pymultinest(NestedSampler):
return self.result return self.result
def _check_and_load_sampling_time_file(self): def _check_and_load_sampling_time_file(self):
self.time_file_path = self.kwargs["outputfiles_basename"] + '/sampling_time.dat' self.time_file_path = self.kwargs["outputfiles_basename"] + "/sampling_time.dat"
if os.path.exists(self.time_file_path): if os.path.exists(self.time_file_path):
with open(self.time_file_path, 'r') as time_file: with open(self.time_file_path, "r") as time_file:
self.total_sampling_time = float(time_file.readline()) self.total_sampling_time = float(time_file.readline())
else: else:
self.total_sampling_time = 0 self.total_sampling_time = 0
...@@ -253,7 +260,7 @@ class Pymultinest(NestedSampler): ...@@ -253,7 +260,7 @@ class Pymultinest(NestedSampler):
new_sampling_time = current_time - self.start_time new_sampling_time = current_time - self.start_time
self.total_sampling_time += new_sampling_time self.total_sampling_time += new_sampling_time
self.start_time = current_time self.start_time = current_time
with open(self.time_file_path, 'w') as time_file: with open(self.time_file_path, "w") as time_file:
time_file.write(str(self.total_sampling_time)) time_file.write(str(self.total_sampling_time))
def _clean_up_run_directory(self): def _clean_up_run_directory(self):
...@@ -271,16 +278,18 @@ class Pymultinest(NestedSampler): ...@@ -271,16 +278,18 @@ class Pymultinest(NestedSampler):
estimate of `remaining_prior_volume / N`. estimate of `remaining_prior_volume / N`.
""" """
import pandas as pd import pandas as pd
dir_ = self.kwargs["outputfiles_basename"] dir_ = self.kwargs["outputfiles_basename"]
dead_points = np.genfromtxt(dir_ + "/ev.dat") dead_points = np.genfromtxt(dir_ + "/ev.dat")
live_points = np.genfromtxt(dir_ + "/phys_live.points") live_points = np.genfromtxt(dir_ + "/phys_live.points")
nlive = self.kwargs["n_live_points"] nlive = self.kwargs["n_live_points"]
final_log_prior_volume = - len(dead_points) / nlive - np.log(nlive) final_log_prior_volume = -len(dead_points) / nlive - np.log(nlive)
live_points = np.insert(live_points, -1, final_log_prior_volume, axis=-1) live_points = np.insert(live_points, -1, final_log_prior_volume, axis=-1)
nested_samples = pd.DataFrame( nested_samples = pd.DataFrame(
np.vstack([dead_points, live_points]).copy(), np.vstack([dead_points, live_points]).copy(),
columns=self.search_parameter_keys + ["log_likelihood", "log_prior_volume", "mode"] columns=self.search_parameter_keys
+ ["log_likelihood", "log_prior_volume", "mode"],
) )
return nested_samples return nested_samples
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment