Skip to content
Snippets Groups Projects
Commit dbd2831e authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'ultranest-file-transfer' into 'master'

Update ultranest temporary file behaviour to match pymultinest

See merge request !798
parents b9c94fa7 ad8a0d68
No related branches found
No related tags found
1 merge request!798Update ultranest temporary file behaviour to match pymultinest
Pipeline #134134 failed
from __future__ import absolute_import
import datetime
import distutils.dir_util
import inspect
import os
import shutil
import signal
import tempfile
import time
import numpy as np
from pandas import DataFrame
......@@ -59,7 +63,7 @@ class Ultranest(NestedSampler):
dlogz=None,
max_iters=None,
update_interval_iter_fraction=0.2,
viz_callback="auto",
viz_callback=None,
dKL=0.5,
frac_remain=0.01,
Lepsilon=0.001,
......@@ -81,6 +85,8 @@ class Ultranest(NestedSampler):
plot=False,
exit_code=77,
skip_import_verification=False,
temporary_directory=True,
callback_interval=10,
**kwargs,
):
super(Ultranest, self).__init__(
......@@ -95,6 +101,12 @@ class Ultranest(NestedSampler):
**kwargs,
)
self._apply_ultranest_boundaries()
self.use_temporary_directory = temporary_directory
if self.use_temporary_directory:
# set callback interval, so copying of results does not thrash the
# disk (ultranest will call viz_callback quite a lot)
self.callback_interval = callback_interval
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
......@@ -113,9 +125,18 @@ class Ultranest(NestedSampler):
""" Check the kwargs """
self.outputfiles_basename = self.kwargs.pop("log_dir", None)
if self.kwargs["viz_callback"] is None:
self.kwargs["viz_callback"] = self._viz_callback
NestedSampler._verify_kwargs_against_default_kwargs(self)
def _viz_callback(self, *args, **kwargs):
if self.use_temporary_directory:
if not (self._viz_callback_counter % self.callback_interval):
self._copy_temporary_directory_contents_to_proper_path()
self._calculate_and_save_sampling_time()
self._viz_callback_counter += 1
def _apply_ultranest_boundaries(self):
if (
self.kwargs["wrapped_params"] is None
......@@ -136,9 +157,11 @@ class Ultranest(NestedSampler):
@outputfiles_basename.setter
def outputfiles_basename(self, outputfiles_basename):
if outputfiles_basename is None:
outputfiles_basename = "{}/ultra_{}".format(self.outdir, self.label)
if outputfiles_basename.endswith("/") is True:
outputfiles_basename = outputfiles_basename.rstrip("/")
outputfiles_basename = os.path.join(
self.outdir, "ultra_{}/".format(self.label)
)
if not outputfiles_basename.endswith("/"):
outputfiles_basename += "/"
check_directory_exists_and_if_not_mkdir(self.outdir)
self._outputfiles_basename = outputfiles_basename
......@@ -148,7 +171,7 @@ class Ultranest(NestedSampler):
@temporary_outputfiles_basename.setter
def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
if temporary_outputfiles_basename.endswith("/") is False:
if not temporary_outputfiles_basename.endswith("/"):
temporary_outputfiles_basename = "{}/".format(
temporary_outputfiles_basename
)
......@@ -157,10 +180,6 @@ class Ultranest(NestedSampler):
shutil.copytree(
self.outputfiles_basename, self.temporary_outputfiles_basename
)
if os.path.islink(self.outputfiles_basename):
os.unlink(self.outputfiles_basename)
else:
shutil.rmtree(self.outputfiles_basename)
def write_current_state_and_exit(self, signum=None, frame=None):
""" Write current state and exit on exit_code """
......@@ -169,24 +188,38 @@ class Ultranest(NestedSampler):
signum, self.exit_code
)
)
# self.copy_temporary_directory_to_proper_path()
self._calculate_and_save_sampling_time()
if self.use_temporary_directory:
self._move_temporary_directory_to_proper_path()
os._exit(self.exit_code)
def copy_temporary_directory_to_proper_path(self):
logger.info(
"Overwriting {} with {}".format(
self.outputfiles_basename, self.temporary_outputfiles_basename
def _copy_temporary_directory_contents_to_proper_path(self):
"""
Copy the temporary back to the proper path.
Do not delete the temporary directory.
"""
if inspect.stack()[1].function != "_viz_callback":
logger.info(
"Overwriting {} with {}".format(
self.outputfiles_basename, self.temporary_outputfiles_basename
)
)
if self.outputfiles_basename.endswith("/"):
outputfiles_basename_stripped = self.outputfiles_basename[:-1]
else:
outputfiles_basename_stripped = self.outputfiles_basename
distutils.dir_util.copy_tree(
self.temporary_outputfiles_basename, outputfiles_basename_stripped
)
# First remove anything in the outputfiles_basename for overwriting
if os.path.exists(self.outputfiles_basename):
if os.path.islink(self.outputfiles_basename):
os.unlink(self.outputfiles_basename)
else:
shutil.rmtree(self.outputfiles_basename, ignore_errors=True)
def _move_temporary_directory_to_proper_path(self):
"""
Move the temporary back to the proper path
shutil.copytree(self.temporary_outputfiles_basename, self.outputfiles_basename)
Anything in the proper path at this point is removed including links
"""
self._copy_temporary_directory_contents_to_proper_path()
shutil.rmtree(self.temporary_outputfiles_basename)
@property
def sampler_function_kwargs(self):
......@@ -253,19 +286,8 @@ class Ultranest(NestedSampler):
stepsampler = self.kwargs.pop("step_sampler", None)
temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
self.temporary_outputfiles_basename = temporary_outputfiles_basename
logger.info("Using temporary file {}".format(temporary_outputfiles_basename))
check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
self.kwargs["log_dir"] = self.temporary_outputfiles_basename
# Symlink the temporary directory with the target directory: ensures data is stored on exit
os.symlink(
os.path.abspath(self.temporary_outputfiles_basename),
os.path.abspath(self.outputfiles_basename),
target_is_directory=True,
)
self._setup_run_directory()
self._check_and_load_sampling_time_file()
# use reactive nested sampler when no live points are given
if self.kwargs.get("num_live_points", None) is not None:
......@@ -289,18 +311,66 @@ class Ultranest(NestedSampler):
"The default step sampling will be used instead."
)
results = sampler.run(**self.sampler_function_kwargs)
if self.use_temporary_directory:
self._viz_callback_counter = 1
self.copy_temporary_directory_to_proper_path()
self.start_time = time.time()
results = sampler.run(**self.sampler_function_kwargs)
self._calculate_and_save_sampling_time()
# Clean up
shutil.rmtree(temporary_outputfiles_basename)
self._clean_up_run_directory()
self._generate_result(results)
self.calc_likelihood_count()
return self.result
def _setup_run_directory(self):
"""
If using a temporary directory, the output directory is moved to the
temporary directory and symlinked back.
"""
if self.use_temporary_directory:
temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
self.temporary_outputfiles_basename = temporary_outputfiles_basename
if os.path.exists(self.outputfiles_basename):
distutils.dir_util.copy_tree(
self.outputfiles_basename, self.temporary_outputfiles_basename
)
check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
self.kwargs["log_dir"] = self.temporary_outputfiles_basename
logger.info(
"Using temporary file {}".format(temporary_outputfiles_basename)
)
else:
check_directory_exists_and_if_not_mkdir(self.outputfiles_basename)
self.kwargs["log_dir"] = self.outputfiles_basename
logger.info("Using output file {}".format(self.outputfiles_basename))
def _clean_up_run_directory(self):
if self.use_temporary_directory:
self._move_temporary_directory_to_proper_path()
self.kwargs["log_dir"] = self.outputfiles_basename
def _check_and_load_sampling_time_file(self):
self.time_file_path = os.path.join(self.kwargs["log_dir"], "sampling_time.dat")
if os.path.exists(self.time_file_path):
with open(self.time_file_path, "r") as time_file:
self.total_sampling_time = float(time_file.readline())
else:
self.total_sampling_time = 0
def _calculate_and_save_sampling_time(self):
current_time = time.time()
new_sampling_time = current_time - self.start_time
self.total_sampling_time += new_sampling_time
with open(self.time_file_path, "w") as time_file:
time_file.write(str(self.total_sampling_time))
self.start_time = current_time
def _generate_result(self, out):
# extract results (samples stored in "v" will change to "points",
# weights stored in "w" will change to "weights")
......@@ -325,3 +395,4 @@ class Ultranest(NestedSampler):
self.result.log_evidence_err = out["logzerr"]
self.result.outputfiles_basename = self.outputfiles_basename
self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time)
......@@ -836,7 +836,7 @@ class TestUltranest(unittest.TestCase):
dlogz=None,
max_iters=None,
update_interval_iter_fraction=0.2,
viz_callback="auto",
viz_callback=None,
dKL=0.5,
frac_remain=0.01,
Lepsilon=0.001,
......@@ -850,6 +850,7 @@ class TestUltranest(unittest.TestCase):
self.assertListEqual([1, 0], self.sampler.kwargs["wrapped_params"]) # Check this separately
self.sampler.kwargs["wrapped_params"] = None # The dict comparison can't handle lists
self.sampler.kwargs["derived_param_names"] = None
self.sampler.kwargs["viz_callback"] = None
self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs(self):
......@@ -870,7 +871,7 @@ class TestUltranest(unittest.TestCase):
dlogz=None,
max_iters=None,
update_interval_iter_fraction=0.2,
viz_callback="auto",
viz_callback=None,
dKL=0.5,
frac_remain=0.01,
Lepsilon=0.001,
......@@ -884,10 +885,11 @@ class TestUltranest(unittest.TestCase):
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
del new_kwargs['num_live_points']
new_kwargs['wrapped_params'] = None # The dict comparison can't handle lists
new_kwargs["derived_param_names"] = None
new_kwargs[equiv] = 123
self.sampler.kwargs = new_kwargs
self.sampler.kwargs["wrapped_params"] = None
self.sampler.kwargs["derived_param_names"] = None
self.sampler.kwargs["viz_callback"] = None
self.assertDictEqual(expected, self.sampler.kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment