Skip to content
Snippets Groups Projects
Commit 5fdbc80b authored by Paul Lasky's avatar Paul Lasky
Browse files

Merge branch 'add-multiple-plot' into 'master'

Adds function to generate multiple-result corner plots

Closes #110

See merge request Monash/tupak!77
parents 46f5556d 33778c18
No related branches found
No related tags found
1 merge request!77Adds function to generate multiple-result corner plots
Pipeline #
......@@ -4,6 +4,7 @@ import numpy as np
import deepdish
import pandas as pd
import corner
import matplotlib
def result_file_name(outdir, label):
......@@ -320,3 +321,69 @@ class Result(dict):
elif typeA in [np.ndarray]:
return np.all(A == B)
return False
def plot_multiple(results, filename=None, labels=None, colours=None,
save=True, **kwargs):
""" Generate a corner plot overlaying two sets of results
Parameters
----------
results: list
A list of `tupak.core.result.Result` objects containing the samples to
plot.
filename: str
File name to save the figure to. If None (default), a filename is
constructed from the outdir of the first element of results and then
the labels for all the result files.
labels: list
List of strings to use when generating a legend. If None (default), the
`label` attribute of each result in `results` is used.
colours: list
The colours for each result. If None, default styles are applied.
save: bool
If true, save the figure
kwargs: dict
All other keyword arguments are passed to `result.plot_corner`.
However, `show_titles` and `truths` are ignored since they would be
ambiguous on such a plot.
Returns
-------
fig:
A matplotlib figure instance
"""
kwargs['show_titles'] = False
kwargs['truths'] = None
fig = results[0].plot_corner(save=False, **kwargs)
default_filename = '{}/{}'.format(results[0].outdir, results[0].label)
lines = []
default_labels = []
for i, result in enumerate(results):
if colours:
c = colours[i]
else:
c = 'C{}'.format(i)
fig = result.plot_corner(fig=fig, save=False, color=c, **kwargs)
default_filename += '_{}'.format(result.label)
lines.append(matplotlib.lines.Line2D([0], [0], color=c))
default_labels.append(result.label)
if labels is None:
labels = default_labels
axes = fig.get_axes()
ndim = int(np.sqrt(len(axes)))
axes[ndim-1].legend(lines, labels)
if filename is None:
filename = default_filename
if save:
fig.savefig(filename)
return fig
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment