Commit d0abea42 authored by Gregory Ashton's avatar Gregory Ashton

Adds a new function to enable quick plotting of 1D marginals

- Also make the `get_one_dimensional_median_and_error_bar return a
  namedtuple which is more generally useful
parent 6e2d238c
Pipeline #34513 passed with stage
in 8 minutes and 50 seconds
......@@ -6,7 +6,7 @@ import pandas as pd
import corner
import matplotlib
import matplotlib.pyplot as plt
from collections import OrderedDict
from collections import OrderedDict, namedtuple
from . import utils
from .utils import logger, infer_parameters_from_function
......@@ -285,22 +285,112 @@ class Result(dict):
string: str
A string of latex-formatted text of the mean and 1-sigma quantiles
summary: namedtuple
An object with attributes, median, lower, upper and string
summary = namedtuple('summary', ['median', 'lower', 'upper', 'string'])
if len(quantiles) != 2:
raise ValueError("quantiles must be of length 2")
quants_to_compute = np.array([quantiles[0], 0.5, quantiles[1]])
quants = np.percentile(self.posterior[key], quants_to_compute * 100)
median = quants[1]
upper = quants[2] - median
lower = median - quants[0]
summary.median = quants[1] = quants[2] - summary.median
summary.minus = summary.median - quants[0]
fmt = "{{0:{0}}}".format(fmt).format
string = r"${{{0}}}_{{-{1}}}^{{+{2}}}$"
return string.format(fmt(median), fmt(lower), fmt(upper))
string_template = r"${{{0}}}_{{-{1}}}^{{+{2}}}$"
summary.string = string_template.format(
fmt(summary.median), fmt(summary.minus), fmt(
return summary
def plot_marginals(self, parameters=None, priors=None, titles=True,
file_base_name=None, bins=50, label_fontsize=16,
title_fontsize=16, quantiles=[0.16, 0.84], dpi=300):
""" Plot 1D marginal distributions
parameters: (list, dict), optional
If given, either a list of the parameter names to include, or a
dictionary of parameter names and their "true" values to plot.
priors: {bool (False), bilby.core.prior.PriorSet}
If true, add the stored prior probability density functions to the
one-dimensional marginal distributions. If instead a PriorSet
is provided, this will be plotted.
titles: bool
If true, add 1D titles of the median and (by default 1-sigma)
error bars. To change the error bars, pass in the quantiles kwarg.
See method `get_one_dimensional_median_and_error_bar` for further
details). If `quantiles=None` is passed in, no title is added.
file_base_name: str, optional
If given, the base file name to use (by default `outdir/label_` is
bins: int
The number of histogram bins
label_fontsize, title_fontsize: int
The fontsizes for the labels and titles
quantiles: list
A length-2 list of the lower and upper-quantiles to calculate
the errors bars for.
dpi: int
Dots per inch resolution of the plot
figures: dictionary
A dictionary of the matplotlib figures
if isinstance(parameters, dict):
plot_parameter_keys = list(parameters.keys())
truths = list(parameters.values())
elif parameters is None:
plot_parameter_keys = self.search_parameter_keys
truths = None
plot_parameter_keys = list(parameters)
truths = None
labels = self.get_latex_labels_from_parameter_keys(plot_parameter_keys)
if file_base_name is None:
file_base_name = '{}/{}_'.format(self.outdir, self.label)
if priors is True:
priors = getattr(self, 'priors', False)
elif isinstance(priors, (dict)) or priors in [False, None]:
raise ValueError('Input priors={} not understood'.format(priors))
figures = dict()
for i, key in enumerate(plot_parameter_keys):
fig, ax = plt.subplots()
ax.hist(self.posterior[key].values, bins=bins, density=True,
ax.set_xlabel(labels[i], fontsize=label_fontsize)
if truths is not None:
ax.axvline(truths[i], ls='--', color='orange')
summary = self.get_one_dimensional_median_and_error_bar(
key, quantiles=quantiles)
ax.axvline(summary.median - summary.minus, ls='--', color='C0')
ax.axvline(summary.median +, ls='--', color='C0')
if titles:
ax.set_title(summary.string, fontsize=title_fontsize)
if isinstance(priors, dict):
theta = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 300)
ax.plot(theta, priors[key].prob(theta), color='C2')
fig.savefig(file_base_name + key)
figures[key] = fig
return figures
def plot_corner(self, parameters=None, priors=None, titles=True, save=True,
filename=None, dpi=300, **kwargs):
......@@ -418,7 +508,7 @@ class Result(dict):
ax = axes[i + i * len(plot_parameter_keys)]
if ax.title.get_text() == '':
par, quantiles=kwargs['quantiles']),
par, quantiles=kwargs['quantiles']).string,
# Add priors to the 1D plots
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment