From ad8a0d683f43ccffe0a3ce4b1de1aa8a5df28949 Mon Sep 17 00:00:00 2001 From: Matthew David Pitkin <matthew.pitkin@ligo.org> Date: Thu, 11 Jun 2020 19:35:04 -0500 Subject: [PATCH] Apply similar changes to those in !804 to help file transfer on Condor for ultranest --- bilby/core/sampler/ultranest.py | 145 ++++++++++++++++++++++++-------- test/sampler_test.py | 10 ++- 2 files changed, 114 insertions(+), 41 deletions(-) diff --git a/bilby/core/sampler/ultranest.py b/bilby/core/sampler/ultranest.py index 530321df1..687ed3c08 100644 --- a/bilby/core/sampler/ultranest.py +++ b/bilby/core/sampler/ultranest.py @@ -1,9 +1,13 @@ from __future__ import absolute_import +import datetime +import distutils.dir_util +import inspect import os import shutil import signal import tempfile +import time import numpy as np from pandas import DataFrame @@ -59,7 +63,7 @@ class Ultranest(NestedSampler): dlogz=None, max_iters=None, update_interval_iter_fraction=0.2, - viz_callback="auto", + viz_callback=None, dKL=0.5, frac_remain=0.01, Lepsilon=0.001, @@ -81,6 +85,8 @@ class Ultranest(NestedSampler): plot=False, exit_code=77, skip_import_verification=False, + temporary_directory=True, + callback_interval=10, **kwargs, ): super(Ultranest, self).__init__( @@ -95,6 +101,12 @@ class Ultranest(NestedSampler): **kwargs, ) self._apply_ultranest_boundaries() + self.use_temporary_directory = temporary_directory + + if self.use_temporary_directory: + # set callback interval, so copying of results does not thrash the + # disk (ultranest will call viz_callback quite a lot) + self.callback_interval = callback_interval signal.signal(signal.SIGTERM, self.write_current_state_and_exit) signal.signal(signal.SIGINT, self.write_current_state_and_exit) @@ -113,9 +125,18 @@ class Ultranest(NestedSampler): """ Check the kwargs """ self.outputfiles_basename = self.kwargs.pop("log_dir", None) + if self.kwargs["viz_callback"] is None: + self.kwargs["viz_callback"] = self._viz_callback NestedSampler._verify_kwargs_against_default_kwargs(self) + def _viz_callback(self, *args, **kwargs): + if self.use_temporary_directory: + if not (self._viz_callback_counter % self.callback_interval): + self._copy_temporary_directory_contents_to_proper_path() + self._calculate_and_save_sampling_time() + self._viz_callback_counter += 1 + def _apply_ultranest_boundaries(self): if ( self.kwargs["wrapped_params"] is None @@ -136,9 +157,11 @@ class Ultranest(NestedSampler): @outputfiles_basename.setter def outputfiles_basename(self, outputfiles_basename): if outputfiles_basename is None: - outputfiles_basename = "{}/ultra_{}".format(self.outdir, self.label) - if outputfiles_basename.endswith("/") is True: - outputfiles_basename = outputfiles_basename.rstrip("/") + outputfiles_basename = os.path.join( + self.outdir, "ultra_{}/".format(self.label) + ) + if not outputfiles_basename.endswith("/"): + outputfiles_basename += "/" check_directory_exists_and_if_not_mkdir(self.outdir) self._outputfiles_basename = outputfiles_basename @@ -148,7 +171,7 @@ class Ultranest(NestedSampler): @temporary_outputfiles_basename.setter def temporary_outputfiles_basename(self, temporary_outputfiles_basename): - if temporary_outputfiles_basename.endswith("/") is False: + if not temporary_outputfiles_basename.endswith("/"): temporary_outputfiles_basename = "{}/".format( temporary_outputfiles_basename ) @@ -157,10 +180,6 @@ class Ultranest(NestedSampler): shutil.copytree( self.outputfiles_basename, self.temporary_outputfiles_basename ) - if os.path.islink(self.outputfiles_basename): - os.unlink(self.outputfiles_basename) - else: - shutil.rmtree(self.outputfiles_basename) def write_current_state_and_exit(self, signum=None, frame=None): """ Write current state and exit on exit_code """ @@ -169,24 +188,38 @@ class Ultranest(NestedSampler): signum, self.exit_code ) ) - # self.copy_temporary_directory_to_proper_path() + self._calculate_and_save_sampling_time() + if self.use_temporary_directory: + self._move_temporary_directory_to_proper_path() os._exit(self.exit_code) - def copy_temporary_directory_to_proper_path(self): - logger.info( - "Overwriting {} with {}".format( - self.outputfiles_basename, self.temporary_outputfiles_basename + def _copy_temporary_directory_contents_to_proper_path(self): + """ + Copy the temporary back to the proper path. + Do not delete the temporary directory. + """ + if inspect.stack()[1].function != "_viz_callback": + logger.info( + "Overwriting {} with {}".format( + self.outputfiles_basename, self.temporary_outputfiles_basename + ) ) + if self.outputfiles_basename.endswith("/"): + outputfiles_basename_stripped = self.outputfiles_basename[:-1] + else: + outputfiles_basename_stripped = self.outputfiles_basename + distutils.dir_util.copy_tree( + self.temporary_outputfiles_basename, outputfiles_basename_stripped ) - # First remove anything in the outputfiles_basename for overwriting - if os.path.exists(self.outputfiles_basename): - if os.path.islink(self.outputfiles_basename): - os.unlink(self.outputfiles_basename) - else: - shutil.rmtree(self.outputfiles_basename, ignore_errors=True) + def _move_temporary_directory_to_proper_path(self): + """ + Move the temporary back to the proper path - shutil.copytree(self.temporary_outputfiles_basename, self.outputfiles_basename) + Anything in the proper path at this point is removed including links + """ + self._copy_temporary_directory_contents_to_proper_path() + shutil.rmtree(self.temporary_outputfiles_basename) @property def sampler_function_kwargs(self): @@ -253,19 +286,8 @@ class Ultranest(NestedSampler): stepsampler = self.kwargs.pop("step_sampler", None) - temporary_outputfiles_basename = tempfile.TemporaryDirectory().name - self.temporary_outputfiles_basename = temporary_outputfiles_basename - logger.info("Using temporary file {}".format(temporary_outputfiles_basename)) - - check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename) - self.kwargs["log_dir"] = self.temporary_outputfiles_basename - - # Symlink the temporary directory with the target directory: ensures data is stored on exit - os.symlink( - os.path.abspath(self.temporary_outputfiles_basename), - os.path.abspath(self.outputfiles_basename), - target_is_directory=True, - ) + self._setup_run_directory() + self._check_and_load_sampling_time_file() # use reactive nested sampler when no live points are given if self.kwargs.get("num_live_points", None) is not None: @@ -289,18 +311,66 @@ class Ultranest(NestedSampler): "The default step sampling will be used instead." ) - results = sampler.run(**self.sampler_function_kwargs) + if self.use_temporary_directory: + self._viz_callback_counter = 1 - self.copy_temporary_directory_to_proper_path() + self.start_time = time.time() + results = sampler.run(**self.sampler_function_kwargs) + self._calculate_and_save_sampling_time() # Clean up - shutil.rmtree(temporary_outputfiles_basename) + self._clean_up_run_directory() self._generate_result(results) self.calc_likelihood_count() return self.result + def _setup_run_directory(self): + """ + If using a temporary directory, the output directory is moved to the + temporary directory and symlinked back. + """ + if self.use_temporary_directory: + temporary_outputfiles_basename = tempfile.TemporaryDirectory().name + self.temporary_outputfiles_basename = temporary_outputfiles_basename + + if os.path.exists(self.outputfiles_basename): + distutils.dir_util.copy_tree( + self.outputfiles_basename, self.temporary_outputfiles_basename + ) + check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename) + + self.kwargs["log_dir"] = self.temporary_outputfiles_basename + logger.info( + "Using temporary file {}".format(temporary_outputfiles_basename) + ) + else: + check_directory_exists_and_if_not_mkdir(self.outputfiles_basename) + self.kwargs["log_dir"] = self.outputfiles_basename + logger.info("Using output file {}".format(self.outputfiles_basename)) + + def _clean_up_run_directory(self): + if self.use_temporary_directory: + self._move_temporary_directory_to_proper_path() + self.kwargs["log_dir"] = self.outputfiles_basename + + def _check_and_load_sampling_time_file(self): + self.time_file_path = os.path.join(self.kwargs["log_dir"], "sampling_time.dat") + if os.path.exists(self.time_file_path): + with open(self.time_file_path, "r") as time_file: + self.total_sampling_time = float(time_file.readline()) + else: + self.total_sampling_time = 0 + + def _calculate_and_save_sampling_time(self): + current_time = time.time() + new_sampling_time = current_time - self.start_time + self.total_sampling_time += new_sampling_time + with open(self.time_file_path, "w") as time_file: + time_file.write(str(self.total_sampling_time)) + self.start_time = current_time + def _generate_result(self, out): # extract results (samples stored in "v" will change to "points", # weights stored in "w" will change to "weights") @@ -325,3 +395,4 @@ class Ultranest(NestedSampler): self.result.log_evidence_err = out["logzerr"] self.result.outputfiles_basename = self.outputfiles_basename + self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time) diff --git a/test/sampler_test.py b/test/sampler_test.py index d9879543c..c3c5e4d41 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -836,7 +836,7 @@ class TestUltranest(unittest.TestCase): dlogz=None, max_iters=None, update_interval_iter_fraction=0.2, - viz_callback="auto", + viz_callback=None, dKL=0.5, frac_remain=0.01, Lepsilon=0.001, @@ -850,6 +850,7 @@ class TestUltranest(unittest.TestCase): self.assertListEqual([1, 0], self.sampler.kwargs["wrapped_params"]) # Check this separately self.sampler.kwargs["wrapped_params"] = None # The dict comparison can't handle lists self.sampler.kwargs["derived_param_names"] = None + self.sampler.kwargs["viz_callback"] = None self.assertDictEqual(expected, self.sampler.kwargs) def test_translate_kwargs(self): @@ -870,7 +871,7 @@ class TestUltranest(unittest.TestCase): dlogz=None, max_iters=None, update_interval_iter_fraction=0.2, - viz_callback="auto", + viz_callback=None, dKL=0.5, frac_remain=0.01, Lepsilon=0.001, @@ -884,10 +885,11 @@ class TestUltranest(unittest.TestCase): for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: new_kwargs = self.sampler.kwargs.copy() del new_kwargs['num_live_points'] - new_kwargs['wrapped_params'] = None # The dict comparison can't handle lists - new_kwargs["derived_param_names"] = None new_kwargs[equiv] = 123 self.sampler.kwargs = new_kwargs + self.sampler.kwargs["wrapped_params"] = None + self.sampler.kwargs["derived_param_names"] = None + self.sampler.kwargs["viz_callback"] = None self.assertDictEqual(expected, self.sampler.kwargs) -- GitLab