Skip to content
Snippets Groups Projects

Pickle dump entire sampler in dynesty

Merged Colm Talbot requested to merge improve-dynesty-checkpointing into master
+ 61
110
from __future__ import absolute_import
import datetime
import dill
import os
import sys
import pickle
@@ -334,8 +335,10 @@ class Dynesty(NestedSampler):
break
old_ncall = self.sampler.ncall
self.sampler._remove_live_points()
self.write_current_state()
self.plot_current_state()
if self.sampler.added_live:
self.sampler._remove_live_points()
sampler_kwargs['add_live'] = True
self._run_nested_wrapper(sampler_kwargs)
@@ -348,141 +351,70 @@ class Dynesty(NestedSampler):
def read_saved_state(self, continuing=False):
"""
Read a saved state of the sampler to disk.
Read a pickled saved state of the sampler to disk.
The required information to reconstruct the state of the run is read
from a pickle file.
This currently adds the whole chain to the sampler.
We then remove the old checkpoint and write all unnecessary items back
to disk.
FIXME: Load only the necessary quantities, rather than read/write?
If the live points are present and the run is continuing
they are removed.
The random state must be reset, as this isn't saved by the pickle.
`nqueue` is set to a negative number to trigger the queue to be
refilled before the first iteration.
The previous run time is set to self.
Parameters
----------
sampler: `dynesty.NestedSampler`
NestedSampler instance to reconstruct from the saved state.
continuing: bool
Whether the run is continuing or terminating, if True, the loaded
state is mostly written back to disk.
"""
+2
logger.info("Reading resume file {}".format(self.resume_file))
if os.path.isfile(self.resume_file):
logger.info("Reading resume file {}".format(self.resume_file))
try:
with open(self.resume_file, 'rb') as file:
saved = pickle.load(file)
logger.info(
"Succesfuly read resume file {}".format(self.resume_file))
except EOFError as e:
logger.warning("Resume file reading failed with error {}".format(e))
return False
self.sampler.saved_u = list(saved['unit_cube_samples'])
self.sampler.saved_v = list(saved['physical_samples'])
self.sampler.saved_logl = list(saved['sample_likelihoods'])
self.sampler.saved_logvol = list(saved['sample_log_volume'])
self.sampler.saved_logwt = list(saved['sample_log_weights'])
self.sampler.saved_logz = list(saved['cumulative_log_evidence'])
self.sampler.saved_logzvar = list(saved['cumulative_log_evidence_error'])
self.sampler.saved_id = list(saved['id'])
self.sampler.saved_it = list(saved['it'])
self.sampler.saved_nc = list(saved['nc'])
self.sampler.saved_boundidx = list(saved['boundidx'])
self.sampler.saved_bounditer = list(saved['bounditer'])
self.sampler.saved_scale = list(saved['scale'])
self.sampler.saved_h = list(saved['cumulative_information'])
self.sampler.ncall = saved['ncall']
self.sampler.live_logl = list(saved['live_logl'])
self.sampler.it = saved['iteration'] + 1
self.sampler.live_u = saved['live_u']
self.sampler.live_v = saved['live_v']
self.sampler.nlive = saved['nlive']
self.sampler.live_bound = saved['live_bound']
self.sampler.live_it = saved['live_it']
self.sampler.added_live = saved['added_live']
self.sampler.bound = saved['bound']
self.sampler.nbound = saved['nbound']
self.sampling_time += datetime.timedelta(seconds=saved['sampling_time'])
return True
with open(self.resume_file, 'rb') as file:
self.sampler = dill.load(file)
if self.sampler.added_live and continuing:
self.sampler._remove_live_points()
self.sampler.nqueue = -1
self.sampler.rstate = np.random
self.start_time = self.sampler.kwargs.pop("start_time")
self.sampling_time = self.sampler.kwargs.pop("sampling_time")
else:
logger.debug(
"No resume file {}".format(self.resume_file))
"Resume file {} does not exist.".format(self.resume_file))
return False
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
self.write_current_state(plot=False)
self.write_current_state()
sys.exit(130)
def write_current_state(self, plot=True):
def write_current_state(self):
"""
Write the current state of the sampler to disk.
The required information to reconstruct the state of the run are written
to an hdf5 file.
All but the most recent removed live point in the chain are removed from
the sampler to reduce memory usage.
This means it is necessary to not append the first live point to the
file if updating a previous checkpoint.
The sampler is pickle dumped using `dill`.
The sampling time is also stored to get the full CPU time for the run.
Parameters
----------
sampler: `dynesty.NestedSampler`
NestedSampler to write to disk.
The check of whether the sampler is picklable is to catch an error
when using pytest. Hopefully, this message won't be triggered during
normal running.
"""
check_directory_exists_and_if_not_mkdir(self.outdir)
print("")
logger.info("Writing checkpoint file {}".format(self.resume_file))
end_time = datetime.datetime.now()
if hasattr(self, 'start_time'):
self.sampling_time += end_time - self.start_time
self.start_time = end_time
self.sampler.kwargs["sampling_time"] = self.sampling_time
self.sampler.kwargs["start_time"] = self.start_time
if dill.pickles(self.sampler):
with open(self.resume_file, 'wb') as file:
dill.dump(self.sampler, file)
else:
logger.warning(
"Cannot write pickle resume file! "
"Job will not resume if interrupted."
)
current_state = dict(
unit_cube_samples=self.sampler.saved_u,
physical_samples=self.sampler.saved_v,
sample_likelihoods=self.sampler.saved_logl,
sample_log_volume=self.sampler.saved_logvol,
sample_log_weights=self.sampler.saved_logwt,
cumulative_log_evidence=self.sampler.saved_logz,
cumulative_log_evidence_error=self.sampler.saved_logzvar,
cumulative_information=self.sampler.saved_h,
id=self.sampler.saved_id,
it=self.sampler.saved_it,
nc=self.sampler.saved_nc,
bound=self.sampler.bound,
nbound=self.sampler.nbound,
boundidx=self.sampler.saved_boundidx,
bounditer=self.sampler.saved_bounditer,
scale=self.sampler.saved_scale,
sampling_time=self.sampling_time.total_seconds()
)
current_state.update(
ncall=self.sampler.ncall, live_logl=self.sampler.live_logl,
iteration=self.sampler.it - 1, live_u=self.sampler.live_u,
live_v=self.sampler.live_v, nlive=self.sampler.nlive,
live_bound=self.sampler.live_bound, live_it=self.sampler.live_it,
added_live=self.sampler.added_live
)
try:
weights = np.exp(current_state['sample_log_weights'] -
current_state['cumulative_log_evidence'][-1])
from dynesty.utils import resample_equal
current_state['posterior'] = resample_equal(
np.array(current_state['physical_samples']), weights)
current_state['search_parameter_keys'] = self.search_parameter_keys
except ValueError:
logger.debug("Unable to create posterior")
with open(self.resume_file, 'wb') as file:
pickle.dump(current_state, file)
if plot and self.check_point_plot:
def plot_current_state(self):
if self.check_point_plot:
import dynesty.plotting as dyplot
labels = [label.replace('_', ' ') for label in self.search_parameter_keys]
filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
@@ -490,10 +422,30 @@ class Dynesty(NestedSampler):
fig = dyplot.traceplot(self.sampler.results, labels=labels)[0]
fig.tight_layout()
fig.savefig(filename)
plt.close('all')
except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError) as e:
logger.warning(e)
logger.warning('Failed to create dynesty state plot at checkpoint')
finally:
plt.close("all")
filename = "{}/{}_checkpoint_run.png".format(self.outdir, self.label)
try:
fig, axs = dyplot.runplot(self.sampler.results)
fig.tight_layout()
plt.savefig(filename)
except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError) as e:
logger.warning(e)
logger.warning('Failed to create dynesty run plot at checkpoint')
finally:
plt.close('all')
filename = "{}/{}_checkpoint_stats.png".format(self.outdir, self.label)
fig, axs = plt.subplots(nrows=3, sharex=True)
for ax, name in zip(axs, ["boundidx", "nc", "scale"]):
ax.plot(getattr(self.sampler, f"saved_{name}"), color="C0")
ax.set_ylabel(name)
axs[-1].set_xlabel("iteration")
fig.tight_layout()
plt.savefig(filename)
plt.close('all')
def generate_trace_plots(self, dynesty_results):
check_directory_exists_and_if_not_mkdir(self.outdir)
@@ -573,7 +525,6 @@ def sample_rwalk_bilby(args):
logl_list = [loglikelihood(v_list[-1])]
max_walk_warning = True
drhat, dr, du, u_prop, logl_prop = np.nan, np.nan, np.nan, np.nan, np.nan
while len(u_list) < nact * act:
# Propose a direction on the unit n-sphere.
Loading