Commit 581775da authored by Gregory Ashton's avatar Gregory Ashton Committed by Moritz Huebner

Resolve sampling_time persistence between restarts for dynesty

This stores the sampling time in the resume file for dynesty and loads
it allowing the sampling time for jobs which get restarted to reflect
the total sampling time. Also defines the general mechanism for any
other samper to act similarly.
parent bfd7510e
......@@ -163,15 +163,18 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
return sampler.cached_result
start_time = datetime.datetime.now()
if command_line_args.bilby_test_mode:
result = sampler._run_test()
else:
result = sampler.run_sampler()
end_time = datetime.datetime.now()
result.sampling_time = (end_time - start_time).total_seconds()
logger.info('Sampling time: {}'.format(end_time - start_time))
# Some samplers calculate the sampling time internally
if result.sampling_time is None:
result.sampling_time = end_time - start_time
logger.info('Sampling time: {}'.format(result.sampling_time))
# Convert sampling time into seconds
result.sampling_time = result.sampling_time.total_seconds()
if sampler.use_ratio:
result.log_noise_evidence = likelihood.noise_log_likelihood()
......
from __future__ import absolute_import
import datetime
import os
import sys
import pickle
......@@ -116,6 +117,7 @@ class Dynesty(NestedSampler):
logger.info("Checkpoint every n_check_point = {}".format(self.n_check_point))
self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
self.sampling_time = datetime.timedelta()
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
......@@ -245,6 +247,7 @@ class Dynesty(NestedSampler):
self.calc_likelihood_count()
self.result.log_evidence = out.logz[-1]
self.result.log_evidence_err = out.logzerr[-1]
self.result.sampling_time = self.sampling_time
if self.plot:
self.generate_trace_plots(out)
......@@ -267,6 +270,7 @@ class Dynesty(NestedSampler):
sampler_kwargs = self.sampler_function_kwargs.copy()
sampler_kwargs['maxcall'] = self.n_check_point
sampler_kwargs['add_live'] = False
self.start_time = datetime.datetime.now()
while True:
sampler_kwargs['maxcall'] += self.n_check_point
self.sampler.run_nested(**sampler_kwargs)
......@@ -276,7 +280,6 @@ class Dynesty(NestedSampler):
self.write_current_state()
self.read_saved_state()
sampler_kwargs['add_live'] = True
self.sampler.run_nested(**sampler_kwargs)
return self.sampler.results
......@@ -340,6 +343,7 @@ class Dynesty(NestedSampler):
self.sampler.live_bound = saved['live_bound']
self.sampler.live_it = saved['live_it']
self.sampler.added_live = saved['added_live']
self.sampling_time += datetime.timedelta(seconds=saved['sampling_time'])
return True
else:
......@@ -371,6 +375,10 @@ class Dynesty(NestedSampler):
check_directory_exists_and_if_not_mkdir(self.outdir)
logger.info("Writing checkpoint file {}".format(self.resume_file))
end_time = datetime.datetime.now()
self.sampling_time += end_time - self.start_time
self.start_time = end_time
current_state = dict(
unit_cube_samples=self.sampler.saved_u,
physical_samples=self.sampler.saved_v,
......@@ -386,6 +394,7 @@ class Dynesty(NestedSampler):
boundidx=self.sampler.saved_boundidx,
bounditer=self.sampler.saved_bounditer,
scale=self.sampler.saved_scale,
sampling_time=self.sampling_time.total_seconds()
)
current_state.update(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment