Skip to content
Snippets Groups Projects
Commit 400343a2 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch 'fix_dynesty_state_plotting' into 'master'

Fix dynesty state plotting

See merge request !468
parents 4b38827b 3684a0d0
No related branches found
No related tags found
1 merge request!468Fix dynesty state plotting
Pipeline #60213 canceled
......@@ -11,6 +11,8 @@ This breaks backward compatibility.
The options to `boundary` are `{'periodic', 'reflective', None}`.
Periodic boundaries are supported as before.
Reflective boundaries are supported in `dynesty` and `cpnest`.
- Added state plotting for dynesty. Use `check_point_plot=True` in the `run_sampler`
function to create trace plots during the dynesty checkpoints
### Removed
-
......
......@@ -5,6 +5,7 @@ import sys
import pickle
import signal
import matplotlib.pyplot as plt
import numpy as np
from pandas import DataFrame
......@@ -89,7 +90,7 @@ class Dynesty(NestedSampler):
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
check_point=True, check_point_plot=False, n_check_point=None,
check_point=True, check_point_plot=True, n_check_point=None,
check_point_delta_t=600, resume=True, **kwargs):
NestedSampler.__init__(self, likelihood=likelihood, priors=priors,
outdir=outdir, label=label, use_ratio=use_ratio,
......@@ -397,11 +398,20 @@ class Dynesty(NestedSampler):
if plot and self.check_point_plot:
import dynesty.plotting as dyplot
labels = self.search_parameter_keys
fn = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
fig = dyplot.traceplot(self.sampler.results, labels=labels)[0]
fig.tight_layout()
fig.savefig(fn)
labels = [label.replace('_', ' ') for label in self.search_parameter_keys]
filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
try:
truths = None
if self.injection_parameters is not None:
truths = [self.injection_parameters[key] for key in self.search_parameter_keys]
fig = dyplot.traceplot(self.sampler.results, labels=labels,
truths=truths)[0]
fig.tight_layout()
fig.savefig(filename)
plt.close('all')
except (RuntimeError, np.linalg.linalg.LinAlgError) as e:
logger.warning(e)
logger.warning('Failed to create dynesty state plot at checkpoint')
def generate_trace_plots(self, dynesty_results):
check_directory_exists_and_if_not_mkdir(self.outdir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment