Skip to content
Snippets Groups Projects

Pickle dump entire sampler in dynesty

Merged Colm Talbot requested to merge improve-dynesty-checkpointing into master
Files
2
+ 113
133
from __future__ import absolute_import
import datetime
import dill
import os
import sys
import pickle
import signal
import time
import tqdm
import matplotlib.pyplot as plt
import numpy as np
from pandas import DataFrame
from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect
from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect, safe_file_dump
from .base_sampler import Sampler, NestedSampler
from numpy import linalg
@@ -70,14 +72,14 @@ class Dynesty(NestedSampler):
check_point_plot: bool,
If true, generate a trace plot along with the check-point
check_point_delta_t: float (600)
The approximate checkpoint period (in seconds). Should the run be
interrupted, it can be resumed from the last checkpoint. Set to
`None` to turn-off check pointing
The minimum checkpoint period (in seconds). Should the run be
interrupted, it can be resumed from the last checkpoint.
n_check_point: int, optional (None)
The number of steps to take before check pointing (override
check_point_delta_t).
The number of steps to take before checking whether to check_point.
resume: bool
If true, resume run from checkpoint (if available)
exit_code: int
The code which the same exits on if it hasn't finished sampling
"""
default_kwargs = dict(bound='multi', sample='rwalk',
verbose=True, periodic=None, reflective=None,
@@ -98,7 +100,7 @@ class Dynesty(NestedSampler):
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
check_point=True, check_point_plot=True, n_check_point=None,
check_point_delta_t=600, resume=True, **kwargs):
check_point_delta_t=600, resume=True, exit_code=130, **kwargs):
super(Dynesty, self).__init__(likelihood=likelihood, priors=priors,
outdir=outdir, label=label, use_ratio=use_ratio,
plot=plot, skip_import_verification=skip_import_verification,
@@ -110,19 +112,16 @@ class Dynesty(NestedSampler):
self._periodic = list()
self._reflective = list()
self._apply_dynesty_boundaries()
if self.n_check_point is None:
# If the log_likelihood_eval_time is not calculable then
# check_point is set to False.
if np.isnan(self._log_likelihood_eval_time):
self.check_point = False
n_check_point_raw = (check_point_delta_t / self._log_likelihood_eval_time)
n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw)))
self.n_check_point = n_check_point_rnd
logger.info("Checkpoint every n_check_point = {}".format(self.n_check_point))
if self.n_check_point is None:
self.n_check_point = 1000
self.check_point_delta_t = check_point_delta_t
logger.info("Checkpoint every check_point_delta_t = {}s"
.format(check_point_delta_t))
self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
self.sampling_time = datetime.timedelta()
self.exit_code = exit_code
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
@@ -131,7 +130,8 @@ class Dynesty(NestedSampler):
def __getstate__(self):
""" For pickle: remove external_sampler, which can be an unpicklable "module" """
state = self.__dict__.copy()
del state['external_sampler']
if "external_sampler" in state:
del state['external_sampler']
return state
@property
@@ -319,9 +319,9 @@ class Dynesty(NestedSampler):
def _run_external_sampler_with_checkpointing(self):
logger.debug("Running sampler with checkpointing")
if self.resume:
resume = self.read_saved_state(continuing=True)
if resume:
logger.info('Resuming from previous run.')
resume_file_loaded = self.read_saved_state(continuing=True)
if resume_file_loaded:
logger.info('Resume file successfully loaded.')
old_ncall = self.sampler.ncall
sampler_kwargs = self.sampler_function_kwargs.copy()
@@ -334,8 +334,15 @@ class Dynesty(NestedSampler):
break
old_ncall = self.sampler.ncall
self.sampler._remove_live_points()
self.write_current_state()
if os.path.isfile(self.resume_file):
last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file)
else:
last_checkpoint_s = (datetime.datetime.now() - self.start_time).total_seconds()
if last_checkpoint_s > self.check_point_delta_t:
self.write_current_state()
self.plot_current_state()
if self.sampler.added_live:
self.sampler._remove_live_points()
sampler_kwargs['add_live'] = True
self._run_nested_wrapper(sampler_kwargs)
@@ -348,152 +355,126 @@ class Dynesty(NestedSampler):
def read_saved_state(self, continuing=False):
"""
Read a saved state of the sampler to disk.
Read a pickled saved state of the sampler to disk.
The required information to reconstruct the state of the run is read
from a pickle file.
This currently adds the whole chain to the sampler.
We then remove the old checkpoint and write all unnecessary items back
to disk.
FIXME: Load only the necessary quantities, rather than read/write?
If the live points are present and the run is continuing
they are removed.
The random state must be reset, as this isn't saved by the pickle.
`nqueue` is set to a negative number to trigger the queue to be
refilled before the first iteration.
The previous run time is set to self.
Parameters
----------
sampler: `dynesty.NestedSampler`
NestedSampler instance to reconstruct from the saved state.
continuing: bool
Whether the run is continuing or terminating, if True, the loaded
state is mostly written back to disk.
"""
import dynesty
if os.path.isfile(self.resume_file):
logger.info("Reading resume file {}".format(self.resume_file))
try:
with open(self.resume_file, 'rb') as file:
saved = pickle.load(file)
logger.info(
"Succesfuly read resume file {}".format(self.resume_file))
except EOFError as e:
logger.warning("Resume file reading failed with error {}".format(e))
return False
self.sampler.saved_u = list(saved['unit_cube_samples'])
self.sampler.saved_v = list(saved['physical_samples'])
self.sampler.saved_logl = list(saved['sample_likelihoods'])
self.sampler.saved_logvol = list(saved['sample_log_volume'])
self.sampler.saved_logwt = list(saved['sample_log_weights'])
self.sampler.saved_logz = list(saved['cumulative_log_evidence'])
self.sampler.saved_logzvar = list(saved['cumulative_log_evidence_error'])
self.sampler.saved_id = list(saved['id'])
self.sampler.saved_it = list(saved['it'])
self.sampler.saved_nc = list(saved['nc'])
self.sampler.saved_boundidx = list(saved['boundidx'])
self.sampler.saved_bounditer = list(saved['bounditer'])
self.sampler.saved_scale = list(saved['scale'])
self.sampler.saved_h = list(saved['cumulative_information'])
self.sampler.ncall = saved['ncall']
self.sampler.live_logl = list(saved['live_logl'])
self.sampler.it = saved['iteration'] + 1
self.sampler.live_u = saved['live_u']
self.sampler.live_v = saved['live_v']
self.sampler.nlive = saved['nlive']
self.sampler.live_bound = saved['live_bound']
self.sampler.live_it = saved['live_it']
self.sampler.added_live = saved['added_live']
self.sampler.bound = saved['bound']
self.sampler.nbound = saved['nbound']
self.sampling_time += datetime.timedelta(seconds=saved['sampling_time'])
with open(self.resume_file, 'rb') as file:
sampler = dill.load(file)
if isinstance(sampler, dynesty.nestedsamplers.MultiEllipsoidSampler) is False:
logger.warning(
"The resume file {} is corrupted or the version of "
"bilby has changed between runs. This resume file will "
"be ignored."
.format(self.resume_file))
return False
self.sampler = sampler
if self.sampler.added_live and continuing:
self.sampler._remove_live_points()
self.sampler.nqueue = -1
self.sampler.rstate = np.random
self.start_time = self.sampler.kwargs.pop("start_time")
self.sampling_time = self.sampler.kwargs.pop("sampling_time")
return True
else:
logger.debug(
"No resume file {}".format(self.resume_file))
logger.info(
"Resume file {} does not exist.".format(self.resume_file))
return False
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
self.write_current_state(plot=False)
sys.exit(130)
if signum == 14:
logger.info(
"Run interrupted by alarm signal {}: checkpoint and exit on {}"
.format(signum, self.exit_code))
else:
logger.info(
"Run interrupted by signal {}: checkpoint and exit on {}"
.format(signum, self.exit_code))
self.write_current_state()
os._exit(self.exit_code)
def write_current_state(self, plot=True):
def write_current_state(self):
"""
Write the current state of the sampler to disk.
The required information to reconstruct the state of the run are written
to an hdf5 file.
All but the most recent removed live point in the chain are removed from
the sampler to reduce memory usage.
This means it is necessary to not append the first live point to the
file if updating a previous checkpoint.
The sampler is pickle dumped using `dill`.
The sampling time is also stored to get the full CPU time for the run.
Parameters
----------
sampler: `dynesty.NestedSampler`
NestedSampler to write to disk.
The check of whether the sampler is picklable is to catch an error
when using pytest. Hopefully, this message won't be triggered during
normal running.
"""
check_directory_exists_and_if_not_mkdir(self.outdir)
print("")
logger.info("Writing checkpoint file {}".format(self.resume_file))
check_directory_exists_and_if_not_mkdir(self.outdir)
end_time = datetime.datetime.now()
if hasattr(self, 'start_time'):
self.sampling_time += end_time - self.start_time
self.start_time = end_time
self.sampler.kwargs["sampling_time"] = self.sampling_time
self.sampler.kwargs["start_time"] = self.start_time
if dill.pickles(self.sampler):
safe_file_dump(self.sampler, self.resume_file, dill)
logger.info("Written checkpoint file {}".format(self.resume_file))
else:
logger.warning(
"Cannot write pickle resume file! "
"Job will not resume if interrupted."
)
current_state = dict(
unit_cube_samples=self.sampler.saved_u,
physical_samples=self.sampler.saved_v,
sample_likelihoods=self.sampler.saved_logl,
sample_log_volume=self.sampler.saved_logvol,
sample_log_weights=self.sampler.saved_logwt,
cumulative_log_evidence=self.sampler.saved_logz,
cumulative_log_evidence_error=self.sampler.saved_logzvar,
cumulative_information=self.sampler.saved_h,
id=self.sampler.saved_id,
it=self.sampler.saved_it,
nc=self.sampler.saved_nc,
bound=self.sampler.bound,
nbound=self.sampler.nbound,
boundidx=self.sampler.saved_boundidx,
bounditer=self.sampler.saved_bounditer,
scale=self.sampler.saved_scale,
sampling_time=self.sampling_time.total_seconds()
)
current_state.update(
ncall=self.sampler.ncall, live_logl=self.sampler.live_logl,
iteration=self.sampler.it - 1, live_u=self.sampler.live_u,
live_v=self.sampler.live_v, nlive=self.sampler.nlive,
live_bound=self.sampler.live_bound, live_it=self.sampler.live_it,
added_live=self.sampler.added_live
)
try:
weights = np.exp(current_state['sample_log_weights'] -
current_state['cumulative_log_evidence'][-1])
from dynesty.utils import resample_equal
current_state['posterior'] = resample_equal(
np.array(current_state['physical_samples']), weights)
current_state['search_parameter_keys'] = self.search_parameter_keys
except ValueError:
logger.debug("Unable to create posterior")
with open(self.resume_file, 'wb') as file:
pickle.dump(current_state, file)
if plot and self.check_point_plot:
def plot_current_state(self):
if self.check_point_plot:
import dynesty.plotting as dyplot
labels = [label.replace('_', ' ') for label in self.search_parameter_keys]
filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
try:
filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
fig = dyplot.traceplot(self.sampler.results, labels=labels)[0]
fig.tight_layout()
fig.savefig(filename)
plt.close('all')
except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError) as e:
logger.warning(e)
logger.warning('Failed to create dynesty state plot at checkpoint')
finally:
plt.close("all")
try:
filename = "{}/{}_checkpoint_run.png".format(self.outdir, self.label)
fig, axs = dyplot.runplot(self.sampler.results)
fig.tight_layout()
plt.savefig(filename)
except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError) as e:
logger.warning(e)
logger.warning('Failed to create dynesty run plot at checkpoint')
finally:
plt.close('all')
try:
filename = "{}/{}_checkpoint_stats.png".format(self.outdir, self.label)
fig, axs = plt.subplots(nrows=3, sharex=True)
for ax, name in zip(axs, ["boundidx", "nc", "scale"]):
ax.plot(getattr(self.sampler, f"saved_{name}"), color="C0")
ax.set_ylabel(name)
axs[-1].set_xlabel("iteration")
fig.tight_layout()
plt.savefig(filename)
except (RuntimeError, ValueError) as e:
logger.warning(e)
logger.warning('Failed to create dynesty stats plot at checkpoint')
finally:
plt.close('all')
def generate_trace_plots(self, dynesty_results):
check_directory_exists_and_if_not_mkdir(self.outdir)
@@ -573,7 +554,6 @@ def sample_rwalk_bilby(args):
logl_list = [loglikelihood(v_list[-1])]
max_walk_warning = True
drhat, dr, du, u_prop, logl_prop = np.nan, np.nan, np.nan, np.nan, np.nan
while len(u_list) < nact * act:
# Propose a direction on the unit n-sphere.
Loading