diff --git a/bilby/core/result.py b/bilby/core/result.py index a4630475fad32a85d70dd3d0becc946d948c8e87..968c8df724b25a63f4666a470ebc3fe17fbdaa5a 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -3,6 +3,7 @@ from __future__ import division import os from distutils.version import LooseVersion from collections import OrderedDict, namedtuple +from itertools import product import numpy as np import pandas as pd @@ -1051,20 +1052,28 @@ class Result(object): self.prior_values[key]\ = priors[key].prob(self.posterior[key].values) - def get_all_injection_credible_levels(self): + def get_all_injection_credible_levels(self, keys=None): """ - Get credible levels for all parameters in self.injection_parameters + Get credible levels for all parameters + + Parameters + ---------- + keys: list, optional + A list of keys for which return the credible levels, if None, + defaults to search_parameter_keys Returns ------- credible_levels: dict The credible levels at which the injected parameters are found. """ + if keys is None: + keys = self.search_parameter_keys if self.injection_parameters is None: raise(TypeError, "Result object has no 'injection_parameters'. " - "Cannot copmute credible levels.") + "Cannot compute credible levels.") credible_levels = {key: self.get_injection_credible_level(key) - for key in self.search_parameter_keys + for key in keys if isinstance(self.injection_parameters[key], float)} return credible_levels @@ -1259,7 +1268,8 @@ def plot_multiple(results, filename=None, labels=None, colours=None, return fig -def make_pp_plot(results, filename=None, save=True, **kwargs): +def make_pp_plot(results, filename=None, save=True, confidence_interval=0.9, + lines=None, legend_fontsize=9, keys=None, **kwargs): """ Make a P-P plot for a set of runs with injected signals. @@ -1271,6 +1281,15 @@ def make_pp_plot(results, filename=None, save=True, **kwargs): The name of the file to save, the default is "outdir/pp.png" save: bool, optional Whether to save the file, default=True + confidence_interval: float, optional + The confidence interval to be plotted, defaulting to 0.9 (90%) + lines: list + If given, a list of matplotlib line formats to use, must be greater + than the number of parameters. + legend_fontsize: float + The font size for the legend + keys: list + A list of keys to use, if None defaults to search_parameter_keys kwargs: Additional kwargs to pass to matplotlib.pyplot.plot @@ -1279,25 +1298,47 @@ def make_pp_plot(results, filename=None, save=True, **kwargs): fig: matplotlib figure """ - fig = plt.figure() + credible_levels = pd.DataFrame() for result in results: credible_levels = credible_levels.append( - result.get_all_injection_credible_levels(), ignore_index=True) - n_parameters = len(credible_levels.keys()) - x_values = np.linspace(0, 1, 101) - for key in credible_levels: - plt.plot(x_values, [sum(credible_levels[key].values < xx) / - len(credible_levels) for xx in x_values], - color='k', alpha=min([1, 4 / n_parameters]), **kwargs) - plt.plot([0, 1], [0, 1], linestyle='--', color='r') - plt.xlim(0, 1) - plt.ylim(0, 1) - plt.tight_layout() + result.get_all_injection_credible_levels(keys), ignore_index=True) + + if lines is None: + colors = ["C{}".format(i) for i in range(8)] + linestyles = ["-", "--", ":"] + lines = ["{}{}".format(a, b) for a, b in product(linestyles, colors)] + if len(lines) < len(credible_levels.keys()): + raise ValueError("Larger number of parameters than unique linestyles") + + x_values = np.linspace(0, 1, 1001) + + # Putting in the confidence bands + N = len(credible_levels) + edge_of_bound = (1. - confidence_interval) / 2. + lower = scipy.stats.binom.ppf(1 - edge_of_bound, N, x_values) / N + upper = scipy.stats.binom.ppf(edge_of_bound, N, x_values) / N + # The binomial point percent function doesn't always return 0 @ 0, + # so set those bounds explicitly to be sure + lower[0] = 0 + upper[0] = 0 + fig, ax = plt.subplots() + + ax.fill_between(x_values, lower, upper, alpha=0.2, color='k') + + for ii, key in enumerate(credible_levels): + pp = np.array([sum(credible_levels[key].values < xx) / + len(credible_levels) for xx in x_values]) + plt.plot(x_values, pp, lines[ii], label=key, **kwargs) + + ax.legend(fontsize=legend_fontsize) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + fig.tight_layout() if save: if filename is None: filename = 'outdir/pp.png' - plt.savefig(filename) + fig.savefig(filename, dpi=500) return fig