Skip to content
Snippets Groups Projects
Commit ce294077 authored by Gregory Ashton's avatar Gregory Ashton Committed by Moritz Huebner
Browse files

Adds a simple check-point plot option to see progress

parent 286878be
No related branches found
No related tags found
No related merge requests found
......@@ -60,6 +60,8 @@ class Dynesty(NestedSampler):
If true, print information information about the convergence during
check_point: bool,
If true, use check pointing.
check_point_plot: bool,
If true, generate a trace plot along with the check-point
check_point_delta_t: float (600)
The approximate checkpoint period (in seconds). Should the run be
interrupted, it can be resumed from the last checkpoint. Set to
......@@ -85,15 +87,18 @@ class Dynesty(NestedSampler):
logl_max=np.inf, add_live=True, print_progress=True,
save_bounds=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False,
skip_import_verification=False, check_point=True, n_check_point=None, check_point_delta_t=600,
resume=True, **kwargs):
NestedSampler.__init__(self, likelihood=likelihood, priors=priors, outdir=outdir, label=label,
use_ratio=use_ratio, plot=plot,
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
check_point=True, check_point_plot=False, n_check_point=None,
check_point_delta_t=600, resume=True, **kwargs):
NestedSampler.__init__(self, likelihood=likelihood, priors=priors,
outdir=outdir, label=label, use_ratio=use_ratio,
plot=plot,
skip_import_verification=skip_import_verification,
**kwargs)
self.n_check_point = n_check_point
self.check_point = check_point
self.check_point_plot = check_point_plot
self.resume = resume
self._periodic = list()
self._reflective = list()
......@@ -320,7 +325,7 @@ class Dynesty(NestedSampler):
self.sampler.live_it = saved['live_it']
self.sampler.added_live = saved['added_live']
if continuing:
self.write_current_state()
self.write_current_state(plot=False)
return True
else:
......@@ -330,10 +335,10 @@ class Dynesty(NestedSampler):
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
self.write_current_state()
self.write_current_state(plot=False)
sys.exit(130)
def write_current_state(self):
def write_current_state(self, plot=True):
"""
Write the current state of the sampler to disk.
......@@ -390,6 +395,14 @@ class Dynesty(NestedSampler):
with open(self.resume_file, 'wb') as file:
pickle.dump(current_state, file)
if plot and self.check_point_plot:
import dynesty.plotting as dyplot
labels = self.search_parameter_keys
fn = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
fig = dyplot.traceplot(self.sampler.results, labels=labels)[0]
fig.tight_layout()
fig.savefig(fn)
def generate_trace_plots(self, dynesty_results):
check_directory_exists_and_if_not_mkdir(self.outdir)
filename = '{}/{}_trace.png'.format(self.outdir, self.label)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment