Skip to content
Snippets Groups Projects

Pickle dump entire sampler in dynesty

Merged Colm Talbot requested to merge improve-dynesty-checkpointing into master
Files
2
@@ -6,13 +6,14 @@ import os
import sys
import pickle
import signal
import time
import tqdm
import matplotlib.pyplot as plt
import numpy as np
from pandas import DataFrame
from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect
from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect, safe_file_dump
from .base_sampler import Sampler, NestedSampler
from numpy import linalg
@@ -71,14 +72,14 @@ class Dynesty(NestedSampler):
check_point_plot: bool,
If true, generate a trace plot along with the check-point
check_point_delta_t: float (600)
The approximate checkpoint period (in seconds). Should the run be
interrupted, it can be resumed from the last checkpoint. Set to
`None` to turn-off check pointing
The minimum checkpoint period (in seconds). Should the run be
interrupted, it can be resumed from the last checkpoint.
n_check_point: int, optional (None)
The number of steps to take before check pointing (override
check_point_delta_t).
The number of steps to take before checking whether to check_point.
resume: bool
If true, resume run from checkpoint (if available)
exit_code: int
The code which the same exits on if it hasn't finished sampling
"""
default_kwargs = dict(bound='multi', sample='rwalk',
verbose=True, periodic=None, reflective=None,
@@ -99,7 +100,7 @@ class Dynesty(NestedSampler):
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
check_point=True, check_point_plot=True, n_check_point=None,
check_point_delta_t=600, resume=True, **kwargs):
check_point_delta_t=600, resume=True, exit_code=130, **kwargs):
super(Dynesty, self).__init__(likelihood=likelihood, priors=priors,
outdir=outdir, label=label, use_ratio=use_ratio,
plot=plot, skip_import_verification=skip_import_verification,
@@ -111,19 +112,16 @@ class Dynesty(NestedSampler):
self._periodic = list()
self._reflective = list()
self._apply_dynesty_boundaries()
if self.n_check_point is None:
# If the log_likelihood_eval_time is not calculable then
# check_point is set to False.
if np.isnan(self._log_likelihood_eval_time):
self.check_point = False
n_check_point_raw = (check_point_delta_t / self._log_likelihood_eval_time)
n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw)))
self.n_check_point = n_check_point_rnd
logger.info("Checkpoint every n_check_point = {}".format(self.n_check_point))
if self.n_check_point is None:
self.n_check_point = 1000
self.check_point_delta_t = check_point_delta_t
logger.info("Checkpoint every check_point_delta_t = {}s"
.format(check_point_delta_t))
self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
self.sampling_time = datetime.timedelta()
self.exit_code = exit_code
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
@@ -336,8 +334,13 @@ class Dynesty(NestedSampler):
break
old_ncall = self.sampler.ncall
self.write_current_state()
self.plot_current_state()
if os.path.isfile(self.resume_file):
last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file)
else:
last_checkpoint_s = np.inf
if last_checkpoint_s > self.check_point_delta_t:
self.write_current_state()
self.plot_current_state()
if self.sampler.added_live:
self.sampler._remove_live_points()
@@ -379,14 +382,14 @@ class Dynesty(NestedSampler):
self.sampling_time = self.sampler.kwargs.pop("sampling_time")
return True
else:
logger.debug(
logger.info(
"Resume file {} does not exist.".format(self.resume_file))
return False
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
self.write_current_state()
sys.exit(130)
os._exit(self.exit_code)
def write_current_state(self):
"""
@@ -399,6 +402,7 @@ class Dynesty(NestedSampler):
when using pytest. Hopefully, this message won't be triggered during
normal running.
"""
check_directory_exists_and_if_not_mkdir(self.outdir)
end_time = datetime.datetime.now()
if hasattr(self, 'start_time'):
@@ -407,8 +411,8 @@ class Dynesty(NestedSampler):
self.sampler.kwargs["sampling_time"] = self.sampling_time
self.sampler.kwargs["start_time"] = self.start_time
if dill.pickles(self.sampler):
with open(self.resume_file, 'wb') as file:
dill.dump(self.sampler, file)
safe_file_dump(self.sampler, self.resume_file, dill)
logger.info("Written checkpoint file {}".format(self.resume_file))
else:
logger.warning(
"Cannot write pickle resume file! "
Loading