Skip to content
Snippets Groups Projects
Commit af1e374b authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Add function to plot data with draws from the posterior and max like

Fix 174
parent 8d8ddd57
No related branches found
No related tags found
1 merge request!176Add function to plot data with draws from the posterior and max like
Pipeline #30257 passed
......@@ -6,6 +6,7 @@ import pandas as pd
import corner
import matplotlib
import matplotlib.pyplot as plt
from collections import OrderedDict
from tupak.core import utils
from tupak.core.utils import logger
......@@ -342,6 +343,66 @@ class Result(dict):
utils.check_directory_exists_and_if_not_mkdir(self.outdir)
fig.savefig(filename)
def plot_with_data(self, model, x, y, ndraws=1000, npoints=1000,
xlabel=None, ylabel=None, data_label='data',
data_fmt='o', draws_label=None, filename=None,
maxl_label='max likelihood', dpi=300):
""" Generate a figure showing the data and fits to the data
Parameters
----------
model: function
A python function which when called as `model(x, **kwargs)` returns
the model prediction (here `kwargs` is a dictionary of key-value
pairs of the model parameters.
x, y: np.ndarray
The independent and dependent data to plot
ndraws: int
Number of draws from the posterior to plot
npoints: int
Number of points used to plot the smoothed fit to the data
xlabel, ylabel: str
Labels for the axes
data_label, draws_label, maxl_label: str
Label for the data, draws, and max likelihood legend
data_fmt: str
Matpltolib fmt code, defaults to `'-o'`
dpi: int
Passed to `plt.savefig`
filename: str
If given, the filename to use. Otherwise, the filename is generated
from the outdir and label attributes.
"""
xsmooth = np.linspace(np.min(x), np.max(x), npoints)
fig, ax = plt.subplots()
logger.info('Plotting {} draws'.format(ndraws))
for _ in range(ndraws):
s = self.posterior.sample().to_dict('records')[0]
ax.plot(xsmooth, model(xsmooth, **s), alpha=0.25, lw=0.1, color='r',
label=draws_label)
if all(~np.isnan(self.posterior.log_likelihood)):
logger.info('Plotting maximum likelihood')
s = self.posterior.ix[self.posterior.log_likelihood.idxmax()]
ax.plot(xsmooth, model(xsmooth, **s), lw=1, color='k',
label=maxl_label)
ax.plot(x, y, data_fmt, markersize=2, label=data_label)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
handles, labels = plt.gca().get_legend_handles_labels()
by_label = OrderedDict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())
ax.legend(numpoints=3)
fig.tight_layout()
if filename is None:
filename = '{}/{}_plot_with_data'.format(self.outdir, self.label)
fig.savefig(filename, dpi=dpi)
def samples_to_posterior(self, likelihood=None, priors=None,
conversion_function=None):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment