diff --git a/tupak/core/result.py b/tupak/core/result.py index ea6a9353c0c45bb00aa45b165bcd317b0d5cbad0..f2b296a02be0d61d0088187f95e9ff0b816309d2 100644 --- a/tupak/core/result.py +++ b/tupak/core/result.py @@ -4,6 +4,7 @@ import numpy as np import deepdish import pandas as pd import corner +import matplotlib def result_file_name(outdir, label): @@ -320,3 +321,69 @@ class Result(dict): elif typeA in [np.ndarray]: return np.all(A == B) return False + + +def plot_multiple(results, filename=None, labels=None, colours=None, + save=True, **kwargs): + """ Generate a corner plot overlaying two sets of results + + Parameters + ---------- + results: list + A list of `tupak.core.result.Result` objects containing the samples to + plot. + filename: str + File name to save the figure to. If None (default), a filename is + constructed from the outdir of the first element of results and then + the labels for all the result files. + labels: list + List of strings to use when generating a legend. If None (default), the + `label` attribute of each result in `results` is used. + colours: list + The colours for each result. If None, default styles are applied. + save: bool + If true, save the figure + kwargs: dict + All other keyword arguments are passed to `result.plot_corner`. + However, `show_titles` and `truths` are ignored since they would be + ambiguous on such a plot. + + Returns + ------- + fig: + A matplotlib figure instance + + """ + + kwargs['show_titles'] = False + kwargs['truths'] = None + + fig = results[0].plot_corner(save=False, **kwargs) + default_filename = '{}/{}'.format(results[0].outdir, results[0].label) + lines = [] + default_labels = [] + for i, result in enumerate(results): + if colours: + c = colours[i] + else: + c = 'C{}'.format(i) + fig = result.plot_corner(fig=fig, save=False, color=c, **kwargs) + default_filename += '_{}'.format(result.label) + lines.append(matplotlib.lines.Line2D([0], [0], color=c)) + default_labels.append(result.label) + + if labels is None: + labels = default_labels + + axes = fig.get_axes() + ndim = int(np.sqrt(len(axes))) + axes[ndim-1].legend(lines, labels) + + if filename is None: + filename = default_filename + + if save: + fig.savefig(filename) + return fig + +