Skip to content
Snippets Groups Projects
Commit c933a88c authored by Paul Lasky's avatar Paul Lasky
Browse files

Merge branch 'fix-dynesty-cleanup' into 'master'

make pickle names consistent

See merge request lscsoft/bilby!385
parents a9a16d22 3eb8f017
No related branches found
No related tags found
No related merge requests found
......@@ -104,6 +104,8 @@ class Dynesty(NestedSampler):
n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw)))
self.n_check_point = n_check_point_rnd
self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
......@@ -240,15 +242,15 @@ class Dynesty(NestedSampler):
def _remove_checkpoint(self):
"""Remove checkpointed state"""
if os.path.isfile('{}/{}_resume.h5'.format(self.outdir, self.label)):
os.remove('{}/{}_resume.h5'.format(self.outdir, self.label))
if os.path.isfile(self.resume_file):
os.remove(self.resume_file)
def read_saved_state(self, continuing=False):
"""
Read a saved state of the sampler to disk.
The required information to reconstruct the state of the run is read
from an hdf5 file.
from a pickle file.
This currently adds the whole chain to the sampler.
We then remove the old checkpoint and write all unnecessary items back
to disk.
......@@ -262,13 +264,14 @@ class Dynesty(NestedSampler):
Whether the run is continuing or terminating, if True, the loaded
state is mostly written back to disk.
"""
resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
logger.debug("Reading resume file {}".format(resume_file))
if os.path.isfile(resume_file):
with open(resume_file, 'rb') as file:
logger.debug("Reading resume file {}".format(self.resume_file))
if os.path.isfile(self.resume_file):
with open(self.resume_file, 'rb') as file:
saved = pickle.load(file)
logger.debug("Succesfuly read resume file {}".format(resume_file))
logger.debug(
"Succesfuly read resume file {}".format(self.resume_file))
self.sampler.saved_u = list(saved['unit_cube_samples'])
self.sampler.saved_v = list(saved['physical_samples'])
......@@ -299,6 +302,8 @@ class Dynesty(NestedSampler):
return True
else:
logger.debug(
"Failed to read resume file {}".format(self.resume_file))
return False
def write_current_state_and_exit(self, signum=None, frame=None):
......@@ -323,57 +328,23 @@ class Dynesty(NestedSampler):
NestedSampler to write to disk.
"""
check_directory_exists_and_if_not_mkdir(self.outdir)
resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
if os.path.isfile(resume_file):
with open(resume_file, 'rb') as file:
saved = pickle.load(file)
current_state = dict(
unit_cube_samples=np.vstack([
saved['unit_cube_samples'], self.sampler.saved_u[1:]]),
physical_samples=np.vstack([
saved['physical_samples'], self.sampler.saved_v[1:]]),
sample_likelihoods=np.concatenate([
saved['sample_likelihoods'], self.sampler.saved_logl[1:]]),
sample_log_volume=np.concatenate([
saved['sample_log_volume'], self.sampler.saved_logvol[1:]]),
sample_log_weights=np.concatenate([
saved['sample_log_weights'], self.sampler.saved_logwt[1:]]),
cumulative_log_evidence=np.concatenate([
saved['cumulative_log_evidence'], self.sampler.saved_logz[1:]]),
cumulative_log_evidence_error=np.concatenate([
saved['cumulative_log_evidence_error'],
self.sampler.saved_logzvar[1:]]),
cumulative_information=np.concatenate([
saved['cumulative_information'], self.sampler.saved_h[1:]]),
id=np.concatenate([saved['id'], self.sampler.saved_id[1:]]),
it=np.concatenate([saved['it'], self.sampler.saved_it[1:]]),
nc=np.concatenate([saved['nc'], self.sampler.saved_nc[1:]]),
boundidx=np.concatenate([
saved['boundidx'], self.sampler.saved_boundidx[1:]]),
bounditer=np.concatenate([
saved['bounditer'], self.sampler.saved_bounditer[1:]]),
scale=np.concatenate([saved['scale'], self.sampler.saved_scale[1:]]),
)
else:
current_state = dict(
unit_cube_samples=self.sampler.saved_u,
physical_samples=self.sampler.saved_v,
sample_likelihoods=self.sampler.saved_logl,
sample_log_volume=self.sampler.saved_logvol,
sample_log_weights=self.sampler.saved_logwt,
cumulative_log_evidence=self.sampler.saved_logz,
cumulative_log_evidence_error=self.sampler.saved_logzvar,
cumulative_information=self.sampler.saved_h,
id=self.sampler.saved_id,
it=self.sampler.saved_it,
nc=self.sampler.saved_nc,
boundidx=self.sampler.saved_boundidx,
bounditer=self.sampler.saved_bounditer,
scale=self.sampler.saved_scale,
)
current_state = dict(
unit_cube_samples=self.sampler.saved_u,
physical_samples=self.sampler.saved_v,
sample_likelihoods=self.sampler.saved_logl,
sample_log_volume=self.sampler.saved_logvol,
sample_log_weights=self.sampler.saved_logwt,
cumulative_log_evidence=self.sampler.saved_logz,
cumulative_log_evidence_error=self.sampler.saved_logzvar,
cumulative_information=self.sampler.saved_h,
id=self.sampler.saved_id,
it=self.sampler.saved_it,
nc=self.sampler.saved_nc,
boundidx=self.sampler.saved_boundidx,
bounditer=self.sampler.saved_bounditer,
scale=self.sampler.saved_scale,
)
current_state.update(
ncall=self.sampler.ncall, live_logl=self.sampler.live_logl,
......@@ -392,24 +363,9 @@ class Dynesty(NestedSampler):
except ValueError:
logger.debug("Unable to create posterior")
with open(resume_file, 'wb') as file:
with open(self.resume_file, 'wb') as file:
pickle.dump(current_state, file)
self.sampler.saved_id = [self.sampler.saved_id[-1]]
self.sampler.saved_u = [self.sampler.saved_u[-1]]
self.sampler.saved_v = [self.sampler.saved_v[-1]]
self.sampler.saved_logl = [self.sampler.saved_logl[-1]]
self.sampler.saved_logvol = [self.sampler.saved_logvol[-1]]
self.sampler.saved_logwt = [self.sampler.saved_logwt[-1]]
self.sampler.saved_logz = [self.sampler.saved_logz[-1]]
self.sampler.saved_logzvar = [self.sampler.saved_logzvar[-1]]
self.sampler.saved_h = [self.sampler.saved_h[-1]]
self.sampler.saved_nc = [self.sampler.saved_nc[-1]]
self.sampler.saved_boundidx = [self.sampler.saved_boundidx[-1]]
self.sampler.saved_it = [self.sampler.saved_it[-1]]
self.sampler.saved_bounditer = [self.sampler.saved_bounditer[-1]]
self.sampler.saved_scale = [self.sampler.saved_scale[-1]]
def generate_trace_plots(self, dynesty_results):
check_directory_exists_and_if_not_mkdir(self.outdir)
filename = '{}/{}_trace.png'.format(self.outdir, self.label)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment