Skip to content
Snippets Groups Projects
Commit 3eb8f017 authored by Colm Talbot's avatar Colm Talbot
Browse files

make pickle names consistent

parent 6e249e8a
No related branches found
No related tags found
1 merge request!385make pickle names consistent
Pipeline #50612 passed
......@@ -104,6 +104,8 @@ class Dynesty(NestedSampler):
n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw)))
self.n_check_point = n_check_point_rnd
self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
signal.signal(signal.SIGINT, self.write_current_state_and_exit)
......@@ -240,15 +242,15 @@ class Dynesty(NestedSampler):
def _remove_checkpoint(self):
"""Remove checkpointed state"""
if os.path.isfile('{}/{}_resume.h5'.format(self.outdir, self.label)):
os.remove('{}/{}_resume.h5'.format(self.outdir, self.label))
if os.path.isfile(self.resume_file):
os.remove(self.resume_file)
def read_saved_state(self, continuing=False):
"""
Read a saved state of the sampler to disk.
The required information to reconstruct the state of the run is read
from an hdf5 file.
from a pickle file.
This currently adds the whole chain to the sampler.
We then remove the old checkpoint and write all unnecessary items back
to disk.
......@@ -262,13 +264,14 @@ class Dynesty(NestedSampler):
Whether the run is continuing or terminating, if True, the loaded
state is mostly written back to disk.
"""
resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
logger.debug("Reading resume file {}".format(resume_file))
if os.path.isfile(resume_file):
with open(resume_file, 'rb') as file:
logger.debug("Reading resume file {}".format(self.resume_file))
if os.path.isfile(self.resume_file):
with open(self.resume_file, 'rb') as file:
saved = pickle.load(file)
logger.debug("Succesfuly read resume file {}".format(resume_file))
logger.debug(
"Succesfuly read resume file {}".format(self.resume_file))
self.sampler.saved_u = list(saved['unit_cube_samples'])
self.sampler.saved_v = list(saved['physical_samples'])
......@@ -299,6 +302,8 @@ class Dynesty(NestedSampler):
return True
else:
logger.debug(
"Failed to read resume file {}".format(self.resume_file))
return False
def write_current_state_and_exit(self, signum=None, frame=None):
......@@ -323,57 +328,23 @@ class Dynesty(NestedSampler):
NestedSampler to write to disk.
"""
check_directory_exists_and_if_not_mkdir(self.outdir)
resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
if os.path.isfile(resume_file):
with open(resume_file, 'rb') as file:
saved = pickle.load(file)
current_state = dict(
unit_cube_samples=np.vstack([
saved['unit_cube_samples'], self.sampler.saved_u[1:]]),
physical_samples=np.vstack([
saved['physical_samples'], self.sampler.saved_v[1:]]),
sample_likelihoods=np.concatenate([
saved['sample_likelihoods'], self.sampler.saved_logl[1:]]),
sample_log_volume=np.concatenate([
saved['sample_log_volume'], self.sampler.saved_logvol[1:]]),
sample_log_weights=np.concatenate([
saved['sample_log_weights'], self.sampler.saved_logwt[1:]]),
cumulative_log_evidence=np.concatenate([
saved['cumulative_log_evidence'], self.sampler.saved_logz[1:]]),
cumulative_log_evidence_error=np.concatenate([
saved['cumulative_log_evidence_error'],
self.sampler.saved_logzvar[1:]]),
cumulative_information=np.concatenate([
saved['cumulative_information'], self.sampler.saved_h[1:]]),
id=np.concatenate([saved['id'], self.sampler.saved_id[1:]]),
it=np.concatenate([saved['it'], self.sampler.saved_it[1:]]),
nc=np.concatenate([saved['nc'], self.sampler.saved_nc[1:]]),
boundidx=np.concatenate([
saved['boundidx'], self.sampler.saved_boundidx[1:]]),
bounditer=np.concatenate([
saved['bounditer'], self.sampler.saved_bounditer[1:]]),
scale=np.concatenate([saved['scale'], self.sampler.saved_scale[1:]]),
)
else:
current_state = dict(
unit_cube_samples=self.sampler.saved_u,
physical_samples=self.sampler.saved_v,
sample_likelihoods=self.sampler.saved_logl,
sample_log_volume=self.sampler.saved_logvol,
sample_log_weights=self.sampler.saved_logwt,
cumulative_log_evidence=self.sampler.saved_logz,
cumulative_log_evidence_error=self.sampler.saved_logzvar,
cumulative_information=self.sampler.saved_h,
id=self.sampler.saved_id,
it=self.sampler.saved_it,
nc=self.sampler.saved_nc,
boundidx=self.sampler.saved_boundidx,
bounditer=self.sampler.saved_bounditer,
scale=self.sampler.saved_scale,
)
current_state = dict(
unit_cube_samples=self.sampler.saved_u,
physical_samples=self.sampler.saved_v,
sample_likelihoods=self.sampler.saved_logl,
sample_log_volume=self.sampler.saved_logvol,
sample_log_weights=self.sampler.saved_logwt,
cumulative_log_evidence=self.sampler.saved_logz,
cumulative_log_evidence_error=self.sampler.saved_logzvar,
cumulative_information=self.sampler.saved_h,
id=self.sampler.saved_id,
it=self.sampler.saved_it,
nc=self.sampler.saved_nc,
boundidx=self.sampler.saved_boundidx,
bounditer=self.sampler.saved_bounditer,
scale=self.sampler.saved_scale,
)
current_state.update(
ncall=self.sampler.ncall, live_logl=self.sampler.live_logl,
......@@ -392,24 +363,9 @@ class Dynesty(NestedSampler):
except ValueError:
logger.debug("Unable to create posterior")
with open(resume_file, 'wb') as file:
with open(self.resume_file, 'wb') as file:
pickle.dump(current_state, file)
self.sampler.saved_id = [self.sampler.saved_id[-1]]
self.sampler.saved_u = [self.sampler.saved_u[-1]]
self.sampler.saved_v = [self.sampler.saved_v[-1]]
self.sampler.saved_logl = [self.sampler.saved_logl[-1]]
self.sampler.saved_logvol = [self.sampler.saved_logvol[-1]]
self.sampler.saved_logwt = [self.sampler.saved_logwt[-1]]
self.sampler.saved_logz = [self.sampler.saved_logz[-1]]
self.sampler.saved_logzvar = [self.sampler.saved_logzvar[-1]]
self.sampler.saved_h = [self.sampler.saved_h[-1]]
self.sampler.saved_nc = [self.sampler.saved_nc[-1]]
self.sampler.saved_boundidx = [self.sampler.saved_boundidx[-1]]
self.sampler.saved_it = [self.sampler.saved_it[-1]]
self.sampler.saved_bounditer = [self.sampler.saved_bounditer[-1]]
self.sampler.saved_scale = [self.sampler.saved_scale[-1]]
def generate_trace_plots(self, dynesty_results):
check_directory_exists_and_if_not_mkdir(self.outdir)
filename = '{}/{}_trace.png'.format(self.outdir, self.label)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment