Skip to content
Snippets Groups Projects

Move to corner

Merged Gregory Ashton requested to merge move-to-corner into master
3 files
+ 56
75
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 47
74
@@ -3,14 +3,7 @@ import os
import numpy as np
import deepdish
import pandas as pd
try:
from chainconsumer import ChainConsumer
except ImportError:
def ChainConsumer():
logging.warning(
"You do not have the optional module chainconsumer installed"
" unable to generate a corner plot")
import corner
def result_file_name(outdir, label):
@@ -102,28 +95,41 @@ class Result(dict):
.format(k))
return return_list
def plot_corner(self, save=True, **kwargs):
""" Plot a corner-plot using chain-consumer
def plot_corner(self, parameters=None, save=True, dpi=300, **kwargs):
""" Plot a corner-plot using corner
See https://corner.readthedocs.io/en/latest/ for a detailed API.
Parameters
----------
parameters: list
If given, a list of the parameter names to include
save: bool
If true, save the image using the given label and outdir
**kwargs:
Other keyword arguments are passed to `corner.corner`. We set some
defaults to improve the basic look and feel, but these can all be
overridden.
Returns
-------
fig:
A matplotlib figure instance
"""
# Set some defaults (unless already set)
kwargs['figsize'] = kwargs.get('figsize', 'GROW')
if save:
filename = '{}/{}_corner.png'.format(self.outdir, self.label)
kwargs['filename'] = kwargs.get('filename', filename)
logging.info('Saving corner plot to {}'.format(kwargs['filename']))
defaults_kwargs = dict(
bins=50, smooth=0.9, label_kwargs=dict(fontsize=16),
title_kwargs=dict(fontsize=16), color='#0072C1',
truth_color='tab:orange', show_titles=True,
quantiles=[0.025, 0.975], levels=(0.39,0.8,0.97),
plot_density=False, plot_datapoints=True, fill_contours=True,
max_n_ticks=3)
defaults_kwargs.update(kwargs)
kwargs = defaults_kwargs
if getattr(self, 'injection_parameters', None) is not None:
# If no truth argument given, set these to the injection params
injection_parameters = [self.injection_parameters[key]
for key in self.search_parameter_keys]
kwargs['truth'] = kwargs.get('truth', injection_parameters)
@@ -133,72 +139,39 @@ class Result(dict):
new_keys = self.get_latex_labels_from_parameter_keys(old_keys)
for old, new in zip(old_keys, new_keys):
kwargs['truth'][new] = kwargs['truth'].pop(old)
if 'parameters' in kwargs:
kwargs['parameters'] = self.get_latex_labels_from_parameter_keys(
kwargs['parameters'])
# Check all parameter_labels are a valid string
for i, label in enumerate(self.parameter_labels):
if label is None:
self.parameter_labels[i] = 'Unknown'
c = ChainConsumer()
if c:
c.add_chain(self.samples, parameters=self.parameter_labels,
name=self.label)
fig = c.plotter.plot(**kwargs)
return fig
def plot_walks(self, save=True, **kwargs):
""" Plot the chain walks using chain-consumer
if 'truth' in kwargs:
kwargs['truths'] = kwargs.pop('truth')
Parameters
----------
save: bool
If true, save the image using the given label and outdir
if parameters:
xs = self.posterior[parameters].values
kwargs['labels'] = kwargs.get(
'labels', self.get_latex_labels_from_parameter_keys(
parameters))
else:
xs = self.posterior[self.search_parameter_keys]
kwargs['labels'] = kwargs.get(
'labels', self.get_latex_labels_from_parameter_keys(
self.search_parameter_keys))
Returns
-------
fig:
A matplotlib figure instance
"""
fig = corner.corner(xs, **kwargs)
# Set some defaults (unless already set)
if save:
kwargs['filename'] = '{}/{}_walks.png'.format(self.outdir, self.label)
logging.info('Saving walker plot to {}'.format(kwargs['filename']))
if getattr(self, 'injection_parameters', None) is not None:
kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys]
c = ChainConsumer()
if c:
c.add_chain(self.samples, parameters=self.parameter_labels)
fig = c.plotter.plot_walks(**kwargs)
return fig
filename = '{}/{}_corner.png'.format(self.outdir, self.label)
logging.info('Saving corner plot to {}'.format(filename))
fig.savefig(filename, dpi=dpi)
def plot_distributions(self, save=True, **kwargs):
""" Plot the chain walks using chain-consumer
return fig
Parameters
----------
save: bool
If true, save the image using the given label and outdir
Returns
-------
fig:
A matplotlib figure instance
def plot_walks(self, save=True, **kwargs):
"""
"""
logging.warning("plot_walks deprecated")
# Set some defaults (unless already set)
if save:
kwargs['filename'] = '{}/{}_distributions.png'.format(self.outdir, self.label)
logging.info('Saving distributions plot to {}'.format(kwargs['filename']))
if getattr(self, 'injection_parameters', None) is not None:
kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys]
c = ChainConsumer()
if c:
c.add_chain(self.samples, parameters=self.parameter_labels)
fig = c.plotter.plot_distributions(**kwargs)
return fig
def plot_distributions(self, save=True, **kwargs):
"""
"""
logging.warning("plot_distributions deprecated")
def write_prior_to_file(self, outdir):
"""
Loading