Commit adb8484c authored by Gregory Ashton's avatar Gregory Ashton Committed by Colm Talbot

General tidy and refactor of plot_corner

parent 7a1d4bc6
......@@ -295,16 +295,15 @@ class Result(dict):
string = r"${{{0}}}_{{-{1}}}^{{+{2}}}$"
return string.format(fmt(median), fmt(lower), fmt(upper))
def plot_corner(self, parameters=None, priors=False, titles=True,
save=True, filename=None, dpi=300, **kwargs):
""" Plot a corner-plot using corner
See https://corner.readthedocs.io/en/latest/ for a detailed API.
def plot_corner(self, parameters=None, priors=None, titles=True, save=True,
filename=None, dpi=300, **kwargs):
""" Plot a corner-plot
Parameters
----------
parameters: list, optional
If given, a list of the parameter names to include
parameters: (list, dict), optional
If given, either a list of the parameter names to include, or a
dictionary of parameter names and their "true" values to plot.
priors: {bool (False), tupak.core.prior.PriorSet}
If true, add the stored prior probability density functions to the
one-dimensional marginal distributions. If instead a PriorSet
......@@ -325,20 +324,28 @@ class Result(dict):
defaults to improve the basic look and feel, but these can all be
overridden.
Notes
-----
The generation of the corner plot themselves is done by the corner
python module, see https://corner.readthedocs.io for more
information.
Returns
-------
fig:
A matplotlib figure instance
"""
# If in testing mode, not corner plots are generated
if utils.command_line_args.test:
return
# tupak default corner kwargs. Overwritten by anything passed to kwargs
defaults_kwargs = dict(
bins=50, smooth=0.9, label_kwargs=dict(fontsize=16),
title_kwargs=dict(fontsize=16), color='#0072C1',
truth_color='tab:orange',
quantiles=[0.16, 0.84],
truth_color='tab:orange', quantiles=[0.16, 0.84],
levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.)),
plot_density=False, plot_datapoints=True, fill_contours=True,
max_n_ticks=3)
......@@ -348,41 +355,58 @@ class Result(dict):
else:
defaults_kwargs['hist_kwargs'] = dict(density=True)
if 'lionize' in kwargs and kwargs['lionize'] is True:
defaults_kwargs['truth_color'] = 'tab:blue'
defaults_kwargs['color'] = '#FF8C00'
defaults_kwargs.update(kwargs)
kwargs = defaults_kwargs
if 'truth' in kwargs:
kwargs['truths'] = kwargs.pop('truth')
# If injection parameters where stored, use these as truth values
if getattr(self, 'injection_parameters', None) is not None:
injection_parameters = [self.injection_parameters.get(key, None)
for key in self.search_parameter_keys]
kwargs['truths'] = kwargs.get('truths', injection_parameters)
if parameters is None:
parameters = self.search_parameter_keys
if 'lionize' in kwargs and kwargs['lionize'] is True:
defaults_kwargs['truth_color'] = 'tab:blue'
defaults_kwargs['color'] = '#FF8C00'
# Handle if truths was passed in
if 'truth' in kwargs:
kwargs['truths'] = kwargs.pop('truth')
if kwargs.get('truths'):
truths = kwargs.get('truths')
if isinstance(parameters, list) and isinstance(truths, list):
if len(parameters) != len(truths):
raise ValueError(
"Length of parameters and truths don't match")
elif isinstance(truths, dict) and parameters is None:
parameters = kwargs.pop('truths')
else:
raise ValueError(
"Combination of parameters and truths not understood")
# If parameters is a dictionary, use the keys to determine which
# parameters to plot and the values as truths.
if isinstance(parameters, dict):
plot_parameter_keys = list(parameters.keys())
kwargs['truths'] = list(parameters.values())
elif parameters is None:
plot_parameter_keys = self.search_parameter_keys
else:
plot_parameter_keys = list(parameters)
xs = self.posterior[parameters].values
# Get latex formatted strings for the plot labels
kwargs['labels'] = kwargs.get(
'labels', self.get_latex_labels_from_parameter_keys(
parameters))
if type(kwargs.get('truths')) == dict:
truths = [kwargs['truths'][k] for k in parameters]
kwargs['truths'] = truths
plot_parameter_keys))
# Create the data array to plot and pass everything to corner
xs = self.posterior[plot_parameter_keys].values
fig = corner.corner(xs, **kwargs)
axes = fig.get_axes()
# Add the titles
if titles and kwargs.get('quantiles', None) is not None:
for i, par in enumerate(parameters):
ax = axes[i + i * len(parameters)]
for i, par in enumerate(plot_parameter_keys):
ax = axes[i + i * len(plot_parameter_keys)]
if ax.title.get_text() == '':
ax.set_title(self.get_one_dimensional_median_and_error_bar(
par, quantiles=kwargs['quantiles']),
......@@ -392,8 +416,8 @@ class Result(dict):
if priors is True:
priors = getattr(self, 'priors', False)
if isinstance(priors, dict):
for i, par in enumerate(parameters):
ax = axes[i + i * len(parameters)]
for i, par in enumerate(plot_parameter_keys):
ax = axes[i + i * len(plot_parameter_keys)]
theta = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 300)
ax.plot(theta, priors[par].prob(theta), color='C2')
elif priors in [False, None]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment