diff --git a/CHANGELOG.md b/CHANGELOG.md index fa51d6d092e172dd522098b135ce823e8d0b8f15..4f1cf9a6000a508ca04368553227d8f5fcf97425 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ This breaks backward compatibility. The options to `boundary` are `{'periodic', 'reflective', None}`. Periodic boundaries are supported as before. Reflective boundaries are supported in `dynesty` and `cpnest`. +- Added state plotting for dynesty. Use `check_point_plot=True` in the `run_sampler` +function to create trace plots during the dynesty checkpoints ### Removed - diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 1e9d95656b266b58c70c29c2769faa1753b6ec57..47be5739905643871caaf038aa844e63c1136110 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -5,6 +5,7 @@ import sys import pickle import signal +import matplotlib.pyplot as plt import numpy as np from pandas import DataFrame @@ -89,7 +90,7 @@ class Dynesty(NestedSampler): def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False, skip_import_verification=False, - check_point=True, check_point_plot=False, n_check_point=None, + check_point=True, check_point_plot=True, n_check_point=None, check_point_delta_t=600, resume=True, **kwargs): NestedSampler.__init__(self, likelihood=likelihood, priors=priors, outdir=outdir, label=label, use_ratio=use_ratio, @@ -397,11 +398,20 @@ class Dynesty(NestedSampler): if plot and self.check_point_plot: import dynesty.plotting as dyplot - labels = self.search_parameter_keys - fn = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label) - fig = dyplot.traceplot(self.sampler.results, labels=labels)[0] - fig.tight_layout() - fig.savefig(fn) + labels = [label.replace('_', ' ') for label in self.search_parameter_keys] + filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label) + try: + truths = None + if self.injection_parameters is not None: + truths = [self.injection_parameters[key] for key in self.search_parameter_keys] + fig = dyplot.traceplot(self.sampler.results, labels=labels, + truths=truths)[0] + fig.tight_layout() + fig.savefig(filename) + plt.close('all') + except (RuntimeError, np.linalg.linalg.LinAlgError) as e: + logger.warning(e) + logger.warning('Failed to create dynesty state plot at checkpoint') def generate_trace_plots(self, dynesty_results): check_directory_exists_and_if_not_mkdir(self.outdir)