Skip to content
Snippets Groups Projects
Commit ad8a0d68 authored by Matthew David Pitkin's avatar Matthew David Pitkin Committed by Gregory Ashton
Browse files

Apply similar changes to those in !804 to help file transfer on

Condor for ultranest
parent ed8bef95
No related branches found
No related tags found
No related merge requests found
from __future__ import absolute_import
import datetime
import distutils.dir_util
import inspect
import os
import shutil
import signal
import tempfile
import time
import numpy as np
from pandas import DataFrame
......@@ -59,7 +63,7 @@ class Ultranest(NestedSampler):
dlogz=None,
max_iters=None,
update_interval_iter_fraction=0.2,
viz_callback="auto",
viz_callback=None,
dKL=0.5,
frac_remain=0.01,
Lepsilon=0.001,
......@@ -81,6 +85,8 @@ class Ultranest(NestedSampler):
plot=False,
exit_code=77,
skip_import_verification=False,
temporary_directory=True,
callback_interval=10,
**kwargs,
):
super(Ultranest, self).__init__(
......@@ -95,6 +101,12 @@ class Ultranest(NestedSampler):
**kwargs,
)
self._apply_ultranest_boundaries()
self.use_temporary_directory = temporary_directory
if self.use_temporary_directory:
# set callback interval, so copying of results does not thrash the
# disk (ultranest will call viz_callback quite a lot)
self.callback_interval = callback_interval
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
......@@ -113,9 +125,18 @@ class Ultranest(NestedSampler):
""" Check the kwargs """
self.outputfiles_basename = self.kwargs.pop("log_dir", None)
if self.kwargs["viz_callback"] is None:
self.kwargs["viz_callback"] = self._viz_callback
NestedSampler._verify_kwargs_against_default_kwargs(self)
def _viz_callback(self, *args, **kwargs):
if self.use_temporary_directory:
if not (self._viz_callback_counter % self.callback_interval):
self._copy_temporary_directory_contents_to_proper_path()
self._calculate_and_save_sampling_time()
self._viz_callback_counter += 1
def _apply_ultranest_boundaries(self):
if (
self.kwargs["wrapped_params"] is None
......@@ -136,9 +157,11 @@ class Ultranest(NestedSampler):
@outputfiles_basename.setter
def outputfiles_basename(self, outputfiles_basename):
if outputfiles_basename is None:
outputfiles_basename = "{}/ultra_{}".format(self.outdir, self.label)
if outputfiles_basename.endswith("/") is True:
outputfiles_basename = outputfiles_basename.rstrip("/")
outputfiles_basename = os.path.join(
self.outdir, "ultra_{}/".format(self.label)
)
if not outputfiles_basename.endswith("/"):
outputfiles_basename += "/"
check_directory_exists_and_if_not_mkdir(self.outdir)
self._outputfiles_basename = outputfiles_basename
......@@ -148,7 +171,7 @@ class Ultranest(NestedSampler):
@temporary_outputfiles_basename.setter
def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
if temporary_outputfiles_basename.endswith("/") is False:
if not temporary_outputfiles_basename.endswith("/"):
temporary_outputfiles_basename = "{}/".format(
temporary_outputfiles_basename
)
......@@ -157,10 +180,6 @@ class Ultranest(NestedSampler):
shutil.copytree(
self.outputfiles_basename, self.temporary_outputfiles_basename
)
if os.path.islink(self.outputfiles_basename):
os.unlink(self.outputfiles_basename)
else:
shutil.rmtree(self.outputfiles_basename)
def write_current_state_and_exit(self, signum=None, frame=None):
""" Write current state and exit on exit_code """
......@@ -169,24 +188,38 @@ class Ultranest(NestedSampler):
signum, self.exit_code
)
)
# self.copy_temporary_directory_to_proper_path()
self._calculate_and_save_sampling_time()
if self.use_temporary_directory:
self._move_temporary_directory_to_proper_path()
os._exit(self.exit_code)
def copy_temporary_directory_to_proper_path(self):
logger.info(
"Overwriting {} with {}".format(
self.outputfiles_basename, self.temporary_outputfiles_basename
def _copy_temporary_directory_contents_to_proper_path(self):
"""
Copy the temporary back to the proper path.
Do not delete the temporary directory.
"""
if inspect.stack()[1].function != "_viz_callback":
logger.info(
"Overwriting {} with {}".format(
self.outputfiles_basename, self.temporary_outputfiles_basename
)
)
if self.outputfiles_basename.endswith("/"):
outputfiles_basename_stripped = self.outputfiles_basename[:-1]
else:
outputfiles_basename_stripped = self.outputfiles_basename
distutils.dir_util.copy_tree(
self.temporary_outputfiles_basename, outputfiles_basename_stripped
)
# First remove anything in the outputfiles_basename for overwriting
if os.path.exists(self.outputfiles_basename):
if os.path.islink(self.outputfiles_basename):
os.unlink(self.outputfiles_basename)
else:
shutil.rmtree(self.outputfiles_basename, ignore_errors=True)
def _move_temporary_directory_to_proper_path(self):
"""
Move the temporary back to the proper path
shutil.copytree(self.temporary_outputfiles_basename, self.outputfiles_basename)
Anything in the proper path at this point is removed including links
"""
self._copy_temporary_directory_contents_to_proper_path()
shutil.rmtree(self.temporary_outputfiles_basename)
@property
def sampler_function_kwargs(self):
......@@ -253,19 +286,8 @@ class Ultranest(NestedSampler):
stepsampler = self.kwargs.pop("step_sampler", None)
temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
self.temporary_outputfiles_basename = temporary_outputfiles_basename
logger.info("Using temporary file {}".format(temporary_outputfiles_basename))
check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
self.kwargs["log_dir"] = self.temporary_outputfiles_basename
# Symlink the temporary directory with the target directory: ensures data is stored on exit
os.symlink(
os.path.abspath(self.temporary_outputfiles_basename),
os.path.abspath(self.outputfiles_basename),
target_is_directory=True,
)
self._setup_run_directory()
self._check_and_load_sampling_time_file()
# use reactive nested sampler when no live points are given
if self.kwargs.get("num_live_points", None) is not None:
......@@ -289,18 +311,66 @@ class Ultranest(NestedSampler):
"The default step sampling will be used instead."
)
results = sampler.run(**self.sampler_function_kwargs)
if self.use_temporary_directory:
self._viz_callback_counter = 1
self.copy_temporary_directory_to_proper_path()
self.start_time = time.time()
results = sampler.run(**self.sampler_function_kwargs)
self._calculate_and_save_sampling_time()
# Clean up
shutil.rmtree(temporary_outputfiles_basename)
self._clean_up_run_directory()
self._generate_result(results)
self.calc_likelihood_count()
return self.result
def _setup_run_directory(self):
"""
If using a temporary directory, the output directory is moved to the
temporary directory and symlinked back.
"""
if self.use_temporary_directory:
temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
self.temporary_outputfiles_basename = temporary_outputfiles_basename
if os.path.exists(self.outputfiles_basename):
distutils.dir_util.copy_tree(
self.outputfiles_basename, self.temporary_outputfiles_basename
)
check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
self.kwargs["log_dir"] = self.temporary_outputfiles_basename
logger.info(
"Using temporary file {}".format(temporary_outputfiles_basename)
)
else:
check_directory_exists_and_if_not_mkdir(self.outputfiles_basename)
self.kwargs["log_dir"] = self.outputfiles_basename
logger.info("Using output file {}".format(self.outputfiles_basename))
def _clean_up_run_directory(self):
if self.use_temporary_directory:
self._move_temporary_directory_to_proper_path()
self.kwargs["log_dir"] = self.outputfiles_basename
def _check_and_load_sampling_time_file(self):
self.time_file_path = os.path.join(self.kwargs["log_dir"], "sampling_time.dat")
if os.path.exists(self.time_file_path):
with open(self.time_file_path, "r") as time_file:
self.total_sampling_time = float(time_file.readline())
else:
self.total_sampling_time = 0
def _calculate_and_save_sampling_time(self):
current_time = time.time()
new_sampling_time = current_time - self.start_time
self.total_sampling_time += new_sampling_time
with open(self.time_file_path, "w") as time_file:
time_file.write(str(self.total_sampling_time))
self.start_time = current_time
def _generate_result(self, out):
# extract results (samples stored in "v" will change to "points",
# weights stored in "w" will change to "weights")
......@@ -325,3 +395,4 @@ class Ultranest(NestedSampler):
self.result.log_evidence_err = out["logzerr"]
self.result.outputfiles_basename = self.outputfiles_basename
self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time)
......@@ -836,7 +836,7 @@ class TestUltranest(unittest.TestCase):
dlogz=None,
max_iters=None,
update_interval_iter_fraction=0.2,
viz_callback="auto",
viz_callback=None,
dKL=0.5,
frac_remain=0.01,
Lepsilon=0.001,
......@@ -850,6 +850,7 @@ class TestUltranest(unittest.TestCase):
self.assertListEqual([1, 0], self.sampler.kwargs["wrapped_params"]) # Check this separately
self.sampler.kwargs["wrapped_params"] = None # The dict comparison can't handle lists
self.sampler.kwargs["derived_param_names"] = None
self.sampler.kwargs["viz_callback"] = None
self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs(self):
......@@ -870,7 +871,7 @@ class TestUltranest(unittest.TestCase):
dlogz=None,
max_iters=None,
update_interval_iter_fraction=0.2,
viz_callback="auto",
viz_callback=None,
dKL=0.5,
frac_remain=0.01,
Lepsilon=0.001,
......@@ -884,10 +885,11 @@ class TestUltranest(unittest.TestCase):
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
del new_kwargs['num_live_points']
new_kwargs['wrapped_params'] = None # The dict comparison can't handle lists
new_kwargs["derived_param_names"] = None
new_kwargs[equiv] = 123
self.sampler.kwargs = new_kwargs
self.sampler.kwargs["wrapped_params"] = None
self.sampler.kwargs["derived_param_names"] = None
self.sampler.kwargs["viz_callback"] = None
self.assertDictEqual(expected, self.sampler.kwargs)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment