Skip to content
Snippets Groups Projects

Pickle dump entire sampler in dynesty

Merged Colm Talbot requested to merge improve-dynesty-checkpointing into master
All threads resolved!
Files
14
@@ -6,13 +6,14 @@ import os
import sys
import pickle
import signal
import time
import tqdm
import matplotlib.pyplot as plt
import numpy as np
from pandas import DataFrame
from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect
from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect, safe_file_dump
from .base_sampler import Sampler, NestedSampler
from numpy import linalg
@@ -71,14 +72,14 @@ class Dynesty(NestedSampler):
check_point_plot: bool,
If true, generate a trace plot along with the check-point
check_point_delta_t: float (600)
The approximate checkpoint period (in seconds). Should the run be
interrupted, it can be resumed from the last checkpoint. Set to
`None` to turn-off check pointing
The minimum checkpoint period (in seconds). Should the run be
interrupted, it can be resumed from the last checkpoint.
n_check_point: int, optional (None)
The number of steps to take before check pointing (override
check_point_delta_t).
The number of steps to take before checking whether to check_point.
resume: bool
If true, resume run from checkpoint (if available)
exit_code: int
The code which the same exits on if it hasn't finished sampling
"""
default_kwargs = dict(bound='multi', sample='rwalk',
verbose=True, periodic=None, reflective=None,
@@ -99,7 +100,7 @@ class Dynesty(NestedSampler):
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
check_point=True, check_point_plot=True, n_check_point=None,
check_point_delta_t=600, resume=True, **kwargs):
check_point_delta_t=600, resume=True, exit_code=130, **kwargs):
super(Dynesty, self).__init__(likelihood=likelihood, priors=priors,
outdir=outdir, label=label, use_ratio=use_ratio,
plot=plot, skip_import_verification=skip_import_verification,
@@ -111,19 +112,16 @@ class Dynesty(NestedSampler):
self._periodic = list()
self._reflective = list()
self._apply_dynesty_boundaries()
if self.n_check_point is None:
# If the log_likelihood_eval_time is not calculable then
# check_point is set to False.
if np.isnan(self._log_likelihood_eval_time):
self.check_point = False
n_check_point_raw = (check_point_delta_t / self._log_likelihood_eval_time)
n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw)))
self.n_check_point = n_check_point_rnd
logger.info("Checkpoint every n_check_point = {}".format(self.n_check_point))
if self.n_check_point is None:
self.n_check_point = 1000
self.check_point_delta_t = check_point_delta_t
logger.info("Checkpoint every check_point_delta_t = {}s"
.format(check_point_delta_t))
self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
self.sampling_time = datetime.timedelta()
self.exit_code = exit_code
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
@@ -132,7 +130,8 @@ class Dynesty(NestedSampler):
def __getstate__(self):
""" For pickle: remove external_sampler, which can be an unpicklable "module" """
state = self.__dict__.copy()
del state['external_sampler']
if "external_sampler" in state:
del state['external_sampler']
return state
@property
@@ -335,8 +334,13 @@ class Dynesty(NestedSampler):
break
old_ncall = self.sampler.ncall
self.write_current_state()
self.plot_current_state()
if os.path.isfile(self.resume_file):
last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file)
else:
last_checkpoint_s = (datetime.datetime.now() - self.start_time).total_seconds()
if last_checkpoint_s > self.check_point_delta_t:
self.write_current_state()
self.plot_current_state()
if self.sampler.added_live:
self.sampler._remove_live_points()
@@ -366,10 +370,21 @@ class Dynesty(NestedSampler):
Whether the run is continuing or terminating, if True, the loaded
state is mostly written back to disk.
"""
import dynesty
if os.path.isfile(self.resume_file):
logger.info("Reading resume file {}".format(self.resume_file))
with open(self.resume_file, 'rb') as file:
self.sampler = dill.load(file)
sampler = dill.load(file)
if isinstance(sampler, dynesty.nestedsamplers.MultiEllipsoidSampler) is False:
logger.warning(
"The resume file {} is corrupted or the version of "
"bilby has changed between runs. This resume file will "
"be ignored."
.format(self.resume_file))
return False
self.sampler = sampler
if self.sampler.added_live and continuing:
self.sampler._remove_live_points()
self.sampler.nqueue = -1
@@ -378,14 +393,21 @@ class Dynesty(NestedSampler):
self.sampling_time = self.sampler.kwargs.pop("sampling_time")
return True
else:
logger.debug(
logger.info(
"Resume file {} does not exist.".format(self.resume_file))
return False
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
if signum == 14:
logger.info(
"Run interrupted by alarm signal {}: checkpoint and exit on {}"
.format(signum, self.exit_code))
else:
logger.info(
"Run interrupted by signal {}: checkpoint and exit on {}"
.format(signum, self.exit_code))
self.write_current_state()
sys.exit(130)
os._exit(self.exit_code)
def write_current_state(self):
"""
@@ -398,6 +420,7 @@ class Dynesty(NestedSampler):
when using pytest. Hopefully, this message won't be triggered during
normal running.
"""
check_directory_exists_and_if_not_mkdir(self.outdir)
end_time = datetime.datetime.now()
if hasattr(self, 'start_time'):
@@ -406,8 +429,8 @@ class Dynesty(NestedSampler):
self.sampler.kwargs["sampling_time"] = self.sampling_time
self.sampler.kwargs["start_time"] = self.start_time
if dill.pickles(self.sampler):
with open(self.resume_file, 'wb') as file:
dill.dump(self.sampler, file)
safe_file_dump(self.sampler, self.resume_file, dill)
logger.info("Written checkpoint file {}".format(self.resume_file))
else:
logger.warning(
"Cannot write pickle resume file! "
@@ -418,8 +441,8 @@ class Dynesty(NestedSampler):
if self.check_point_plot:
import dynesty.plotting as dyplot
labels = [label.replace('_', ' ') for label in self.search_parameter_keys]
filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
try:
filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
fig = dyplot.traceplot(self.sampler.results, labels=labels)[0]
fig.tight_layout()
fig.savefig(filename)
@@ -428,8 +451,8 @@ class Dynesty(NestedSampler):
logger.warning('Failed to create dynesty state plot at checkpoint')
finally:
plt.close("all")
filename = "{}/{}_checkpoint_run.png".format(self.outdir, self.label)
try:
filename = "{}/{}_checkpoint_run.png".format(self.outdir, self.label)
fig, axs = dyplot.runplot(self.sampler.results)
fig.tight_layout()
plt.savefig(filename)
@@ -438,15 +461,20 @@ class Dynesty(NestedSampler):
logger.warning('Failed to create dynesty run plot at checkpoint')
finally:
plt.close('all')
filename = "{}/{}_checkpoint_stats.png".format(self.outdir, self.label)
fig, axs = plt.subplots(nrows=3, sharex=True)
for ax, name in zip(axs, ["boundidx", "nc", "scale"]):
ax.plot(getattr(self.sampler, f"saved_{name}"), color="C0")
ax.set_ylabel(name)
axs[-1].set_xlabel("iteration")
fig.tight_layout()
plt.savefig(filename)
plt.close('all')
try:
filename = "{}/{}_checkpoint_stats.png".format(self.outdir, self.label)
fig, axs = plt.subplots(nrows=3, sharex=True)
for ax, name in zip(axs, ["boundidx", "nc", "scale"]):
ax.plot(getattr(self.sampler, f"saved_{name}"), color="C0")
ax.set_ylabel(name)
axs[-1].set_xlabel("iteration")
fig.tight_layout()
plt.savefig(filename)
except (RuntimeError, ValueError) as e:
logger.warning(e)
logger.warning('Failed to create dynesty stats plot at checkpoint')
finally:
plt.close('all')
def generate_trace_plots(self, dynesty_results):
check_directory_exists_and_if_not_mkdir(self.outdir)
Loading