From 581775da065162bcbd63f62c4ef0e88667c88baf Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Mon, 17 Jun 2019 19:41:11 -0500
Subject: [PATCH] Resolve sampling_time persistence between restarts for
 dynesty

This stores the sampling time in the resume file for dynesty and loads
it allowing the sampling time for jobs which get restarted to reflect
the total sampling time. Also defines the general mechanism for any
other samper to act similarly.
---
 bilby/core/sampler/__init__.py | 11 +++++++----
 bilby/core/sampler/dynesty.py  | 11 ++++++++++-
 2 files changed, 17 insertions(+), 5 deletions(-)

diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py
index 5ff4c7ffe..424f1e108 100644
--- a/bilby/core/sampler/__init__.py
+++ b/bilby/core/sampler/__init__.py
@@ -163,15 +163,18 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
         return sampler.cached_result
 
     start_time = datetime.datetime.now()
-
     if command_line_args.bilby_test_mode:
         result = sampler._run_test()
     else:
         result = sampler.run_sampler()
-
     end_time = datetime.datetime.now()
-    result.sampling_time = (end_time - start_time).total_seconds()
-    logger.info('Sampling time: {}'.format(end_time - start_time))
+
+    # Some samplers calculate the sampling time internally
+    if result.sampling_time is None:
+        result.sampling_time = end_time - start_time
+    logger.info('Sampling time: {}'.format(result.sampling_time))
+    # Convert sampling time into seconds
+    result.sampling_time = result.sampling_time.total_seconds()
 
     if sampler.use_ratio:
         result.log_noise_evidence = likelihood.noise_log_likelihood()
diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index 9f0f8d80f..625264719 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -1,5 +1,6 @@
 from __future__ import absolute_import
 
+import datetime
 import os
 import sys
 import pickle
@@ -116,6 +117,7 @@ class Dynesty(NestedSampler):
         logger.info("Checkpoint every n_check_point = {}".format(self.n_check_point))
 
         self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
+        self.sampling_time = datetime.timedelta()
 
         signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
         signal.signal(signal.SIGINT, self.write_current_state_and_exit)
@@ -245,6 +247,7 @@ class Dynesty(NestedSampler):
         self.calc_likelihood_count()
         self.result.log_evidence = out.logz[-1]
         self.result.log_evidence_err = out.logzerr[-1]
+        self.result.sampling_time = self.sampling_time
 
         if self.plot:
             self.generate_trace_plots(out)
@@ -267,6 +270,7 @@ class Dynesty(NestedSampler):
         sampler_kwargs = self.sampler_function_kwargs.copy()
         sampler_kwargs['maxcall'] = self.n_check_point
         sampler_kwargs['add_live'] = False
+        self.start_time = datetime.datetime.now()
         while True:
             sampler_kwargs['maxcall'] += self.n_check_point
             self.sampler.run_nested(**sampler_kwargs)
@@ -276,7 +280,6 @@ class Dynesty(NestedSampler):
 
             self.write_current_state()
 
-        self.read_saved_state()
         sampler_kwargs['add_live'] = True
         self.sampler.run_nested(**sampler_kwargs)
         return self.sampler.results
@@ -340,6 +343,7 @@ class Dynesty(NestedSampler):
             self.sampler.live_bound = saved['live_bound']
             self.sampler.live_it = saved['live_it']
             self.sampler.added_live = saved['added_live']
+            self.sampling_time += datetime.timedelta(seconds=saved['sampling_time'])
             return True
 
         else:
@@ -371,6 +375,10 @@ class Dynesty(NestedSampler):
         check_directory_exists_and_if_not_mkdir(self.outdir)
         logger.info("Writing checkpoint file {}".format(self.resume_file))
 
+        end_time = datetime.datetime.now()
+        self.sampling_time += end_time - self.start_time
+        self.start_time = end_time
+
         current_state = dict(
             unit_cube_samples=self.sampler.saved_u,
             physical_samples=self.sampler.saved_v,
@@ -386,6 +394,7 @@ class Dynesty(NestedSampler):
             boundidx=self.sampler.saved_boundidx,
             bounditer=self.sampler.saved_bounditer,
             scale=self.sampler.saved_scale,
+            sampling_time=self.sampling_time.total_seconds()
         )
 
         current_state.update(
-- 
GitLab