Skip to content
Snippets Groups Projects
Commit ce294077 authored by Gregory Ashton's avatar Gregory Ashton Committed by Moritz Huebner
Browse files

Adds a simple check-point plot option to see progress

parent 286878be
No related branches found
No related tags found
No related merge requests found
...@@ -60,6 +60,8 @@ class Dynesty(NestedSampler): ...@@ -60,6 +60,8 @@ class Dynesty(NestedSampler):
If true, print information information about the convergence during If true, print information information about the convergence during
check_point: bool, check_point: bool,
If true, use check pointing. If true, use check pointing.
check_point_plot: bool,
If true, generate a trace plot along with the check-point
check_point_delta_t: float (600) check_point_delta_t: float (600)
The approximate checkpoint period (in seconds). Should the run be The approximate checkpoint period (in seconds). Should the run be
interrupted, it can be resumed from the last checkpoint. Set to interrupted, it can be resumed from the last checkpoint. Set to
...@@ -85,15 +87,18 @@ class Dynesty(NestedSampler): ...@@ -85,15 +87,18 @@ class Dynesty(NestedSampler):
logl_max=np.inf, add_live=True, print_progress=True, logl_max=np.inf, add_live=True, print_progress=True,
save_bounds=False) save_bounds=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False, def __init__(self, likelihood, priors, outdir='outdir', label='label',
skip_import_verification=False, check_point=True, n_check_point=None, check_point_delta_t=600, use_ratio=False, plot=False, skip_import_verification=False,
resume=True, **kwargs): check_point=True, check_point_plot=False, n_check_point=None,
NestedSampler.__init__(self, likelihood=likelihood, priors=priors, outdir=outdir, label=label, check_point_delta_t=600, resume=True, **kwargs):
use_ratio=use_ratio, plot=plot, NestedSampler.__init__(self, likelihood=likelihood, priors=priors,
outdir=outdir, label=label, use_ratio=use_ratio,
plot=plot,
skip_import_verification=skip_import_verification, skip_import_verification=skip_import_verification,
**kwargs) **kwargs)
self.n_check_point = n_check_point self.n_check_point = n_check_point
self.check_point = check_point self.check_point = check_point
self.check_point_plot = check_point_plot
self.resume = resume self.resume = resume
self._periodic = list() self._periodic = list()
self._reflective = list() self._reflective = list()
...@@ -320,7 +325,7 @@ class Dynesty(NestedSampler): ...@@ -320,7 +325,7 @@ class Dynesty(NestedSampler):
self.sampler.live_it = saved['live_it'] self.sampler.live_it = saved['live_it']
self.sampler.added_live = saved['added_live'] self.sampler.added_live = saved['added_live']
if continuing: if continuing:
self.write_current_state() self.write_current_state(plot=False)
return True return True
else: else:
...@@ -330,10 +335,10 @@ class Dynesty(NestedSampler): ...@@ -330,10 +335,10 @@ class Dynesty(NestedSampler):
def write_current_state_and_exit(self, signum=None, frame=None): def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum)) logger.warning("Run terminated with signal {}".format(signum))
self.write_current_state() self.write_current_state(plot=False)
sys.exit(130) sys.exit(130)
def write_current_state(self): def write_current_state(self, plot=True):
""" """
Write the current state of the sampler to disk. Write the current state of the sampler to disk.
...@@ -390,6 +395,14 @@ class Dynesty(NestedSampler): ...@@ -390,6 +395,14 @@ class Dynesty(NestedSampler):
with open(self.resume_file, 'wb') as file: with open(self.resume_file, 'wb') as file:
pickle.dump(current_state, file) pickle.dump(current_state, file)
if plot and self.check_point_plot:
import dynesty.plotting as dyplot
labels = self.search_parameter_keys
fn = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
fig = dyplot.traceplot(self.sampler.results, labels=labels)[0]
fig.tight_layout()
fig.savefig(fn)
def generate_trace_plots(self, dynesty_results): def generate_trace_plots(self, dynesty_results):
check_directory_exists_and_if_not_mkdir(self.outdir) check_directory_exists_and_if_not_mkdir(self.outdir)
filename = '{}/{}_trace.png'.format(self.outdir, self.label) filename = '{}/{}_trace.png'.format(self.outdir, self.label)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment