Skip to content
Snippets Groups Projects
Commit 325e0499 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch 'resolve-174' into 'master'

Add function to plot data with draws from the posterior and max like

See merge request Monash/tupak!176
parents d2bd4726 1b30b8aa
No related branches found
No related tags found
1 merge request!176Add function to plot data with draws from the posterior and max like
Pipeline #30645 passed
...@@ -6,6 +6,7 @@ import pandas as pd ...@@ -6,6 +6,7 @@ import pandas as pd
import corner import corner
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from collections import OrderedDict
from tupak.core import utils from tupak.core import utils
from tupak.core.utils import logger from tupak.core.utils import logger
...@@ -235,7 +236,8 @@ class Result(dict): ...@@ -235,7 +236,8 @@ class Result(dict):
""" """
return self.posterior_volume / self.prior_volume(priors) return self.posterior_volume / self.prior_volume(priors)
def plot_corner(self, parameters=None, save=True, dpi=300, **kwargs): def plot_corner(self, parameters=None, save=True, priors=None, dpi=300,
**kwargs):
""" Plot a corner-plot using corner """ Plot a corner-plot using corner
See https://corner.readthedocs.io/en/latest/ for a detailed API. See https://corner.readthedocs.io/en/latest/ for a detailed API.
...@@ -342,6 +344,66 @@ class Result(dict): ...@@ -342,6 +344,66 @@ class Result(dict):
utils.check_directory_exists_and_if_not_mkdir(self.outdir) utils.check_directory_exists_and_if_not_mkdir(self.outdir)
fig.savefig(filename) fig.savefig(filename)
def plot_with_data(self, model, x, y, ndraws=1000, npoints=1000,
xlabel=None, ylabel=None, data_label='data',
data_fmt='o', draws_label=None, filename=None,
maxl_label='max likelihood', dpi=300):
""" Generate a figure showing the data and fits to the data
Parameters
----------
model: function
A python function which when called as `model(x, **kwargs)` returns
the model prediction (here `kwargs` is a dictionary of key-value
pairs of the model parameters.
x, y: np.ndarray
The independent and dependent data to plot
ndraws: int
Number of draws from the posterior to plot
npoints: int
Number of points used to plot the smoothed fit to the data
xlabel, ylabel: str
Labels for the axes
data_label, draws_label, maxl_label: str
Label for the data, draws, and max likelihood legend
data_fmt: str
Matpltolib fmt code, defaults to `'-o'`
dpi: int
Passed to `plt.savefig`
filename: str
If given, the filename to use. Otherwise, the filename is generated
from the outdir and label attributes.
"""
xsmooth = np.linspace(np.min(x), np.max(x), npoints)
fig, ax = plt.subplots()
logger.info('Plotting {} draws'.format(ndraws))
for _ in range(ndraws):
s = self.posterior.sample().to_dict('records')[0]
ax.plot(xsmooth, model(xsmooth, **s), alpha=0.25, lw=0.1, color='r',
label=draws_label)
if all(~np.isnan(self.posterior.log_likelihood)):
logger.info('Plotting maximum likelihood')
s = self.posterior.ix[self.posterior.log_likelihood.idxmax()]
ax.plot(xsmooth, model(xsmooth, **s), lw=1, color='k',
label=maxl_label)
ax.plot(x, y, data_fmt, markersize=2, label=data_label)
if xlabel is not None:
ax.set_xlabel(xlabel)
if ylabel is not None:
ax.set_ylabel(ylabel)
handles, labels = plt.gca().get_legend_handles_labels()
by_label = OrderedDict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())
ax.legend(numpoints=3)
fig.tight_layout()
if filename is None:
filename = '{}/{}_plot_with_data'.format(self.outdir, self.label)
fig.savefig(filename, dpi=dpi)
def samples_to_posterior(self, likelihood=None, priors=None, def samples_to_posterior(self, likelihood=None, priors=None,
conversion_function=None): conversion_function=None):
""" """
...@@ -362,6 +424,9 @@ class Result(dict): ...@@ -362,6 +424,9 @@ class Result(dict):
self.samples, columns=self.search_parameter_keys) self.samples, columns=self.search_parameter_keys)
data_frame['log_likelihood'] = getattr( data_frame['log_likelihood'] = getattr(
self, 'log_likelihood_evaluations', np.nan) self, 'log_likelihood_evaluations', np.nan)
for key in priors:
if getattr(priors[key], 'is_fixed', False):
data_frame[key] = priors[key].peak
# We save the samples in the posterior and remove the array of samples # We save the samples in the posterior and remove the array of samples
del self.samples del self.samples
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment