Skip to content
Snippets Groups Projects

Pickle dump entire sampler in dynesty

Merged Colm Talbot requested to merge improve-dynesty-checkpointing into master
All threads resolved!
@@ -370,20 +370,35 @@ class Dynesty(NestedSampler):
Whether the run is continuing or terminating, if True, the loaded
state is mostly written back to disk.
"""
import dynesty
from ... import __version__ as bilby_version
from dynesty import __version__ as dynesty_version
versions = dict(bilby=bilby_versions, dynesty=dynesty_version)
if os.path.isfile(self.resume_file):
logger.info("Reading resume file {}".format(self.resume_file))
with open(self.resume_file, 'rb') as file:
sampler = dill.load(file)
if isinstance(sampler, dynesty.nestedsamplers.MultiEllipsoidSampler) is False:
if not hasattr(sampler, "versions"):
logger.warning(
"The resume file {} is corrupted or the version of "
"bilby has changed between runs. This resume file will "
"be ignored."
.format(self.resume_file))
return False
version_warning = (
"The {code} version has changed between runs. "
"This may cause unpredictable behaviour and/or failure. "
"Old version = {old}, new version = {new}".
)
for code in versions:
if not versions[code] == sampler.versions.get(code, None):
logger.warning(version_message.format(
code=code,
old=sampler.versions.get(code, "None"),
new=versions[code]
))
del sampler.versions
self.sampler = sampler
if self.sampler.added_live and continuing:
self.sampler._remove_live_points()
@@ -421,6 +436,8 @@ class Dynesty(NestedSampler):
normal running.
"""
from ... import __version__ as bilby_version
from dynesty import __version__ as dynesty_version
check_directory_exists_and_if_not_mkdir(self.outdir)
end_time = datetime.datetime.now()
if hasattr(self, 'start_time'):
@@ -428,6 +445,9 @@ class Dynesty(NestedSampler):
self.start_time = end_time
self.sampler.kwargs["sampling_time"] = self.sampling_time
self.sampler.kwargs["start_time"] = self.start_time
self.sampler.versions = dict(
bilby=bilby_version, dynesty=dynesty_version
)
if dill.pickles(self.sampler):
safe_file_dump(self.sampler, self.resume_file, dill)
logger.info("Written checkpoint file {}".format(self.resume_file))
Loading