Skip to content
Snippets Groups Projects
Commit 2b239f75 authored by Sylvia Biscoveanu's avatar Sylvia Biscoveanu Committed by Colm Talbot
Browse files

Fix the multinest temporary file transfer

parent 2b12d397
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,10 @@ import importlib ...@@ -2,7 +2,10 @@ import importlib
import os import os
import tempfile import tempfile
import shutil import shutil
import distutils.dir_util
import signal import signal
import time
import datetime
import numpy as np import numpy as np
...@@ -115,8 +118,15 @@ class Pymultinest(NestedSampler): ...@@ -115,8 +118,15 @@ class Pymultinest(NestedSampler):
# for PyMultiNest >=2.9 the n_params kwarg cannot be None # for PyMultiNest >=2.9 the n_params kwarg cannot be None
if self.kwargs["n_params"] is None: if self.kwargs["n_params"] is None:
self.kwargs["n_params"] = self.ndim self.kwargs["n_params"] = self.ndim
if self.kwargs['dump_callback'] is None:
self.kwargs['dump_callback'] = self._dump_callback
NestedSampler._verify_kwargs_against_default_kwargs(self) NestedSampler._verify_kwargs_against_default_kwargs(self)
def _dump_callback(self, *args, **kwargs):
if self.use_temporary_directory:
self._copy_temporary_directory_contents_to_proper_path()
self._calculate_and_save_sampling_time()
def _apply_multinest_boundaries(self): def _apply_multinest_boundaries(self):
if self.kwargs["wrapped_params"] is None: if self.kwargs["wrapped_params"] is None:
self.kwargs["wrapped_params"] = [] self.kwargs["wrapped_params"] = []
...@@ -154,10 +164,6 @@ class Pymultinest(NestedSampler): ...@@ -154,10 +164,6 @@ class Pymultinest(NestedSampler):
shutil.copytree( shutil.copytree(
self.outputfiles_basename, self.temporary_outputfiles_basename self.outputfiles_basename, self.temporary_outputfiles_basename
) )
if os.path.islink(self.outputfiles_basename):
os.unlink(self.outputfiles_basename)
else:
shutil.rmtree(self.outputfiles_basename)
def write_current_state_and_exit(self, signum=None, frame=None): def write_current_state_and_exit(self, signum=None, frame=None):
""" Write current state and exit on exit_code """ """ Write current state and exit on exit_code """
...@@ -166,15 +172,15 @@ class Pymultinest(NestedSampler): ...@@ -166,15 +172,15 @@ class Pymultinest(NestedSampler):
signum, self.exit_code signum, self.exit_code
) )
) )
self._calculate_and_save_sampling_time()
if self.use_temporary_directory: if self.use_temporary_directory:
self._move_temporary_directory_to_proper_path() self._move_temporary_directory_to_proper_path()
os._exit(self.exit_code) os._exit(self.exit_code)
def _move_temporary_directory_to_proper_path(self): def _copy_temporary_directory_contents_to_proper_path(self):
""" """
Move the temporary back to the proper path Copy the temporary back to the proper path.
Do not delete the temporary directory.
Anything in the proper path at this point is removed including links
""" """
logger.info( logger.info(
"Overwriting {} with {}".format( "Overwriting {} with {}".format(
...@@ -185,11 +191,16 @@ class Pymultinest(NestedSampler): ...@@ -185,11 +191,16 @@ class Pymultinest(NestedSampler):
outputfiles_basename_stripped = self.outputfiles_basename[:-1] outputfiles_basename_stripped = self.outputfiles_basename[:-1]
else: else:
outputfiles_basename_stripped = self.outputfiles_basename outputfiles_basename_stripped = self.outputfiles_basename
if os.path.islink(outputfiles_basename_stripped): distutils.dir_util.copy_tree(self.temporary_outputfiles_basename, outputfiles_basename_stripped)
os.unlink(outputfiles_basename_stripped)
elif os.path.isdir(outputfiles_basename_stripped): def _move_temporary_directory_to_proper_path(self):
shutil.rmtree(outputfiles_basename_stripped) """
shutil.move(self.temporary_outputfiles_basename, outputfiles_basename_stripped) Copy the temporary back to the proper path
Anything in the temporary directory at this point is removed
"""
self._copy_temporary_directory_contents_to_proper_path()
shutil.rmtree(self.temporary_outputfiles_basename)
def run_sampler(self): def run_sampler(self):
import pymultinest import pymultinest
...@@ -197,17 +208,20 @@ class Pymultinest(NestedSampler): ...@@ -197,17 +208,20 @@ class Pymultinest(NestedSampler):
self._verify_kwargs_against_default_kwargs() self._verify_kwargs_against_default_kwargs()
self._setup_run_directory() self._setup_run_directory()
self._check_and_load_sampling_time_file()
# Overwrite pymultinest's signal handling function # Overwrite pymultinest's signal handling function
pm_run = importlib.import_module("pymultinest.run") pm_run = importlib.import_module("pymultinest.run")
pm_run.interrupt_handler = self.write_current_state_and_exit pm_run.interrupt_handler = self.write_current_state_and_exit
self.start_time = time.time()
out = pymultinest.solve( out = pymultinest.solve(
LogLikelihood=self.log_likelihood, LogLikelihood=self.log_likelihood,
Prior=self.prior_transform, Prior=self.prior_transform,
n_dims=self.ndim, n_dims=self.ndim,
**self.kwargs **self.kwargs
) )
self._calculate_and_save_sampling_time()
self._clean_up_run_directory() self._clean_up_run_directory()
...@@ -222,26 +236,22 @@ class Pymultinest(NestedSampler): ...@@ -222,26 +236,22 @@ class Pymultinest(NestedSampler):
self.result.log_evidence_err = out["logZerr"] self.result.log_evidence_err = out["logZerr"]
self.calc_likelihood_count() self.calc_likelihood_count()
self.result.outputfiles_basename = self.outputfiles_basename self.result.outputfiles_basename = self.outputfiles_basename
self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time)
return self.result return self.result
def _setup_run_directory(self): def _setup_run_directory(self):
""" """
If using a temporary directory, the output directory is moved to the If using a temporary directory, the output directory is moved to the
temporary directory and symlinked back. temporary directory.
""" """
if self.use_temporary_directory: if self.use_temporary_directory:
temporary_outputfiles_basename = tempfile.TemporaryDirectory().name temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
self.temporary_outputfiles_basename = temporary_outputfiles_basename self.temporary_outputfiles_basename = temporary_outputfiles_basename
if os.path.exists(self.outputfiles_basename): if os.path.exists(self.outputfiles_basename):
shutil.move(self.outputfiles_basename, self.temporary_outputfiles_basename) distutils.dir_util.copy_tree(self.outputfiles_basename, self.temporary_outputfiles_basename)
check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename) check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
os.symlink(
os.path.abspath(self.temporary_outputfiles_basename),
os.path.abspath(self.outputfiles_basename),
target_is_directory=True,
)
self.kwargs["outputfiles_basename"] = self.temporary_outputfiles_basename self.kwargs["outputfiles_basename"] = self.temporary_outputfiles_basename
logger.info("Using temporary file {}".format(temporary_outputfiles_basename)) logger.info("Using temporary file {}".format(temporary_outputfiles_basename))
else: else:
...@@ -249,6 +259,21 @@ class Pymultinest(NestedSampler): ...@@ -249,6 +259,21 @@ class Pymultinest(NestedSampler):
self.kwargs["outputfiles_basename"] = self.outputfiles_basename self.kwargs["outputfiles_basename"] = self.outputfiles_basename
logger.info("Using output file {}".format(self.outputfiles_basename)) logger.info("Using output file {}".format(self.outputfiles_basename))
def _check_and_load_sampling_time_file(self):
self.time_file_path = self.kwargs["outputfiles_basename"] + '/sampling_time.dat'
if os.path.exists(self.time_file_path):
with open(self.time_file_path, 'r') as time_file:
self.total_sampling_time = float(time_file.readline())
else:
self.total_sampling_time = 0
def _calculate_and_save_sampling_time(self):
current_time = time.time()
new_sampling_time = current_time - self.start_time
self.total_sampling_time += new_sampling_time
with open(self.time_file_path, 'w') as time_file:
time_file.write(str(self.total_sampling_time))
def _clean_up_run_directory(self): def _clean_up_run_directory(self):
if self.use_temporary_directory: if self.use_temporary_directory:
self._move_temporary_directory_to_proper_path() self._move_temporary_directory_to_proper_path()
......
...@@ -767,7 +767,8 @@ class TestPymultinest(unittest.TestCase): ...@@ -767,7 +767,8 @@ class TestPymultinest(unittest.TestCase):
n_iter_before_update=100, null_log_evidence=-1e90, n_iter_before_update=100, null_log_evidence=-1e90,
max_modes=100, mode_tolerance=-1e90, seed=-1, max_modes=100, mode_tolerance=-1e90, seed=-1,
context=0, write_output=True, log_zero=-1e100, context=0, write_output=True, log_zero=-1e100,
max_iter=0, init_MPI=False, dump_callback=None) max_iter=0, init_MPI=False, dump_callback='dumper')
self.sampler.kwargs['dump_callback'] = 'dumper' # Check like the dynesty print_func
self.assertListEqual([1, 0], self.sampler.kwargs['wrapped_params']) # Check this separately self.assertListEqual([1, 0], self.sampler.kwargs['wrapped_params']) # Check this separately
self.sampler.kwargs['wrapped_params'] = None # The dict comparison can't handle lists self.sampler.kwargs['wrapped_params'] = None # The dict comparison can't handle lists
self.assertDictEqual(expected, self.sampler.kwargs) self.assertDictEqual(expected, self.sampler.kwargs)
...@@ -782,7 +783,7 @@ class TestPymultinest(unittest.TestCase): ...@@ -782,7 +783,7 @@ class TestPymultinest(unittest.TestCase):
n_iter_before_update=100, null_log_evidence=-1e90, n_iter_before_update=100, null_log_evidence=-1e90,
max_modes=100, mode_tolerance=-1e90, seed=-1, max_modes=100, mode_tolerance=-1e90, seed=-1,
context=0, write_output=True, log_zero=-1e100, context=0, write_output=True, log_zero=-1e100,
max_iter=0, init_MPI=False, dump_callback=None) max_iter=0, init_MPI=False, dump_callback='dumper')
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy() new_kwargs = self.sampler.kwargs.copy()
...@@ -790,6 +791,7 @@ class TestPymultinest(unittest.TestCase): ...@@ -790,6 +791,7 @@ class TestPymultinest(unittest.TestCase):
new_kwargs[ new_kwargs[
"wrapped_params" "wrapped_params"
] = None # The dict comparison can't handle lists ] = None # The dict comparison can't handle lists
new_kwargs['dump_callback'] = 'dumper' # Check this like Dynesty print_func
new_kwargs[equiv] = 123 new_kwargs[equiv] = 123
self.sampler.kwargs = new_kwargs self.sampler.kwargs = new_kwargs
self.assertDictEqual(expected, self.sampler.kwargs) self.assertDictEqual(expected, self.sampler.kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment