diff --git a/tupak/core/result.py b/tupak/core/result.py index 3912a5987ac432e24b3b6960e6c10441afc54eec..1a935110e09e734c6dfb87a1b53a24794705f2ff 100644 --- a/tupak/core/result.py +++ b/tupak/core/result.py @@ -295,16 +295,15 @@ class Result(dict): string = r"${{{0}}}_{{-{1}}}^{{+{2}}}$" return string.format(fmt(median), fmt(lower), fmt(upper)) - def plot_corner(self, parameters=None, priors=False, titles=True, - save=True, filename=None, dpi=300, **kwargs): - """ Plot a corner-plot using corner - - See https://corner.readthedocs.io/en/latest/ for a detailed API. + def plot_corner(self, parameters=None, priors=None, titles=True, save=True, + filename=None, dpi=300, **kwargs): + """ Plot a corner-plot Parameters ---------- - parameters: list, optional - If given, a list of the parameter names to include + parameters: (list, dict), optional + If given, either a list of the parameter names to include, or a + dictionary of parameter names and their "true" values to plot. priors: {bool (False), tupak.core.prior.PriorSet} If true, add the stored prior probability density functions to the one-dimensional marginal distributions. If instead a PriorSet @@ -325,20 +324,28 @@ class Result(dict): defaults to improve the basic look and feel, but these can all be overridden. + Notes + ----- + The generation of the corner plot themselves is done by the corner + python module, see https://corner.readthedocs.io for more + information. + Returns ------- fig: A matplotlib figure instance """ + + # If in testing mode, not corner plots are generated if utils.command_line_args.test: return + # tupak default corner kwargs. Overwritten by anything passed to kwargs defaults_kwargs = dict( bins=50, smooth=0.9, label_kwargs=dict(fontsize=16), title_kwargs=dict(fontsize=16), color='#0072C1', - truth_color='tab:orange', - quantiles=[0.16, 0.84], + truth_color='tab:orange', quantiles=[0.16, 0.84], levels=(1 - np.exp(-0.5), 1 - np.exp(-2), 1 - np.exp(-9 / 2.)), plot_density=False, plot_datapoints=True, fill_contours=True, max_n_ticks=3) @@ -348,41 +355,58 @@ class Result(dict): else: defaults_kwargs['hist_kwargs'] = dict(density=True) + if 'lionize' in kwargs and kwargs['lionize'] is True: + defaults_kwargs['truth_color'] = 'tab:blue' + defaults_kwargs['color'] = '#FF8C00' + defaults_kwargs.update(kwargs) kwargs = defaults_kwargs - if 'truth' in kwargs: - kwargs['truths'] = kwargs.pop('truth') - + # If injection parameters where stored, use these as truth values if getattr(self, 'injection_parameters', None) is not None: injection_parameters = [self.injection_parameters.get(key, None) for key in self.search_parameter_keys] kwargs['truths'] = kwargs.get('truths', injection_parameters) - if parameters is None: - parameters = self.search_parameter_keys - - if 'lionize' in kwargs and kwargs['lionize'] is True: - defaults_kwargs['truth_color'] = 'tab:blue' - defaults_kwargs['color'] = '#FF8C00' + # Handle if truths was passed in + if 'truth' in kwargs: + kwargs['truths'] = kwargs.pop('truth') + if kwargs.get('truths'): + truths = kwargs.get('truths') + if isinstance(parameters, list) and isinstance(truths, list): + if len(parameters) != len(truths): + raise ValueError( + "Length of parameters and truths don't match") + elif isinstance(truths, dict) and parameters is None: + parameters = kwargs.pop('truths') + else: + raise ValueError( + "Combination of parameters and truths not understood") + + # If parameters is a dictionary, use the keys to determine which + # parameters to plot and the values as truths. + if isinstance(parameters, dict): + plot_parameter_keys = list(parameters.keys()) + kwargs['truths'] = list(parameters.values()) + elif parameters is None: + plot_parameter_keys = self.search_parameter_keys + else: + plot_parameter_keys = list(parameters) - xs = self.posterior[parameters].values + # Get latex formatted strings for the plot labels kwargs['labels'] = kwargs.get( 'labels', self.get_latex_labels_from_parameter_keys( - parameters)) - - if type(kwargs.get('truths')) == dict: - truths = [kwargs['truths'][k] for k in parameters] - kwargs['truths'] = truths + plot_parameter_keys)) + # Create the data array to plot and pass everything to corner + xs = self.posterior[plot_parameter_keys].values fig = corner.corner(xs, **kwargs) - axes = fig.get_axes() # Add the titles if titles and kwargs.get('quantiles', None) is not None: - for i, par in enumerate(parameters): - ax = axes[i + i * len(parameters)] + for i, par in enumerate(plot_parameter_keys): + ax = axes[i + i * len(plot_parameter_keys)] if ax.title.get_text() == '': ax.set_title(self.get_one_dimensional_median_and_error_bar( par, quantiles=kwargs['quantiles']), @@ -392,8 +416,8 @@ class Result(dict): if priors is True: priors = getattr(self, 'priors', False) if isinstance(priors, dict): - for i, par in enumerate(parameters): - ax = axes[i + i * len(parameters)] + for i, par in enumerate(plot_parameter_keys): + ax = axes[i + i * len(plot_parameter_keys)] theta = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 300) ax.plot(theta, priors[par].prob(theta), color='C2') elif priors in [False, None]: