Skip to content
Snippets Groups Projects
Commit 3684a0d0 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Fix dynesty state plotting

parent 4b38827b
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,8 @@ This breaks backward compatibility.
The options to `boundary` are `{'periodic', 'reflective', None}`.
Periodic boundaries are supported as before.
Reflective boundaries are supported in `dynesty` and `cpnest`.
- Added state plotting for dynesty. Use `check_point_plot=True` in the `run_sampler`
function to create trace plots during the dynesty checkpoints
### Removed
-
......
......@@ -5,6 +5,7 @@ import sys
import pickle
import signal
import matplotlib.pyplot as plt
import numpy as np
from pandas import DataFrame
......@@ -89,7 +90,7 @@ class Dynesty(NestedSampler):
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
check_point=True, check_point_plot=False, n_check_point=None,
check_point=True, check_point_plot=True, n_check_point=None,
check_point_delta_t=600, resume=True, **kwargs):
NestedSampler.__init__(self, likelihood=likelihood, priors=priors,
outdir=outdir, label=label, use_ratio=use_ratio,
......@@ -397,11 +398,20 @@ class Dynesty(NestedSampler):
if plot and self.check_point_plot:
import dynesty.plotting as dyplot
labels = self.search_parameter_keys
fn = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
fig = dyplot.traceplot(self.sampler.results, labels=labels)[0]
fig.tight_layout()
fig.savefig(fn)
labels = [label.replace('_', ' ') for label in self.search_parameter_keys]
filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
try:
truths = None
if self.injection_parameters is not None:
truths = [self.injection_parameters[key] for key in self.search_parameter_keys]
fig = dyplot.traceplot(self.sampler.results, labels=labels,
truths=truths)[0]
fig.tight_layout()
fig.savefig(filename)
plt.close('all')
except (RuntimeError, np.linalg.linalg.LinAlgError) as e:
logger.warning(e)
logger.warning('Failed to create dynesty state plot at checkpoint')
def generate_trace_plots(self, dynesty_results):
check_directory_exists_and_if_not_mkdir(self.outdir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment