diff --git a/requirements.txt b/requirements.txt index f71732b9408956a5cac6709a0079630222d20fde..fe8a629b9a049938e04c73d9c49034b19bffc522 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ future dynesty corner -numpy +numpy>=1.9 matplotlib>=2.0 scipy gwpy diff --git a/tupak/result.py b/tupak/result.py index 07023b03cace71d61d721e7fffcd2817e547c1b0..f54f9dfaa78a7abe21157dbc900c0e43726b8429 100644 --- a/tupak/result.py +++ b/tupak/result.py @@ -3,14 +3,7 @@ import os import numpy as np import deepdish import pandas as pd - -try: - from chainconsumer import ChainConsumer -except ImportError: - def ChainConsumer(): - logging.warning( - "You do not have the optional module chainconsumer installed" - " unable to generate a corner plot") +import corner def result_file_name(outdir, label): @@ -102,28 +95,41 @@ class Result(dict): .format(k)) return return_list - def plot_corner(self, save=True, **kwargs): - """ Plot a corner-plot using chain-consumer + def plot_corner(self, parameters=None, save=True, dpi=300, **kwargs): + """ Plot a corner-plot using corner + + See https://corner.readthedocs.io/en/latest/ for a detailed API. Parameters ---------- + parameters: list + If given, a list of the parameter names to include save: bool If true, save the image using the given label and outdir + **kwargs: + Other keyword arguments are passed to `corner.corner`. We set some + defaults to improve the basic look and feel, but these can all be + overridden. Returns ------- fig: A matplotlib figure instance + """ - # Set some defaults (unless already set) - kwargs['figsize'] = kwargs.get('figsize', 'GROW') - if save: - filename = '{}/{}_corner.png'.format(self.outdir, self.label) - kwargs['filename'] = kwargs.get('filename', filename) - logging.info('Saving corner plot to {}'.format(kwargs['filename'])) + defaults_kwargs = dict( + bins=50, smooth=0.9, label_kwargs=dict(fontsize=16), + title_kwargs=dict(fontsize=16), color='#0072C1', + truth_color='tab:orange', show_titles=True, + quantiles=[0.025, 0.975], levels=(0.39,0.8,0.97), + plot_density=False, plot_datapoints=True, fill_contours=True, + max_n_ticks=3) + + defaults_kwargs.update(kwargs) + kwargs = defaults_kwargs + if getattr(self, 'injection_parameters', None) is not None: - # If no truth argument given, set these to the injection params injection_parameters = [self.injection_parameters[key] for key in self.search_parameter_keys] kwargs['truth'] = kwargs.get('truth', injection_parameters) @@ -133,72 +139,39 @@ class Result(dict): new_keys = self.get_latex_labels_from_parameter_keys(old_keys) for old, new in zip(old_keys, new_keys): kwargs['truth'][new] = kwargs['truth'].pop(old) - if 'parameters' in kwargs: - kwargs['parameters'] = self.get_latex_labels_from_parameter_keys( - kwargs['parameters']) - - # Check all parameter_labels are a valid string - for i, label in enumerate(self.parameter_labels): - if label is None: - self.parameter_labels[i] = 'Unknown' - c = ChainConsumer() - if c: - c.add_chain(self.samples, parameters=self.parameter_labels, - name=self.label) - fig = c.plotter.plot(**kwargs) - return fig - def plot_walks(self, save=True, **kwargs): - """ Plot the chain walks using chain-consumer + if 'truth' in kwargs: + kwargs['truths'] = kwargs.pop('truth') - Parameters - ---------- - save: bool - If true, save the image using the given label and outdir + if parameters: + xs = self.posterior[parameters].values + kwargs['labels'] = kwargs.get( + 'labels', self.get_latex_labels_from_parameter_keys( + parameters)) + else: + xs = self.posterior[self.search_parameter_keys] + kwargs['labels'] = kwargs.get( + 'labels', self.get_latex_labels_from_parameter_keys( + self.search_parameter_keys)) - Returns - ------- - fig: - A matplotlib figure instance - """ + fig = corner.corner(xs, **kwargs) - # Set some defaults (unless already set) if save: - kwargs['filename'] = '{}/{}_walks.png'.format(self.outdir, self.label) - logging.info('Saving walker plot to {}'.format(kwargs['filename'])) - if getattr(self, 'injection_parameters', None) is not None: - kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys] - c = ChainConsumer() - if c: - c.add_chain(self.samples, parameters=self.parameter_labels) - fig = c.plotter.plot_walks(**kwargs) - return fig + filename = '{}/{}_corner.png'.format(self.outdir, self.label) + logging.info('Saving corner plot to {}'.format(filename)) + fig.savefig(filename, dpi=dpi) - def plot_distributions(self, save=True, **kwargs): - """ Plot the chain walks using chain-consumer + return fig - Parameters - ---------- - save: bool - If true, save the image using the given label and outdir - - Returns - ------- - fig: - A matplotlib figure instance + def plot_walks(self, save=True, **kwargs): """ + """ + logging.warning("plot_walks deprecated") - # Set some defaults (unless already set) - if save: - kwargs['filename'] = '{}/{}_distributions.png'.format(self.outdir, self.label) - logging.info('Saving distributions plot to {}'.format(kwargs['filename'])) - if getattr(self, 'injection_parameters', None) is not None: - kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys] - c = ChainConsumer() - if c: - c.add_chain(self.samples, parameters=self.parameter_labels) - fig = c.plotter.plot_distributions(**kwargs) - return fig + def plot_distributions(self, save=True, **kwargs): + """ + """ + logging.warning("plot_distributions deprecated") def write_prior_to_file(self, outdir): """ diff --git a/tupak/sampler.py b/tupak/sampler.py index 16f7a34fe8d2ca72c66f1c67b238911224e03bd6..527f0a811407599fb3dc9a96b8667424f822f777 100644 --- a/tupak/sampler.py +++ b/tupak/sampler.py @@ -325,8 +325,16 @@ class Dynesty(Sampler): out.samples, weights) self.result.logz = out.logz[-1] self.result.logzerr = out.logzerr[-1] + self.generate_trace_plots(out) return self.result + def generate_trace_plots(self, dynesty_results): + from dynesty import plotting as dyplot + fig, axes = dyplot.traceplot(dynesty_results, + labels=self.result.parameter_labels) + fig.tight_layout() + fig.savefig('{}/{}_trace.png'.format(self.outdir, self.label)) + def _run_test(self): dynesty = self.external_sampler nested_sampler = dynesty.NestedSampler(