diff --git a/tupak/sampler.py b/tupak/sampler.py index e47cecbfc6e1c57c325143631c98a10fc477d31d..0b6c0d316722affc0552513facb204a9e00e81db 100644 --- a/tupak/sampler.py +++ b/tupak/sampler.py @@ -36,8 +36,10 @@ class Sampler(object): """ - def __init__(self, likelihood, priors, external_sampler='nestle', outdir='outdir', label='label', use_ratio=False, - **kwargs): + def __init__( + self, likelihood, priors, external_sampler='nestle', + outdir='outdir', label='label', use_ratio=False, plot=False, + **kwargs): self.likelihood = likelihood self.priors = priors self.label = label @@ -45,6 +47,7 @@ class Sampler(object): self.use_ratio = use_ratio self.external_sampler = external_sampler self.external_sampler_function = None + self.plot = plot self.__search_parameter_keys = [] self.__fixed_parameter_keys = [] @@ -327,15 +330,19 @@ class Dynesty(Sampler): out.samples, weights) self.result.logz = out.logz[-1] self.result.logzerr = out.logzerr[-1] - self.generate_trace_plots(out) + + if self.plot: + self.generate_trace_plots(out) return self.result def generate_trace_plots(self, dynesty_results): + filename = '{}/{}_trace.png'.format(self.outdir, self.label) + logging.info("Writing trace plot to {}".format(filename)) from dynesty import plotting as dyplot fig, axes = dyplot.traceplot(dynesty_results, labels=self.result.parameter_labels) fig.tight_layout() - fig.savefig('{}/{}_trace.png'.format(self.outdir, self.label)) + fig.savefig(filename) def _run_test(self): dynesty = self.external_sampler @@ -444,7 +451,7 @@ class Ptemcee(Sampler): def run_sampler(likelihood, priors=None, label='label', outdir='outdir', sampler='nestle', use_ratio=True, injection_parameters=None, - conversion_function=None, **kwargs): + conversion_function=None, plot=False, **kwargs): """ The primary interface to easy parameter estimation @@ -469,7 +476,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', injection_parameters: dict A dictionary of injection parameters used in creating the data (if using simulated data). Appended to the result object and saved. - + plot: bool + If true, generate a corner plot and, if applicable diagnostic plots conversion_function: function, optional Function to apply to posterior to generate additional parameters. **kwargs: @@ -492,7 +500,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', if implemented_samplers.__contains__(sampler.title()): sampler_class = globals()[sampler.title()] sampler = sampler_class(likelihood, priors, sampler, outdir=outdir, - label=label, use_ratio=use_ratio, + label=label, use_ratio=use_ratio, plot=plot, **kwargs) if sampler.cached_result: @@ -519,6 +527,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', result.samples_to_data_frame(likelihood=likelihood, priors=priors, conversion_function=conversion_function) result.kwargs = sampler.kwargs result.save_to_file(outdir=outdir, label=label) + if plot: + result.plot_corner() return result else: raise ValueError(