Skip to content
Snippets Groups Projects
Commit 325e0499 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch 'resolve-174' into 'master'

Add function to plot data with draws from the posterior and max like

See merge request Monash/tupak!176
parents d2bd4726 1b30b8aa
No related branches found
No related tags found
1 merge request!176Add function to plot data with draws from the posterior and max like
Pipeline #30645 passed
......@@ -6,6 +6,7 @@ import pandas as pd
import corner
import matplotlib
import matplotlib.pyplot as plt
from collections import OrderedDict
from tupak.core import utils
from tupak.core.utils import logger
......@@ -235,7 +236,8 @@ class Result(dict):
"""
return self.posterior_volume / self.prior_volume(priors)
def plot_corner(self, parameters=None, save=True, dpi=300, **kwargs):
def plot_corner(self, parameters=None, save=True, priors=None, dpi=300,
**kwargs):
""" Plot a corner-plot using corner
See https://corner.readthedocs.io/en/latest/ for a detailed API.
......@@ -342,6 +344,66 @@ class Result(dict):
utils.check_directory_exists_and_if_not_mkdir(self.outdir)
fig.savefig(filename)
def plot_with_data(self, model, x, y, ndraws=1000, npoints=1000,
xlabel=None, ylabel=None, data_label='data',
data_fmt='o', draws_label=None, filename=None,
maxl_label='max likelihood', dpi=300):
""" Generate a figure showing the data and fits to the data
Parameters
----------
model: function
A python function which when called as `model(x, **kwargs)` returns
the model prediction (here `kwargs` is a dictionary of key-value
pairs of the model parameters.
x, y: np.ndarray
The independent and dependent data to plot
ndraws: int
Number of draws from the posterior to plot
npoints: int
Number of points used to plot the smoothed fit to the data
xlabel, ylabel: str
Labels for the axes
data_label, draws_label, maxl_label: str
Label for the data, draws, and max likelihood legend
data_fmt: str
Matpltolib fmt code, defaults to `'-o'`
dpi: int
Passed to `plt.savefig`
filename: str
If given, the filename to use. Otherwise, the filename is generated
from the outdir and label attributes.
"""
xsmooth = np.linspace(np.min(x), np.max(x), npoints)
fig, ax = plt.subplots()
logger.info('Plotting {} draws'.format(ndraws))
for _ in range(ndraws):
s = self.posterior.sample().to_dict('records')[0]
ax.plot(xsmooth, model(xsmooth, **s), alpha=0.25, lw=0.1, color='r',
label=draws_label)
if all(~np.isnan(self.posterior.log_likelihood)):
logger.info('Plotting maximum likelihood')
s = self.posterior.ix[self.posterior.log_likelihood.idxmax()]
ax.plot(xsmooth, model(xsmooth, **s), lw=1, color='k',
label=maxl_label)
ax.plot(x, y, data_fmt, markersize=2, label=data_label)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
handles, labels = plt.gca().get_legend_handles_labels()
by_label = OrderedDict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())
ax.legend(numpoints=3)
fig.tight_layout()
if filename is None:
filename = '{}/{}_plot_with_data'.format(self.outdir, self.label)
fig.savefig(filename, dpi=dpi)
def samples_to_posterior(self, likelihood=None, priors=None,
conversion_function=None):
"""
......@@ -362,6 +424,9 @@ class Result(dict):
self.samples, columns=self.search_parameter_keys)
data_frame['log_likelihood'] = getattr(
self, 'log_likelihood_evaluations', np.nan)
for key in priors:
if getattr(priors[key], 'is_fixed', False):
data_frame[key] = priors[key].peak
# We save the samples in the posterior and remove the array of samples
del self.samples
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment