diff --git a/bilby/core/sampler/pymultinest.py b/bilby/core/sampler/pymultinest.py index 055d4e14ad969f70da745fa13571db693405784a..9421c8207c09dae938d4b37cf6b61a46c2853a42 100644 --- a/bilby/core/sampler/pymultinest.py +++ b/bilby/core/sampler/pymultinest.py @@ -2,7 +2,10 @@ import importlib import os import tempfile import shutil +import distutils.dir_util import signal +import time +import datetime import numpy as np @@ -115,8 +118,15 @@ class Pymultinest(NestedSampler): # for PyMultiNest >=2.9 the n_params kwarg cannot be None if self.kwargs["n_params"] is None: self.kwargs["n_params"] = self.ndim + if self.kwargs['dump_callback'] is None: + self.kwargs['dump_callback'] = self._dump_callback NestedSampler._verify_kwargs_against_default_kwargs(self) + def _dump_callback(self, *args, **kwargs): + if self.use_temporary_directory: + self._copy_temporary_directory_contents_to_proper_path() + self._calculate_and_save_sampling_time() + def _apply_multinest_boundaries(self): if self.kwargs["wrapped_params"] is None: self.kwargs["wrapped_params"] = [] @@ -154,10 +164,6 @@ class Pymultinest(NestedSampler): shutil.copytree( self.outputfiles_basename, self.temporary_outputfiles_basename ) - if os.path.islink(self.outputfiles_basename): - os.unlink(self.outputfiles_basename) - else: - shutil.rmtree(self.outputfiles_basename) def write_current_state_and_exit(self, signum=None, frame=None): """ Write current state and exit on exit_code """ @@ -166,15 +172,15 @@ class Pymultinest(NestedSampler): signum, self.exit_code ) ) + self._calculate_and_save_sampling_time() if self.use_temporary_directory: self._move_temporary_directory_to_proper_path() os._exit(self.exit_code) - def _move_temporary_directory_to_proper_path(self): + def _copy_temporary_directory_contents_to_proper_path(self): """ - Move the temporary back to the proper path - - Anything in the proper path at this point is removed including links + Copy the temporary back to the proper path. + Do not delete the temporary directory. """ logger.info( "Overwriting {} with {}".format( @@ -185,11 +191,16 @@ class Pymultinest(NestedSampler): outputfiles_basename_stripped = self.outputfiles_basename[:-1] else: outputfiles_basename_stripped = self.outputfiles_basename - if os.path.islink(outputfiles_basename_stripped): - os.unlink(outputfiles_basename_stripped) - elif os.path.isdir(outputfiles_basename_stripped): - shutil.rmtree(outputfiles_basename_stripped) - shutil.move(self.temporary_outputfiles_basename, outputfiles_basename_stripped) + distutils.dir_util.copy_tree(self.temporary_outputfiles_basename, outputfiles_basename_stripped) + + def _move_temporary_directory_to_proper_path(self): + """ + Copy the temporary back to the proper path + + Anything in the temporary directory at this point is removed + """ + self._copy_temporary_directory_contents_to_proper_path() + shutil.rmtree(self.temporary_outputfiles_basename) def run_sampler(self): import pymultinest @@ -197,17 +208,20 @@ class Pymultinest(NestedSampler): self._verify_kwargs_against_default_kwargs() self._setup_run_directory() + self._check_and_load_sampling_time_file() # Overwrite pymultinest's signal handling function pm_run = importlib.import_module("pymultinest.run") pm_run.interrupt_handler = self.write_current_state_and_exit + self.start_time = time.time() out = pymultinest.solve( LogLikelihood=self.log_likelihood, Prior=self.prior_transform, n_dims=self.ndim, **self.kwargs ) + self._calculate_and_save_sampling_time() self._clean_up_run_directory() @@ -222,26 +236,22 @@ class Pymultinest(NestedSampler): self.result.log_evidence_err = out["logZerr"] self.calc_likelihood_count() self.result.outputfiles_basename = self.outputfiles_basename + self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time) return self.result def _setup_run_directory(self): """ If using a temporary directory, the output directory is moved to the - temporary directory and symlinked back. + temporary directory. """ if self.use_temporary_directory: temporary_outputfiles_basename = tempfile.TemporaryDirectory().name self.temporary_outputfiles_basename = temporary_outputfiles_basename if os.path.exists(self.outputfiles_basename): - shutil.move(self.outputfiles_basename, self.temporary_outputfiles_basename) + distutils.dir_util.copy_tree(self.outputfiles_basename, self.temporary_outputfiles_basename) check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename) - os.symlink( - os.path.abspath(self.temporary_outputfiles_basename), - os.path.abspath(self.outputfiles_basename), - target_is_directory=True, - ) self.kwargs["outputfiles_basename"] = self.temporary_outputfiles_basename logger.info("Using temporary file {}".format(temporary_outputfiles_basename)) else: @@ -249,6 +259,21 @@ class Pymultinest(NestedSampler): self.kwargs["outputfiles_basename"] = self.outputfiles_basename logger.info("Using output file {}".format(self.outputfiles_basename)) + def _check_and_load_sampling_time_file(self): + self.time_file_path = self.kwargs["outputfiles_basename"] + '/sampling_time.dat' + if os.path.exists(self.time_file_path): + with open(self.time_file_path, 'r') as time_file: + self.total_sampling_time = float(time_file.readline()) + else: + self.total_sampling_time = 0 + + def _calculate_and_save_sampling_time(self): + current_time = time.time() + new_sampling_time = current_time - self.start_time + self.total_sampling_time += new_sampling_time + with open(self.time_file_path, 'w') as time_file: + time_file.write(str(self.total_sampling_time)) + def _clean_up_run_directory(self): if self.use_temporary_directory: self._move_temporary_directory_to_proper_path() diff --git a/test/sampler_test.py b/test/sampler_test.py index 0140b4f7b8a5c5d6d736adc1350e30ae25cf8f41..d9879543cfe8ab5c1413eb85d85b6ef806621f64 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -767,7 +767,8 @@ class TestPymultinest(unittest.TestCase): n_iter_before_update=100, null_log_evidence=-1e90, max_modes=100, mode_tolerance=-1e90, seed=-1, context=0, write_output=True, log_zero=-1e100, - max_iter=0, init_MPI=False, dump_callback=None) + max_iter=0, init_MPI=False, dump_callback='dumper') + self.sampler.kwargs['dump_callback'] = 'dumper' # Check like the dynesty print_func self.assertListEqual([1, 0], self.sampler.kwargs['wrapped_params']) # Check this separately self.sampler.kwargs['wrapped_params'] = None # The dict comparison can't handle lists self.assertDictEqual(expected, self.sampler.kwargs) @@ -782,7 +783,7 @@ class TestPymultinest(unittest.TestCase): n_iter_before_update=100, null_log_evidence=-1e90, max_modes=100, mode_tolerance=-1e90, seed=-1, context=0, write_output=True, log_zero=-1e100, - max_iter=0, init_MPI=False, dump_callback=None) + max_iter=0, init_MPI=False, dump_callback='dumper') for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: new_kwargs = self.sampler.kwargs.copy() @@ -790,6 +791,7 @@ class TestPymultinest(unittest.TestCase): new_kwargs[ "wrapped_params" ] = None # The dict comparison can't handle lists + new_kwargs['dump_callback'] = 'dumper' # Check this like Dynesty print_func new_kwargs[equiv] = 123 self.sampler.kwargs = new_kwargs self.assertDictEqual(expected, self.sampler.kwargs)