Skip to content
Snippets Groups Projects
Commit dbd2831e authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'ultranest-file-transfer' into 'master'

Update ultranest temporary file behaviour to match pymultinest

See merge request !798
parents b9c94fa7 ad8a0d68
No related branches found
No related tags found
No related merge requests found
from __future__ import absolute_import from __future__ import absolute_import
import datetime
import distutils.dir_util
import inspect
import os import os
import shutil import shutil
import signal import signal
import tempfile import tempfile
import time
import numpy as np import numpy as np
from pandas import DataFrame from pandas import DataFrame
...@@ -59,7 +63,7 @@ class Ultranest(NestedSampler): ...@@ -59,7 +63,7 @@ class Ultranest(NestedSampler):
dlogz=None, dlogz=None,
max_iters=None, max_iters=None,
update_interval_iter_fraction=0.2, update_interval_iter_fraction=0.2,
viz_callback="auto", viz_callback=None,
dKL=0.5, dKL=0.5,
frac_remain=0.01, frac_remain=0.01,
Lepsilon=0.001, Lepsilon=0.001,
...@@ -81,6 +85,8 @@ class Ultranest(NestedSampler): ...@@ -81,6 +85,8 @@ class Ultranest(NestedSampler):
plot=False, plot=False,
exit_code=77, exit_code=77,
skip_import_verification=False, skip_import_verification=False,
temporary_directory=True,
callback_interval=10,
**kwargs, **kwargs,
): ):
super(Ultranest, self).__init__( super(Ultranest, self).__init__(
...@@ -95,6 +101,12 @@ class Ultranest(NestedSampler): ...@@ -95,6 +101,12 @@ class Ultranest(NestedSampler):
**kwargs, **kwargs,
) )
self._apply_ultranest_boundaries() self._apply_ultranest_boundaries()
self.use_temporary_directory = temporary_directory
if self.use_temporary_directory:
# set callback interval, so copying of results does not thrash the
# disk (ultranest will call viz_callback quite a lot)
self.callback_interval = callback_interval
signal.signal(signal.SIGTERM, self.write_current_state_and_exit) signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit) signal.signal(signal.SIGINT, self.write_current_state_and_exit)
...@@ -113,9 +125,18 @@ class Ultranest(NestedSampler): ...@@ -113,9 +125,18 @@ class Ultranest(NestedSampler):
""" Check the kwargs """ """ Check the kwargs """
self.outputfiles_basename = self.kwargs.pop("log_dir", None) self.outputfiles_basename = self.kwargs.pop("log_dir", None)
if self.kwargs["viz_callback"] is None:
self.kwargs["viz_callback"] = self._viz_callback
NestedSampler._verify_kwargs_against_default_kwargs(self) NestedSampler._verify_kwargs_against_default_kwargs(self)
def _viz_callback(self, *args, **kwargs):
if self.use_temporary_directory:
if not (self._viz_callback_counter % self.callback_interval):
self._copy_temporary_directory_contents_to_proper_path()
self._calculate_and_save_sampling_time()
self._viz_callback_counter += 1
def _apply_ultranest_boundaries(self): def _apply_ultranest_boundaries(self):
if ( if (
self.kwargs["wrapped_params"] is None self.kwargs["wrapped_params"] is None
...@@ -136,9 +157,11 @@ class Ultranest(NestedSampler): ...@@ -136,9 +157,11 @@ class Ultranest(NestedSampler):
@outputfiles_basename.setter @outputfiles_basename.setter
def outputfiles_basename(self, outputfiles_basename): def outputfiles_basename(self, outputfiles_basename):
if outputfiles_basename is None: if outputfiles_basename is None:
outputfiles_basename = "{}/ultra_{}".format(self.outdir, self.label) outputfiles_basename = os.path.join(
if outputfiles_basename.endswith("/") is True: self.outdir, "ultra_{}/".format(self.label)
outputfiles_basename = outputfiles_basename.rstrip("/") )
if not outputfiles_basename.endswith("/"):
outputfiles_basename += "/"
check_directory_exists_and_if_not_mkdir(self.outdir) check_directory_exists_and_if_not_mkdir(self.outdir)
self._outputfiles_basename = outputfiles_basename self._outputfiles_basename = outputfiles_basename
...@@ -148,7 +171,7 @@ class Ultranest(NestedSampler): ...@@ -148,7 +171,7 @@ class Ultranest(NestedSampler):
@temporary_outputfiles_basename.setter @temporary_outputfiles_basename.setter
def temporary_outputfiles_basename(self, temporary_outputfiles_basename): def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
if temporary_outputfiles_basename.endswith("/") is False: if not temporary_outputfiles_basename.endswith("/"):
temporary_outputfiles_basename = "{}/".format( temporary_outputfiles_basename = "{}/".format(
temporary_outputfiles_basename temporary_outputfiles_basename
) )
...@@ -157,10 +180,6 @@ class Ultranest(NestedSampler): ...@@ -157,10 +180,6 @@ class Ultranest(NestedSampler):
shutil.copytree( shutil.copytree(
self.outputfiles_basename, self.temporary_outputfiles_basename self.outputfiles_basename, self.temporary_outputfiles_basename
) )
if os.path.islink(self.outputfiles_basename):
os.unlink(self.outputfiles_basename)
else:
shutil.rmtree(self.outputfiles_basename)
def write_current_state_and_exit(self, signum=None, frame=None): def write_current_state_and_exit(self, signum=None, frame=None):
""" Write current state and exit on exit_code """ """ Write current state and exit on exit_code """
...@@ -169,24 +188,38 @@ class Ultranest(NestedSampler): ...@@ -169,24 +188,38 @@ class Ultranest(NestedSampler):
signum, self.exit_code signum, self.exit_code
) )
) )
# self.copy_temporary_directory_to_proper_path() self._calculate_and_save_sampling_time()
if self.use_temporary_directory:
self._move_temporary_directory_to_proper_path()
os._exit(self.exit_code) os._exit(self.exit_code)
def copy_temporary_directory_to_proper_path(self): def _copy_temporary_directory_contents_to_proper_path(self):
logger.info( """
"Overwriting {} with {}".format( Copy the temporary back to the proper path.
self.outputfiles_basename, self.temporary_outputfiles_basename Do not delete the temporary directory.
"""
if inspect.stack()[1].function != "_viz_callback":
logger.info(
"Overwriting {} with {}".format(
self.outputfiles_basename, self.temporary_outputfiles_basename
)
) )
if self.outputfiles_basename.endswith("/"):
outputfiles_basename_stripped = self.outputfiles_basename[:-1]
else:
outputfiles_basename_stripped = self.outputfiles_basename
distutils.dir_util.copy_tree(
self.temporary_outputfiles_basename, outputfiles_basename_stripped
) )
# First remove anything in the outputfiles_basename for overwriting def _move_temporary_directory_to_proper_path(self):
if os.path.exists(self.outputfiles_basename): """
if os.path.islink(self.outputfiles_basename): Move the temporary back to the proper path
os.unlink(self.outputfiles_basename)
else:
shutil.rmtree(self.outputfiles_basename, ignore_errors=True)
shutil.copytree(self.temporary_outputfiles_basename, self.outputfiles_basename) Anything in the proper path at this point is removed including links
"""
self._copy_temporary_directory_contents_to_proper_path()
shutil.rmtree(self.temporary_outputfiles_basename)
@property @property
def sampler_function_kwargs(self): def sampler_function_kwargs(self):
...@@ -253,19 +286,8 @@ class Ultranest(NestedSampler): ...@@ -253,19 +286,8 @@ class Ultranest(NestedSampler):
stepsampler = self.kwargs.pop("step_sampler", None) stepsampler = self.kwargs.pop("step_sampler", None)
temporary_outputfiles_basename = tempfile.TemporaryDirectory().name self._setup_run_directory()
self.temporary_outputfiles_basename = temporary_outputfiles_basename self._check_and_load_sampling_time_file()
logger.info("Using temporary file {}".format(temporary_outputfiles_basename))
check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
self.kwargs["log_dir"] = self.temporary_outputfiles_basename
# Symlink the temporary directory with the target directory: ensures data is stored on exit
os.symlink(
os.path.abspath(self.temporary_outputfiles_basename),
os.path.abspath(self.outputfiles_basename),
target_is_directory=True,
)
# use reactive nested sampler when no live points are given # use reactive nested sampler when no live points are given
if self.kwargs.get("num_live_points", None) is not None: if self.kwargs.get("num_live_points", None) is not None:
...@@ -289,18 +311,66 @@ class Ultranest(NestedSampler): ...@@ -289,18 +311,66 @@ class Ultranest(NestedSampler):
"The default step sampling will be used instead." "The default step sampling will be used instead."
) )
results = sampler.run(**self.sampler_function_kwargs) if self.use_temporary_directory:
self._viz_callback_counter = 1
self.copy_temporary_directory_to_proper_path() self.start_time = time.time()
results = sampler.run(**self.sampler_function_kwargs)
self._calculate_and_save_sampling_time()
# Clean up # Clean up
shutil.rmtree(temporary_outputfiles_basename) self._clean_up_run_directory()
self._generate_result(results) self._generate_result(results)
self.calc_likelihood_count() self.calc_likelihood_count()
return self.result return self.result
def _setup_run_directory(self):
"""
If using a temporary directory, the output directory is moved to the
temporary directory and symlinked back.
"""
if self.use_temporary_directory:
temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
self.temporary_outputfiles_basename = temporary_outputfiles_basename
if os.path.exists(self.outputfiles_basename):
distutils.dir_util.copy_tree(
self.outputfiles_basename, self.temporary_outputfiles_basename
)
check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
self.kwargs["log_dir"] = self.temporary_outputfiles_basename
logger.info(
"Using temporary file {}".format(temporary_outputfiles_basename)
)
else:
check_directory_exists_and_if_not_mkdir(self.outputfiles_basename)
self.kwargs["log_dir"] = self.outputfiles_basename
logger.info("Using output file {}".format(self.outputfiles_basename))
def _clean_up_run_directory(self):
if self.use_temporary_directory:
self._move_temporary_directory_to_proper_path()
self.kwargs["log_dir"] = self.outputfiles_basename
def _check_and_load_sampling_time_file(self):
self.time_file_path = os.path.join(self.kwargs["log_dir"], "sampling_time.dat")
if os.path.exists(self.time_file_path):
with open(self.time_file_path, "r") as time_file:
self.total_sampling_time = float(time_file.readline())
else:
self.total_sampling_time = 0
def _calculate_and_save_sampling_time(self):
current_time = time.time()
new_sampling_time = current_time - self.start_time
self.total_sampling_time += new_sampling_time
with open(self.time_file_path, "w") as time_file:
time_file.write(str(self.total_sampling_time))
self.start_time = current_time
def _generate_result(self, out): def _generate_result(self, out):
# extract results (samples stored in "v" will change to "points", # extract results (samples stored in "v" will change to "points",
# weights stored in "w" will change to "weights") # weights stored in "w" will change to "weights")
...@@ -325,3 +395,4 @@ class Ultranest(NestedSampler): ...@@ -325,3 +395,4 @@ class Ultranest(NestedSampler):
self.result.log_evidence_err = out["logzerr"] self.result.log_evidence_err = out["logzerr"]
self.result.outputfiles_basename = self.outputfiles_basename self.result.outputfiles_basename = self.outputfiles_basename
self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time)
...@@ -836,7 +836,7 @@ class TestUltranest(unittest.TestCase): ...@@ -836,7 +836,7 @@ class TestUltranest(unittest.TestCase):
dlogz=None, dlogz=None,
max_iters=None, max_iters=None,
update_interval_iter_fraction=0.2, update_interval_iter_fraction=0.2,
viz_callback="auto", viz_callback=None,
dKL=0.5, dKL=0.5,
frac_remain=0.01, frac_remain=0.01,
Lepsilon=0.001, Lepsilon=0.001,
...@@ -850,6 +850,7 @@ class TestUltranest(unittest.TestCase): ...@@ -850,6 +850,7 @@ class TestUltranest(unittest.TestCase):
self.assertListEqual([1, 0], self.sampler.kwargs["wrapped_params"]) # Check this separately self.assertListEqual([1, 0], self.sampler.kwargs["wrapped_params"]) # Check this separately
self.sampler.kwargs["wrapped_params"] = None # The dict comparison can't handle lists self.sampler.kwargs["wrapped_params"] = None # The dict comparison can't handle lists
self.sampler.kwargs["derived_param_names"] = None self.sampler.kwargs["derived_param_names"] = None
self.sampler.kwargs["viz_callback"] = None
self.assertDictEqual(expected, self.sampler.kwargs) self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs(self): def test_translate_kwargs(self):
...@@ -870,7 +871,7 @@ class TestUltranest(unittest.TestCase): ...@@ -870,7 +871,7 @@ class TestUltranest(unittest.TestCase):
dlogz=None, dlogz=None,
max_iters=None, max_iters=None,
update_interval_iter_fraction=0.2, update_interval_iter_fraction=0.2,
viz_callback="auto", viz_callback=None,
dKL=0.5, dKL=0.5,
frac_remain=0.01, frac_remain=0.01,
Lepsilon=0.001, Lepsilon=0.001,
...@@ -884,10 +885,11 @@ class TestUltranest(unittest.TestCase): ...@@ -884,10 +885,11 @@ class TestUltranest(unittest.TestCase):
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy() new_kwargs = self.sampler.kwargs.copy()
del new_kwargs['num_live_points'] del new_kwargs['num_live_points']
new_kwargs['wrapped_params'] = None # The dict comparison can't handle lists
new_kwargs["derived_param_names"] = None
new_kwargs[equiv] = 123 new_kwargs[equiv] = 123
self.sampler.kwargs = new_kwargs self.sampler.kwargs = new_kwargs
self.sampler.kwargs["wrapped_params"] = None
self.sampler.kwargs["derived_param_names"] = None
self.sampler.kwargs["viz_callback"] = None
self.assertDictEqual(expected, self.sampler.kwargs) self.assertDictEqual(expected, self.sampler.kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment