Commit 2a4e9ed8 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Deprecate walks and dist and update corner to use corner

parent 8d52c679
...@@ -3,14 +3,7 @@ import os ...@@ -3,14 +3,7 @@ import os
import numpy as np import numpy as np
import deepdish import deepdish
import pandas as pd import pandas as pd
import corner
try:
from chainconsumer import ChainConsumer
except ImportError:
def ChainConsumer():
logging.warning(
"You do not have the optional module chainconsumer installed"
" unable to generate a corner plot")
def result_file_name(outdir, label): def result_file_name(outdir, label):
...@@ -102,11 +95,13 @@ class Result(dict): ...@@ -102,11 +95,13 @@ class Result(dict):
.format(k)) .format(k))
return return_list return return_list
def plot_corner(self, save=True, **kwargs): def plot_corner(self, parameters=None, save=True, **kwargs):
""" Plot a corner-plot using chain-consumer """ Plot a corner-plot using chain-consumer
Parameters Parameters
---------- ----------
parameters: list
If given, a list of the parameter names to include
save: bool save: bool
If true, save the image using the given label and outdir If true, save the image using the given label and outdir
...@@ -116,14 +111,7 @@ class Result(dict): ...@@ -116,14 +111,7 @@ class Result(dict):
A matplotlib figure instance A matplotlib figure instance
""" """
# Set some defaults (unless already set)
kwargs['figsize'] = kwargs.get('figsize', 'GROW')
if save:
filename = '{}/{}_corner.png'.format(self.outdir, self.label)
kwargs['filename'] = kwargs.get('filename', filename)
logging.info('Saving corner plot to {}'.format(kwargs['filename']))
if getattr(self, 'injection_parameters', None) is not None: if getattr(self, 'injection_parameters', None) is not None:
# If no truth argument given, set these to the injection params
injection_parameters = [self.injection_parameters[key] injection_parameters = [self.injection_parameters[key]
for key in self.search_parameter_keys] for key in self.search_parameter_keys]
kwargs['truth'] = kwargs.get('truth', injection_parameters) kwargs['truth'] = kwargs.get('truth', injection_parameters)
...@@ -133,72 +121,39 @@ class Result(dict): ...@@ -133,72 +121,39 @@ class Result(dict):
new_keys = self.get_latex_labels_from_parameter_keys(old_keys) new_keys = self.get_latex_labels_from_parameter_keys(old_keys)
for old, new in zip(old_keys, new_keys): for old, new in zip(old_keys, new_keys):
kwargs['truth'][new] = kwargs['truth'].pop(old) kwargs['truth'][new] = kwargs['truth'].pop(old)
if 'parameters' in kwargs:
kwargs['parameters'] = self.get_latex_labels_from_parameter_keys(
kwargs['parameters'])
# Check all parameter_labels are a valid string
for i, label in enumerate(self.parameter_labels):
if label is None:
self.parameter_labels[i] = 'Unknown'
c = ChainConsumer()
if c:
c.add_chain(self.samples, parameters=self.parameter_labels,
name=self.label)
fig = c.plotter.plot(**kwargs)
return fig
def plot_walks(self, save=True, **kwargs): if 'truth' in kwargs:
""" Plot the chain walks using chain-consumer kwargs['truths'] = kwargs.pop('truth')
Parameters if parameters:
---------- xs = self.posterior[parameters].values
save: bool kwargs['labels'] = kwargs.get(
If true, save the image using the given label and outdir 'labels', self.get_latex_labels_from_parameter_keys(
parameters))
else:
xs = self.posterior[self.search_parameter_keys]
kwargs['labels'] = kwargs.get(
'labels', self.get_latex_labels_from_parameter_keys(
self.search_parameter_keys))
Returns fig = corner.corner(xs, **kwargs)
-------
fig:
A matplotlib figure instance
"""
# Set some defaults (unless already set)
if save: if save:
kwargs['filename'] = '{}/{}_walks.png'.format(self.outdir, self.label) filename = '{}/{}_corner.png'.format(self.outdir, self.label)
logging.info('Saving walker plot to {}'.format(kwargs['filename'])) logging.info('Saving corner plot to {}'.format(filename))
if getattr(self, 'injection_parameters', None) is not None: fig.savefig(filename)
kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys]
c = ChainConsumer()
if c:
c.add_chain(self.samples, parameters=self.parameter_labels)
fig = c.plotter.plot_walks(**kwargs)
return fig
def plot_distributions(self, save=True, **kwargs):
""" Plot the chain walks using chain-consumer
Parameters return fig
----------
save: bool
If true, save the image using the given label and outdir
Returns def plot_walks(self, save=True, **kwargs):
-------
fig:
A matplotlib figure instance
""" """
"""
logging.warning("plot_walks deprecated")
# Set some defaults (unless already set) def plot_distributions(self, save=True, **kwargs):
if save: """
kwargs['filename'] = '{}/{}_distributions.png'.format(self.outdir, self.label) """
logging.info('Saving distributions plot to {}'.format(kwargs['filename'])) logging.warning("plot_distributions deprecated")
if getattr(self, 'injection_parameters', None) is not None:
kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys]
c = ChainConsumer()
if c:
c.add_chain(self.samples, parameters=self.parameter_labels)
fig = c.plotter.plot_distributions(**kwargs)
return fig
def write_prior_to_file(self, outdir): def write_prior_to_file(self, outdir):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment