From ad8a0d683f43ccffe0a3ce4b1de1aa8a5df28949 Mon Sep 17 00:00:00 2001
From: Matthew David Pitkin <matthew.pitkin@ligo.org>
Date: Thu, 11 Jun 2020 19:35:04 -0500
Subject: [PATCH] Apply similar changes to those in !804 to help file transfer
 on Condor for ultranest

---
 bilby/core/sampler/ultranest.py | 145 ++++++++++++++++++++++++--------
 test/sampler_test.py            |  10 ++-
 2 files changed, 114 insertions(+), 41 deletions(-)

diff --git a/bilby/core/sampler/ultranest.py b/bilby/core/sampler/ultranest.py
index 530321df1..687ed3c08 100644
--- a/bilby/core/sampler/ultranest.py
+++ b/bilby/core/sampler/ultranest.py
@@ -1,9 +1,13 @@
 from __future__ import absolute_import
 
+import datetime
+import distutils.dir_util
+import inspect
 import os
 import shutil
 import signal
 import tempfile
+import time
 
 import numpy as np
 from pandas import DataFrame
@@ -59,7 +63,7 @@ class Ultranest(NestedSampler):
         dlogz=None,
         max_iters=None,
         update_interval_iter_fraction=0.2,
-        viz_callback="auto",
+        viz_callback=None,
         dKL=0.5,
         frac_remain=0.01,
         Lepsilon=0.001,
@@ -81,6 +85,8 @@ class Ultranest(NestedSampler):
         plot=False,
         exit_code=77,
         skip_import_verification=False,
+        temporary_directory=True,
+        callback_interval=10,
         **kwargs,
     ):
         super(Ultranest, self).__init__(
@@ -95,6 +101,12 @@ class Ultranest(NestedSampler):
             **kwargs,
         )
         self._apply_ultranest_boundaries()
+        self.use_temporary_directory = temporary_directory
+
+        if self.use_temporary_directory:
+            # set callback interval, so copying of results does not thrash the
+            # disk (ultranest will call viz_callback quite a lot)
+            self.callback_interval = callback_interval
 
         signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
         signal.signal(signal.SIGINT, self.write_current_state_and_exit)
@@ -113,9 +125,18 @@ class Ultranest(NestedSampler):
         """ Check the kwargs """
 
         self.outputfiles_basename = self.kwargs.pop("log_dir", None)
+        if self.kwargs["viz_callback"] is None:
+            self.kwargs["viz_callback"] = self._viz_callback
 
         NestedSampler._verify_kwargs_against_default_kwargs(self)
 
+    def _viz_callback(self, *args, **kwargs):
+        if self.use_temporary_directory:
+            if not (self._viz_callback_counter % self.callback_interval):
+                self._copy_temporary_directory_contents_to_proper_path()
+                self._calculate_and_save_sampling_time()
+            self._viz_callback_counter += 1
+
     def _apply_ultranest_boundaries(self):
         if (
             self.kwargs["wrapped_params"] is None
@@ -136,9 +157,11 @@ class Ultranest(NestedSampler):
     @outputfiles_basename.setter
     def outputfiles_basename(self, outputfiles_basename):
         if outputfiles_basename is None:
-            outputfiles_basename = "{}/ultra_{}".format(self.outdir, self.label)
-        if outputfiles_basename.endswith("/") is True:
-            outputfiles_basename = outputfiles_basename.rstrip("/")
+            outputfiles_basename = os.path.join(
+                self.outdir, "ultra_{}/".format(self.label)
+            )
+        if not outputfiles_basename.endswith("/"):
+            outputfiles_basename += "/"
         check_directory_exists_and_if_not_mkdir(self.outdir)
         self._outputfiles_basename = outputfiles_basename
 
@@ -148,7 +171,7 @@ class Ultranest(NestedSampler):
 
     @temporary_outputfiles_basename.setter
     def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
-        if temporary_outputfiles_basename.endswith("/") is False:
+        if not temporary_outputfiles_basename.endswith("/"):
             temporary_outputfiles_basename = "{}/".format(
                 temporary_outputfiles_basename
             )
@@ -157,10 +180,6 @@ class Ultranest(NestedSampler):
             shutil.copytree(
                 self.outputfiles_basename, self.temporary_outputfiles_basename
             )
-            if os.path.islink(self.outputfiles_basename):
-                os.unlink(self.outputfiles_basename)
-            else:
-                shutil.rmtree(self.outputfiles_basename)
 
     def write_current_state_and_exit(self, signum=None, frame=None):
         """ Write current state and exit on exit_code """
@@ -169,24 +188,38 @@ class Ultranest(NestedSampler):
                 signum, self.exit_code
             )
         )
-        # self.copy_temporary_directory_to_proper_path()
+        self._calculate_and_save_sampling_time()
+        if self.use_temporary_directory:
+            self._move_temporary_directory_to_proper_path()
         os._exit(self.exit_code)
 
-    def copy_temporary_directory_to_proper_path(self):
-        logger.info(
-            "Overwriting {} with {}".format(
-                self.outputfiles_basename, self.temporary_outputfiles_basename
+    def _copy_temporary_directory_contents_to_proper_path(self):
+        """
+        Copy the temporary back to the proper path.
+        Do not delete the temporary directory.
+        """
+        if inspect.stack()[1].function != "_viz_callback":
+            logger.info(
+                "Overwriting {} with {}".format(
+                    self.outputfiles_basename, self.temporary_outputfiles_basename
+                )
             )
+        if self.outputfiles_basename.endswith("/"):
+            outputfiles_basename_stripped = self.outputfiles_basename[:-1]
+        else:
+            outputfiles_basename_stripped = self.outputfiles_basename
+        distutils.dir_util.copy_tree(
+            self.temporary_outputfiles_basename, outputfiles_basename_stripped
         )
 
-        # First remove anything in the outputfiles_basename for overwriting
-        if os.path.exists(self.outputfiles_basename):
-            if os.path.islink(self.outputfiles_basename):
-                os.unlink(self.outputfiles_basename)
-            else:
-                shutil.rmtree(self.outputfiles_basename, ignore_errors=True)
+    def _move_temporary_directory_to_proper_path(self):
+        """
+        Move the temporary back to the proper path
 
-        shutil.copytree(self.temporary_outputfiles_basename, self.outputfiles_basename)
+        Anything in the proper path at this point is removed including links
+        """
+        self._copy_temporary_directory_contents_to_proper_path()
+        shutil.rmtree(self.temporary_outputfiles_basename)
 
     @property
     def sampler_function_kwargs(self):
@@ -253,19 +286,8 @@ class Ultranest(NestedSampler):
 
         stepsampler = self.kwargs.pop("step_sampler", None)
 
-        temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
-        self.temporary_outputfiles_basename = temporary_outputfiles_basename
-        logger.info("Using temporary file {}".format(temporary_outputfiles_basename))
-
-        check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
-        self.kwargs["log_dir"] = self.temporary_outputfiles_basename
-
-        # Symlink the temporary directory with the target directory: ensures data is stored on exit
-        os.symlink(
-            os.path.abspath(self.temporary_outputfiles_basename),
-            os.path.abspath(self.outputfiles_basename),
-            target_is_directory=True,
-        )
+        self._setup_run_directory()
+        self._check_and_load_sampling_time_file()
 
         # use reactive nested sampler when no live points are given
         if self.kwargs.get("num_live_points", None) is not None:
@@ -289,18 +311,66 @@ class Ultranest(NestedSampler):
                     "The default step sampling will be used instead."
                 )
 
-        results = sampler.run(**self.sampler_function_kwargs)
+        if self.use_temporary_directory:
+            self._viz_callback_counter = 1
 
-        self.copy_temporary_directory_to_proper_path()
+        self.start_time = time.time()
+        results = sampler.run(**self.sampler_function_kwargs)
+        self._calculate_and_save_sampling_time()
 
         # Clean up
-        shutil.rmtree(temporary_outputfiles_basename)
+        self._clean_up_run_directory()
 
         self._generate_result(results)
         self.calc_likelihood_count()
 
         return self.result
 
+    def _setup_run_directory(self):
+        """
+        If using a temporary directory, the output directory is moved to the
+        temporary directory and symlinked back.
+        """
+        if self.use_temporary_directory:
+            temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
+            self.temporary_outputfiles_basename = temporary_outputfiles_basename
+
+            if os.path.exists(self.outputfiles_basename):
+                distutils.dir_util.copy_tree(
+                    self.outputfiles_basename, self.temporary_outputfiles_basename
+                )
+            check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
+
+            self.kwargs["log_dir"] = self.temporary_outputfiles_basename
+            logger.info(
+                "Using temporary file {}".format(temporary_outputfiles_basename)
+            )
+        else:
+            check_directory_exists_and_if_not_mkdir(self.outputfiles_basename)
+            self.kwargs["log_dir"] = self.outputfiles_basename
+            logger.info("Using output file {}".format(self.outputfiles_basename))
+
+    def _clean_up_run_directory(self):
+        if self.use_temporary_directory:
+            self._move_temporary_directory_to_proper_path()
+            self.kwargs["log_dir"] = self.outputfiles_basename
+
+    def _check_and_load_sampling_time_file(self):
+        self.time_file_path = os.path.join(self.kwargs["log_dir"], "sampling_time.dat")
+        if os.path.exists(self.time_file_path):
+            with open(self.time_file_path, "r") as time_file:
+                self.total_sampling_time = float(time_file.readline())
+        else:
+            self.total_sampling_time = 0
+
+    def _calculate_and_save_sampling_time(self):
+        current_time = time.time()
+        new_sampling_time = current_time - self.start_time
+        self.total_sampling_time += new_sampling_time
+        with open(self.time_file_path, "w") as time_file:
+            time_file.write(str(self.total_sampling_time))
+        self.start_time = current_time
+
     def _generate_result(self, out):
         # extract results (samples stored in "v" will change to "points",
         # weights stored in "w" will change to "weights")
@@ -325,3 +395,4 @@ class Ultranest(NestedSampler):
         self.result.log_evidence_err = out["logzerr"]
 
         self.result.outputfiles_basename = self.outputfiles_basename
+        self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time)
diff --git a/test/sampler_test.py b/test/sampler_test.py
index d9879543c..c3c5e4d41 100644
--- a/test/sampler_test.py
+++ b/test/sampler_test.py
@@ -836,7 +836,7 @@ class TestUltranest(unittest.TestCase):
             dlogz=None,
             max_iters=None,
             update_interval_iter_fraction=0.2,
-            viz_callback="auto",
+            viz_callback=None,
             dKL=0.5,
             frac_remain=0.01,
             Lepsilon=0.001,
@@ -850,6 +850,7 @@ class TestUltranest(unittest.TestCase):
         self.assertListEqual([1, 0], self.sampler.kwargs["wrapped_params"])  # Check this separately
         self.sampler.kwargs["wrapped_params"] = None  # The dict comparison can't handle lists
         self.sampler.kwargs["derived_param_names"] = None
+        self.sampler.kwargs["viz_callback"] = None
         self.assertDictEqual(expected, self.sampler.kwargs)
 
     def test_translate_kwargs(self):
@@ -870,7 +871,7 @@ class TestUltranest(unittest.TestCase):
             dlogz=None,
             max_iters=None,
             update_interval_iter_fraction=0.2,
-            viz_callback="auto",
+            viz_callback=None,
             dKL=0.5,
             frac_remain=0.01,
             Lepsilon=0.001,
@@ -884,10 +885,11 @@ class TestUltranest(unittest.TestCase):
         for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
             new_kwargs = self.sampler.kwargs.copy()
             del new_kwargs['num_live_points']
-            new_kwargs['wrapped_params'] = None  # The dict comparison can't handle lists
-            new_kwargs["derived_param_names"] = None
             new_kwargs[equiv] = 123
             self.sampler.kwargs = new_kwargs
+            self.sampler.kwargs["wrapped_params"] = None
+            self.sampler.kwargs["derived_param_names"] = None
+            self.sampler.kwargs["viz_callback"] = None
             self.assertDictEqual(expected, self.sampler.kwargs)
 
 
-- 
GitLab