Skip to content
Snippets Groups Projects
Commit 2a4e9ed8 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Deprecate walks and dist and update corner to use corner

parent 8d52c679
No related branches found
No related tags found
1 merge request!46Move to corner
......@@ -3,14 +3,7 @@ import os
import numpy as np
import deepdish
import pandas as pd
try:
from chainconsumer import ChainConsumer
except ImportError:
def ChainConsumer():
logging.warning(
"You do not have the optional module chainconsumer installed"
" unable to generate a corner plot")
import corner
def result_file_name(outdir, label):
......@@ -102,11 +95,13 @@ class Result(dict):
.format(k))
return return_list
def plot_corner(self, save=True, **kwargs):
def plot_corner(self, parameters=None, save=True, **kwargs):
""" Plot a corner-plot using chain-consumer
Parameters
----------
parameters: list
If given, a list of the parameter names to include
save: bool
If true, save the image using the given label and outdir
......@@ -116,14 +111,7 @@ class Result(dict):
A matplotlib figure instance
"""
# Set some defaults (unless already set)
kwargs['figsize'] = kwargs.get('figsize', 'GROW')
if save:
filename = '{}/{}_corner.png'.format(self.outdir, self.label)
kwargs['filename'] = kwargs.get('filename', filename)
logging.info('Saving corner plot to {}'.format(kwargs['filename']))
if getattr(self, 'injection_parameters', None) is not None:
# If no truth argument given, set these to the injection params
injection_parameters = [self.injection_parameters[key]
for key in self.search_parameter_keys]
kwargs['truth'] = kwargs.get('truth', injection_parameters)
......@@ -133,72 +121,39 @@ class Result(dict):
new_keys = self.get_latex_labels_from_parameter_keys(old_keys)
for old, new in zip(old_keys, new_keys):
kwargs['truth'][new] = kwargs['truth'].pop(old)
if 'parameters' in kwargs:
kwargs['parameters'] = self.get_latex_labels_from_parameter_keys(
kwargs['parameters'])
# Check all parameter_labels are a valid string
for i, label in enumerate(self.parameter_labels):
if label is None:
self.parameter_labels[i] = 'Unknown'
c = ChainConsumer()
if c:
c.add_chain(self.samples, parameters=self.parameter_labels,
name=self.label)
fig = c.plotter.plot(**kwargs)
return fig
def plot_walks(self, save=True, **kwargs):
""" Plot the chain walks using chain-consumer
if 'truth' in kwargs:
kwargs['truths'] = kwargs.pop('truth')
Parameters
----------
save: bool
If true, save the image using the given label and outdir
if parameters:
xs = self.posterior[parameters].values
kwargs['labels'] = kwargs.get(
'labels', self.get_latex_labels_from_parameter_keys(
parameters))
else:
xs = self.posterior[self.search_parameter_keys]
kwargs['labels'] = kwargs.get(
'labels', self.get_latex_labels_from_parameter_keys(
self.search_parameter_keys))
Returns
-------
fig:
A matplotlib figure instance
"""
fig = corner.corner(xs, **kwargs)
# Set some defaults (unless already set)
if save:
kwargs['filename'] = '{}/{}_walks.png'.format(self.outdir, self.label)
logging.info('Saving walker plot to {}'.format(kwargs['filename']))
if getattr(self, 'injection_parameters', None) is not None:
kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys]
c = ChainConsumer()
if c:
c.add_chain(self.samples, parameters=self.parameter_labels)
fig = c.plotter.plot_walks(**kwargs)
return fig
def plot_distributions(self, save=True, **kwargs):
""" Plot the chain walks using chain-consumer
filename = '{}/{}_corner.png'.format(self.outdir, self.label)
logging.info('Saving corner plot to {}'.format(filename))
fig.savefig(filename)
Parameters
----------
save: bool
If true, save the image using the given label and outdir
return fig
Returns
-------
fig:
A matplotlib figure instance
def plot_walks(self, save=True, **kwargs):
"""
"""
logging.warning("plot_walks deprecated")
# Set some defaults (unless already set)
if save:
kwargs['filename'] = '{}/{}_distributions.png'.format(self.outdir, self.label)
logging.info('Saving distributions plot to {}'.format(kwargs['filename']))
if getattr(self, 'injection_parameters', None) is not None:
kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys]
c = ChainConsumer()
if c:
c.add_chain(self.samples, parameters=self.parameter_labels)
fig = c.plotter.plot_distributions(**kwargs)
return fig
def plot_distributions(self, save=True, **kwargs):
"""
"""
logging.warning("plot_distributions deprecated")
def write_prior_to_file(self, outdir):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment