Skip to content
Snippets Groups Projects
Commit 733a52e6 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adds plotting argument

- This will turn off by default the trace plots
parent c31c4482
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -36,8 +36,10 @@ class Sampler(object):
"""
def __init__(self, likelihood, priors, external_sampler='nestle', outdir='outdir', label='label', use_ratio=False,
**kwargs):
def __init__(
self, likelihood, priors, external_sampler='nestle',
outdir='outdir', label='label', use_ratio=False, plot=False,
**kwargs):
self.likelihood = likelihood
self.priors = priors
self.label = label
......@@ -45,6 +47,7 @@ class Sampler(object):
self.use_ratio = use_ratio
self.external_sampler = external_sampler
self.external_sampler_function = None
self.plot = plot
self.__search_parameter_keys = []
self.__fixed_parameter_keys = []
......@@ -327,15 +330,19 @@ class Dynesty(Sampler):
out.samples, weights)
self.result.logz = out.logz[-1]
self.result.logzerr = out.logzerr[-1]
self.generate_trace_plots(out)
if self.plot:
self.generate_trace_plots(out)
return self.result
def generate_trace_plots(self, dynesty_results):
filename = '{}/{}_trace.png'.format(self.outdir, self.label)
logging.info("Writing trace plot to {}".format(filename))
from dynesty import plotting as dyplot
fig, axes = dyplot.traceplot(dynesty_results,
labels=self.result.parameter_labels)
fig.tight_layout()
fig.savefig('{}/{}_trace.png'.format(self.outdir, self.label))
fig.savefig(filename)
def _run_test(self):
dynesty = self.external_sampler
......@@ -444,7 +451,7 @@ class Ptemcee(Sampler):
def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
sampler='nestle', use_ratio=True, injection_parameters=None,
conversion_function=None, **kwargs):
conversion_function=None, plot=False, **kwargs):
"""
The primary interface to easy parameter estimation
......@@ -469,7 +476,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
injection_parameters: dict
A dictionary of injection parameters used in creating the data (if
using simulated data). Appended to the result object and saved.
plot: bool
If true, generate a corner plot and, if applicable diagnostic plots
conversion_function: function, optional
Function to apply to posterior to generate additional parameters.
**kwargs:
......@@ -492,7 +500,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
if implemented_samplers.__contains__(sampler.title()):
sampler_class = globals()[sampler.title()]
sampler = sampler_class(likelihood, priors, sampler, outdir=outdir,
label=label, use_ratio=use_ratio,
label=label, use_ratio=use_ratio, plot=plot,
**kwargs)
if sampler.cached_result:
......@@ -519,6 +527,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
result.samples_to_data_frame(likelihood=likelihood, priors=priors, conversion_function=conversion_function)
result.kwargs = sampler.kwargs
result.save_to_file(outdir=outdir, label=label)
if plot:
result.plot_corner()
return result
else:
raise ValueError(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment