From 581775da065162bcbd63f62c4ef0e88667c88baf Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Mon, 17 Jun 2019 19:41:11 -0500 Subject: [PATCH] Resolve sampling_time persistence between restarts for dynesty This stores the sampling time in the resume file for dynesty and loads it allowing the sampling time for jobs which get restarted to reflect the total sampling time. Also defines the general mechanism for any other samper to act similarly. --- bilby/core/sampler/__init__.py | 11 +++++++---- bilby/core/sampler/dynesty.py | 11 ++++++++++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 5ff4c7ffe..424f1e108 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -163,15 +163,18 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', return sampler.cached_result start_time = datetime.datetime.now() - if command_line_args.bilby_test_mode: result = sampler._run_test() else: result = sampler.run_sampler() - end_time = datetime.datetime.now() - result.sampling_time = (end_time - start_time).total_seconds() - logger.info('Sampling time: {}'.format(end_time - start_time)) + + # Some samplers calculate the sampling time internally + if result.sampling_time is None: + result.sampling_time = end_time - start_time + logger.info('Sampling time: {}'.format(result.sampling_time)) + # Convert sampling time into seconds + result.sampling_time = result.sampling_time.total_seconds() if sampler.use_ratio: result.log_noise_evidence = likelihood.noise_log_likelihood() diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 9f0f8d80f..625264719 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -1,5 +1,6 @@ from __future__ import absolute_import +import datetime import os import sys import pickle @@ -116,6 +117,7 @@ class Dynesty(NestedSampler): logger.info("Checkpoint every n_check_point = {}".format(self.n_check_point)) self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label) + self.sampling_time = datetime.timedelta() signal.signal(signal.SIGTERM, self.write_current_state_and_exit) signal.signal(signal.SIGINT, self.write_current_state_and_exit) @@ -245,6 +247,7 @@ class Dynesty(NestedSampler): self.calc_likelihood_count() self.result.log_evidence = out.logz[-1] self.result.log_evidence_err = out.logzerr[-1] + self.result.sampling_time = self.sampling_time if self.plot: self.generate_trace_plots(out) @@ -267,6 +270,7 @@ class Dynesty(NestedSampler): sampler_kwargs = self.sampler_function_kwargs.copy() sampler_kwargs['maxcall'] = self.n_check_point sampler_kwargs['add_live'] = False + self.start_time = datetime.datetime.now() while True: sampler_kwargs['maxcall'] += self.n_check_point self.sampler.run_nested(**sampler_kwargs) @@ -276,7 +280,6 @@ class Dynesty(NestedSampler): self.write_current_state() - self.read_saved_state() sampler_kwargs['add_live'] = True self.sampler.run_nested(**sampler_kwargs) return self.sampler.results @@ -340,6 +343,7 @@ class Dynesty(NestedSampler): self.sampler.live_bound = saved['live_bound'] self.sampler.live_it = saved['live_it'] self.sampler.added_live = saved['added_live'] + self.sampling_time += datetime.timedelta(seconds=saved['sampling_time']) return True else: @@ -371,6 +375,10 @@ class Dynesty(NestedSampler): check_directory_exists_and_if_not_mkdir(self.outdir) logger.info("Writing checkpoint file {}".format(self.resume_file)) + end_time = datetime.datetime.now() + self.sampling_time += end_time - self.start_time + self.start_time = end_time + current_state = dict( unit_cube_samples=self.sampler.saved_u, physical_samples=self.sampler.saved_v, @@ -386,6 +394,7 @@ class Dynesty(NestedSampler): boundidx=self.sampler.saved_boundidx, bounditer=self.sampler.saved_bounditer, scale=self.sampler.saved_scale, + sampling_time=self.sampling_time.total_seconds() ) current_state.update( -- GitLab