diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 2b6dc72c39ba1042036e8c45560aaf113e60dc77..1e9d95656b266b58c70c29c2769faa1753b6ec57 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -60,6 +60,8 @@ class Dynesty(NestedSampler): If true, print information information about the convergence during check_point: bool, If true, use check pointing. + check_point_plot: bool, + If true, generate a trace plot along with the check-point check_point_delta_t: float (600) The approximate checkpoint period (in seconds). Should the run be interrupted, it can be resumed from the last checkpoint. Set to @@ -85,15 +87,18 @@ class Dynesty(NestedSampler): logl_max=np.inf, add_live=True, print_progress=True, save_bounds=False) - def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False, - skip_import_verification=False, check_point=True, n_check_point=None, check_point_delta_t=600, - resume=True, **kwargs): - NestedSampler.__init__(self, likelihood=likelihood, priors=priors, outdir=outdir, label=label, - use_ratio=use_ratio, plot=plot, + def __init__(self, likelihood, priors, outdir='outdir', label='label', + use_ratio=False, plot=False, skip_import_verification=False, + check_point=True, check_point_plot=False, n_check_point=None, + check_point_delta_t=600, resume=True, **kwargs): + NestedSampler.__init__(self, likelihood=likelihood, priors=priors, + outdir=outdir, label=label, use_ratio=use_ratio, + plot=plot, skip_import_verification=skip_import_verification, **kwargs) self.n_check_point = n_check_point self.check_point = check_point + self.check_point_plot = check_point_plot self.resume = resume self._periodic = list() self._reflective = list() @@ -320,7 +325,7 @@ class Dynesty(NestedSampler): self.sampler.live_it = saved['live_it'] self.sampler.added_live = saved['added_live'] if continuing: - self.write_current_state() + self.write_current_state(plot=False) return True else: @@ -330,10 +335,10 @@ class Dynesty(NestedSampler): def write_current_state_and_exit(self, signum=None, frame=None): logger.warning("Run terminated with signal {}".format(signum)) - self.write_current_state() + self.write_current_state(plot=False) sys.exit(130) - def write_current_state(self): + def write_current_state(self, plot=True): """ Write the current state of the sampler to disk. @@ -390,6 +395,14 @@ class Dynesty(NestedSampler): with open(self.resume_file, 'wb') as file: pickle.dump(current_state, file) + if plot and self.check_point_plot: + import dynesty.plotting as dyplot + labels = self.search_parameter_keys + fn = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label) + fig = dyplot.traceplot(self.sampler.results, labels=labels)[0] + fig.tight_layout() + fig.savefig(fn) + def generate_trace_plots(self, dynesty_results): check_directory_exists_and_if_not_mkdir(self.outdir) filename = '{}/{}_trace.png'.format(self.outdir, self.label)