diff --git a/tupak/core/result.py b/tupak/core/result.py index fab2beb65037a3727ef30c00a13089c138b00f1b..20f7e1ba0029fba8677b49ad790899522786aa9f 100644 --- a/tupak/core/result.py +++ b/tupak/core/result.py @@ -6,6 +6,7 @@ import pandas as pd import corner import matplotlib import matplotlib.pyplot as plt +from collections import OrderedDict from tupak.core import utils from tupak.core.utils import logger @@ -235,7 +236,8 @@ class Result(dict): """ return self.posterior_volume / self.prior_volume(priors) - def plot_corner(self, parameters=None, save=True, dpi=300, **kwargs): + def plot_corner(self, parameters=None, save=True, priors=None, dpi=300, + **kwargs): """ Plot a corner-plot using corner See https://corner.readthedocs.io/en/latest/ for a detailed API. @@ -342,6 +344,66 @@ class Result(dict): utils.check_directory_exists_and_if_not_mkdir(self.outdir) fig.savefig(filename) + def plot_with_data(self, model, x, y, ndraws=1000, npoints=1000, + xlabel=None, ylabel=None, data_label='data', + data_fmt='o', draws_label=None, filename=None, + maxl_label='max likelihood', dpi=300): + """ Generate a figure showing the data and fits to the data + + Parameters + ---------- + model: function + A python function which when called as `model(x, **kwargs)` returns + the model prediction (here `kwargs` is a dictionary of key-value + pairs of the model parameters. + x, y: np.ndarray + The independent and dependent data to plot + ndraws: int + Number of draws from the posterior to plot + npoints: int + Number of points used to plot the smoothed fit to the data + xlabel, ylabel: str + Labels for the axes + data_label, draws_label, maxl_label: str + Label for the data, draws, and max likelihood legend + data_fmt: str + Matpltolib fmt code, defaults to `'-o'` + dpi: int + Passed to `plt.savefig` + filename: str + If given, the filename to use. Otherwise, the filename is generated + from the outdir and label attributes. + + """ + xsmooth = np.linspace(np.min(x), np.max(x), npoints) + fig, ax = plt.subplots() + logger.info('Plotting {} draws'.format(ndraws)) + for _ in range(ndraws): + s = self.posterior.sample().to_dict('records')[0] + ax.plot(xsmooth, model(xsmooth, **s), alpha=0.25, lw=0.1, color='r', + label=draws_label) + if all(~np.isnan(self.posterior.log_likelihood)): + logger.info('Plotting maximum likelihood') + s = self.posterior.ix[self.posterior.log_likelihood.idxmax()] + ax.plot(xsmooth, model(xsmooth, **s), lw=1, color='k', + label=maxl_label) + + ax.plot(x, y, data_fmt, markersize=2, label=data_label) + + if xlabel is not None: + ax.set_xlabel(xlabel) + if ylabel is not None: + ax.set_ylabel(ylabel) + + handles, labels = plt.gca().get_legend_handles_labels() + by_label = OrderedDict(zip(labels, handles)) + plt.legend(by_label.values(), by_label.keys()) + ax.legend(numpoints=3) + fig.tight_layout() + if filename is None: + filename = '{}/{}_plot_with_data'.format(self.outdir, self.label) + fig.savefig(filename, dpi=dpi) + def samples_to_posterior(self, likelihood=None, priors=None, conversion_function=None): """ @@ -362,6 +424,9 @@ class Result(dict): self.samples, columns=self.search_parameter_keys) data_frame['log_likelihood'] = getattr( self, 'log_likelihood_evaluations', np.nan) + for key in priors: + if getattr(priors[key], 'is_fixed', False): + data_frame[key] = priors[key].peak # We save the samples in the posterior and remove the array of samples del self.samples else: