Skip to content
Snippets Groups Projects

Pickle dump entire sampler in dynesty

Merged Colm Talbot requested to merge improve-dynesty-checkpointing into master
All threads resolved!
@@ -129,6 +129,12 @@ class Dynesty(NestedSampler):
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
def __getstate__(self):
""" For pickle: remove external_sampler, which can be an unpicklable "module" """
state = self.__dict__.copy()
del state['external_sampler']
return state
@property
def sampler_function_kwargs(self):
keys = ['dlogz', 'print_progress', 'print_func', 'maxiter',
@@ -331,6 +337,7 @@ class Dynesty(NestedSampler):
old_ncall = self.sampler.ncall
self.write_current_state()
self.plot_current_state()
if self.sampler.added_live:
self.sampler._remove_live_points()
@@ -377,16 +384,20 @@ class Dynesty(NestedSampler):
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
self.write_current_state(plot=False)
self.write_current_state()
sys.exit(130)
def write_current_state(self, plot=True):
def write_current_state(self):
"""
Write the current state of the sampler to disk.
The sampler is pickle dumped using `dill`.
The sampling time is also stored to get the full CPU time for the run.
The check of whether the sampler is picklable is to catch an error
when using pytest. Hopefully, this message won't be triggered during
normal running.
Parameters
----------
plot: bool
@@ -399,10 +410,17 @@ class Dynesty(NestedSampler):
self.start_time = end_time
self.sampler.kwargs["sampling_time"] = self.sampling_time
self.sampler.kwargs["start_time"] = self.start_time
with open(self.resume_file, 'wb') as file:
dill.dump(self.sampler, file)
if dill.pickles(self.sampler):
with open(self.resume_file, 'wb') as file:
dill.dump(self.sampler, file)
else:
logger.warning(
"Cannot write pickle resume file! "
"Job will not resume if interrupted."
)
if plot and self.check_point_plot:
def plot_current_state(self):
if self.check_point_plot:
import dynesty.plotting as dyplot
labels = [label.replace('_', ' ') for label in self.search_parameter_keys]
filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
Loading