Skip to content
Snippets Groups Projects
Commit 31f18abe authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch '260-result-plot_corner-should-still-work-ootb-if-i-move-the-hdf5-file' into 'master'

Resolve "`Result.plot_corner` should still work OOTB if I move the hdf5 file"

Closes #260

See merge request !319
parents 265dc536 c724b30e
No related branches found
No related tags found
No related merge requests found
...@@ -30,6 +30,8 @@ ...@@ -30,6 +30,8 @@
- Renamed "prior" to "priors" in bilby.gw.likelihood.GravtitationalWaveTransient - Renamed "prior" to "priors" in bilby.gw.likelihood.GravtitationalWaveTransient
for consistency with bilby.core. **WARNING**: This will break scripts which for consistency with bilby.core. **WARNING**: This will break scripts which
use marginalization. use marginalization.
- Added `outdir` kwarg for plotting methods in `bilby.core.result.Result`. This makes plotting
into custom destinations easier.
- Fixed definition of matched_filter_snr, the interferometer method has become `ifo.inner_product`. - Fixed definition of matched_filter_snr, the interferometer method has become `ifo.inner_product`.
### Added ### Added
......
...@@ -11,6 +11,7 @@ import corner ...@@ -11,6 +11,7 @@ import corner
import scipy.stats import scipy.stats
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib import lines as mpllines
from . import utils from . import utils
from .utils import (logger, infer_parameters_from_function, from .utils import (logger, infer_parameters_from_function,
...@@ -119,9 +120,10 @@ class Result(object): ...@@ -119,9 +120,10 @@ class Result(object):
Version information for software used to generate the result. Note, Version information for software used to generate the result. Note,
this information is generated when the result object is initialized this information is generated when the result object is initialized
Note: Note
All sampling output parameters, e.g. the samples themselves are ---------
typically not given at initialisation, but set at a later stage. All sampling output parameters, e.g. the samples themselves are
typically not given at initialisation, but set at a later stage.
""" """
...@@ -151,6 +153,9 @@ class Result(object): ...@@ -151,6 +153,9 @@ class Result(object):
self.version = version self.version = version
self.max_autocorrelation_time = max_autocorrelation_time self.max_autocorrelation_time = max_autocorrelation_time
self.prior_values = None
self._kde = None
def __str__(self): def __str__(self):
"""Print a summary """ """Print a summary """
if getattr(self, 'posterior', None) is not None: if getattr(self, 'posterior', None) is not None:
...@@ -285,7 +290,7 @@ class Result(object): ...@@ -285,7 +290,7 @@ class Result(object):
pass pass
return dictionary return dictionary
def save_to_file(self, overwrite=False): def save_to_file(self, overwrite=False, outdir=None):
""" """
Writes the Result to a deepdish h5 file Writes the Result to a deepdish h5 file
...@@ -294,9 +299,12 @@ class Result(object): ...@@ -294,9 +299,12 @@ class Result(object):
overwrite: bool, optional overwrite: bool, optional
Whether or not to overwrite an existing result file. Whether or not to overwrite an existing result file.
default=False default=False
outdir: str, optional
Path to the outdir. Default is the one stored in the result object.
""" """
file_name = result_file_name(self.outdir, self.label) outdir = self._safe_outdir_creation(outdir, self.save_to_file)
utils.check_directory_exists_and_if_not_mkdir(self.outdir) file_name = result_file_name(outdir, self.label)
if os.path.isfile(file_name): if os.path.isfile(file_name):
if overwrite: if overwrite:
logger.debug('Removing existing file {}'.format(file_name)) logger.debug('Removing existing file {}'.format(file_name))
...@@ -326,10 +334,10 @@ class Result(object): ...@@ -326,10 +334,10 @@ class Result(object):
logger.error("\n\n Saving the data has failed with the " logger.error("\n\n Saving the data has failed with the "
"following message:\n {} \n\n".format(e)) "following message:\n {} \n\n".format(e))
def save_posterior_samples(self): def save_posterior_samples(self, outdir=None):
"""Saves posterior samples to a file""" """Saves posterior samples to a file"""
filename = '{}/{}_posterior_samples.txt'.format(self.outdir, self.label) outdir = self._safe_outdir_creation(outdir, self.save_posterior_samples)
utils.check_directory_exists_and_if_not_mkdir(self.outdir) filename = '{}/{}_posterior_samples.txt'.format(outdir, self.label)
self.posterior.to_csv(filename, index=False, header=True) self.posterior.to_csv(filename, index=False, header=True)
def get_latex_labels_from_parameter_keys(self, keys): def get_latex_labels_from_parameter_keys(self, keys):
...@@ -389,7 +397,7 @@ class Result(object): ...@@ -389,7 +397,7 @@ class Result(object):
return self.posterior_volume / self.prior_volume(priors) return self.posterior_volume / self.prior_volume(priors)
def get_one_dimensional_median_and_error_bar(self, key, fmt='.2f', def get_one_dimensional_median_and_error_bar(self, key, fmt='.2f',
quantiles=[0.16, 0.84]): quantiles=(0.16, 0.84)):
""" Calculate the median and error bar for a given key """ Calculate the median and error bar for a given key
Parameters Parameters
...@@ -398,8 +406,8 @@ class Result(object): ...@@ -398,8 +406,8 @@ class Result(object):
The parameter key for which to calculate the median and error bar The parameter key for which to calculate the median and error bar
fmt: str, ('.2f') fmt: str, ('.2f')
A format string A format string
quantiles: list quantiles: list, tuple
A length-2 list of the lower and upper-quantiles to calculate A length-2 tuple of the lower and upper-quantiles to calculate
the errors bars for. the errors bars for.
Returns Returns
...@@ -428,8 +436,8 @@ class Result(object): ...@@ -428,8 +436,8 @@ class Result(object):
def plot_single_density(self, key, prior=None, cumulative=False, def plot_single_density(self, key, prior=None, cumulative=False,
title=None, truth=None, save=True, title=None, truth=None, save=True,
file_base_name=None, bins=50, label_fontsize=16, file_base_name=None, bins=50, label_fontsize=16,
title_fontsize=16, quantiles=[0.16, 0.84], dpi=300): title_fontsize=16, quantiles=(0.16, 0.84), dpi=300):
""" Plot a 1D marginal density, either probablility or cumulative. """ Plot a 1D marginal density, either probability or cumulative.
Parameters Parameters
---------- ----------
...@@ -458,8 +466,8 @@ class Result(object): ...@@ -458,8 +466,8 @@ class Result(object):
The number of histogram bins The number of histogram bins
label_fontsize, title_fontsize: int label_fontsize, title_fontsize: int
The fontsizes for the labels and titles The fontsizes for the labels and titles
quantiles: list quantiles: tuple
A length-2 list of the lower and upper-quantiles to calculate A length-2 tuple of the lower and upper-quantiles to calculate
the errors bars for. the errors bars for.
dpi: int dpi: int
Dots per inch resolution of the plot Dots per inch resolution of the plot
...@@ -493,7 +501,7 @@ class Result(object): ...@@ -493,7 +501,7 @@ class Result(object):
if isinstance(prior, Prior): if isinstance(prior, Prior):
theta = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 300) theta = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 300)
ax.plot(theta, Prior.prob(theta), color='C2') ax.plot(theta, prior.prob(theta), color='C2')
if save: if save:
fig.tight_layout() fig.tight_layout()
...@@ -508,7 +516,8 @@ class Result(object): ...@@ -508,7 +516,8 @@ class Result(object):
def plot_marginals(self, parameters=None, priors=None, titles=True, def plot_marginals(self, parameters=None, priors=None, titles=True,
file_base_name=None, bins=50, label_fontsize=16, file_base_name=None, bins=50, label_fontsize=16,
title_fontsize=16, quantiles=[0.16, 0.84], dpi=300): title_fontsize=16, quantiles=(0.16, 0.84), dpi=300,
outdir=None):
""" Plot 1D marginal distributions """ Plot 1D marginal distributions
Parameters Parameters
...@@ -531,12 +540,14 @@ class Result(object): ...@@ -531,12 +540,14 @@ class Result(object):
bins: int bins: int
The number of histogram bins The number of histogram bins
label_fontsize, title_fontsize: int label_fontsize, title_fontsize: int
The fontsizes for the labels and titles The font sizes for the labels and titles
quantiles: list quantiles: tuple
A length-2 list of the lower and upper-quantiles to calculate A length-2 tuple of the lower and upper-quantiles to calculate
the errors bars for. the errors bars for.
dpi: int dpi: int
Dots per inch resolution of the plot Dots per inch resolution of the plot
outdir: str, optional
Path to the outdir. Default is the one store in the result object.
Returns Returns
------- -------
...@@ -558,7 +569,8 @@ class Result(object): ...@@ -558,7 +569,8 @@ class Result(object):
truths = self.injection_parameters truths = self.injection_parameters
if file_base_name is None: if file_base_name is None:
file_base_name = '{}/{}_1d/'.format(self.outdir, self.label) outdir = self._safe_outdir_creation(outdir, self.plot_marginals)
file_base_name = '{}/{}_1d/'.format(outdir, self.label)
check_directory_exists_and_if_not_mkdir(file_base_name) check_directory_exists_and_if_not_mkdir(file_base_name)
if priors is True: if priors is True:
...@@ -609,7 +621,8 @@ class Result(object): ...@@ -609,7 +621,8 @@ class Result(object):
**kwargs: **kwargs:
Other keyword arguments are passed to `corner.corner`. We set some Other keyword arguments are passed to `corner.corner`. We set some
defaults to improve the basic look and feel, but these can all be defaults to improve the basic look and feel, but these can all be
overridden. overridden. Also optional an 'outdir' argument which can be used
to override the outdir set by the absolute path of the result object.
Notes Notes
----- -----
...@@ -720,8 +733,8 @@ class Result(object): ...@@ -720,8 +733,8 @@ class Result(object):
if save: if save:
if filename is None: if filename is None:
utils.check_directory_exists_and_if_not_mkdir(self.outdir) outdir = self._safe_outdir_creation(kwargs.get('outdir'), self.plot_corner)
filename = '{}/{}_corner.png'.format(self.outdir, self.label) filename = '{}/{}_corner.png'.format(outdir, self.label)
logger.debug('Saving corner plot to {}'.format(filename)) logger.debug('Saving corner plot to {}'.format(filename))
fig.savefig(filename, dpi=dpi) fig.savefig(filename, dpi=dpi)
plt.close(fig) plt.close(fig)
...@@ -752,16 +765,16 @@ class Result(object): ...@@ -752,16 +765,16 @@ class Result(object):
ax.set_ylabel(self.parameter_labels[i]) ax.set_ylabel(self.parameter_labels[i])
fig.tight_layout() fig.tight_layout()
filename = '{}/{}_walkers.png'.format(self.outdir, self.label) outdir = self._safe_outdir_creation(kwargs.get('outdir'), self.plot_walkers)
filename = '{}/{}_walkers.png'.format(outdir, self.label)
logger.debug('Saving walkers plot to {}'.format('filename')) logger.debug('Saving walkers plot to {}'.format('filename'))
utils.check_directory_exists_and_if_not_mkdir(self.outdir)
fig.savefig(filename) fig.savefig(filename)
plt.close(fig) plt.close(fig)
def plot_with_data(self, model, x, y, ndraws=1000, npoints=1000, def plot_with_data(self, model, x, y, ndraws=1000, npoints=1000,
xlabel=None, ylabel=None, data_label='data', xlabel=None, ylabel=None, data_label='data',
data_fmt='o', draws_label=None, filename=None, data_fmt='o', draws_label=None, filename=None,
maxl_label='max likelihood', dpi=300): maxl_label='max likelihood', dpi=300, outdir=None):
""" Generate a figure showing the data and fits to the data """ Generate a figure showing the data and fits to the data
Parameters Parameters
...@@ -787,6 +800,8 @@ class Result(object): ...@@ -787,6 +800,8 @@ class Result(object):
filename: str filename: str
If given, the filename to use. Otherwise, the filename is generated If given, the filename to use. Otherwise, the filename is generated
from the outdir and label attributes. from the outdir and label attributes.
outdir: str, optional
Path to the outdir. Default is the one store in the result object.
""" """
...@@ -825,8 +840,8 @@ class Result(object): ...@@ -825,8 +840,8 @@ class Result(object):
ax.legend(numpoints=3) ax.legend(numpoints=3)
fig.tight_layout() fig.tight_layout()
if filename is None: if filename is None:
utils.check_directory_exists_and_if_not_mkdir(self.outdir) outdir = self._safe_outdir_creation(outdir, self.plot_with_data)
filename = '{}/{}_plot_with_data'.format(self.outdir, self.label) filename = '{}/{}_plot_with_data'.format(outdir, self.label)
fig.savefig(filename, dpi=dpi) fig.savefig(filename, dpi=dpi)
plt.close(fig) plt.close(fig)
...@@ -944,20 +959,20 @@ class Result(object): ...@@ -944,20 +959,20 @@ class Result(object):
bool: True if attribute name matches with an attribute of other_object, False otherwise bool: True if attribute name matches with an attribute of other_object, False otherwise
""" """
A = getattr(self, name, False) a = getattr(self, name, False)
B = getattr(other_object, name, False) b = getattr(other_object, name, False)
logger.debug('Checking {} value: {}=={}'.format(name, A, B)) logger.debug('Checking {} value: {}=={}'.format(name, a, b))
if (A is not False) and (B is not False): if (a is not False) and (b is not False):
typeA = type(A) type_a = type(a)
typeB = type(B) type_b = type(b)
if typeA == typeB: if type_a == type_b:
if typeA in [str, float, int, dict, list]: if type_a in [str, float, int, dict, list]:
try: try:
return A == B return a == b
except ValueError: except ValueError:
return False return False
elif typeA in [np.ndarray]: elif type_a in [np.ndarray]:
return np.all(A == B) return np.all(a == b)
return False return False
@property @property
...@@ -966,9 +981,9 @@ class Result(object): ...@@ -966,9 +981,9 @@ class Result(object):
Uses `scipy.stats.gaussian_kde` to generate the kernel density Uses `scipy.stats.gaussian_kde` to generate the kernel density
""" """
try: if self._kde:
return self._kde return self._kde
except AttributeError: else:
self._kde = scipy.stats.gaussian_kde( self._kde = scipy.stats.gaussian_kde(
self.posterior[self.search_parameter_keys].values.T) self.posterior[self.search_parameter_keys].values.T)
return self._kde return self._kde
...@@ -998,6 +1013,18 @@ class Result(object): ...@@ -998,6 +1013,18 @@ class Result(object):
for s in sample] for s in sample]
return self.kde(ordered_sample) return self.kde(ordered_sample)
def _safe_outdir_creation(self, outdir=None, caller_func=None):
if outdir is None:
outdir = self.outdir
try:
utils.check_directory_exists_and_if_not_mkdir(outdir)
except PermissionError:
raise FileMovedError("Can not write in the out directory.\n"
"Did you move the here file from another system?\n"
"Try calling " + caller_func.__name__ + " with the 'outdir' "
"keyword argument, e.g. " + caller_func.__name__ + "(outdir='.')")
return outdir
def plot_multiple(results, filename=None, labels=None, colours=None, def plot_multiple(results, filename=None, labels=None, colours=None,
save=True, evidences=False, **kwargs): save=True, evidences=False, **kwargs):
...@@ -1050,7 +1077,7 @@ def plot_multiple(results, filename=None, labels=None, colours=None, ...@@ -1050,7 +1077,7 @@ def plot_multiple(results, filename=None, labels=None, colours=None,
hist_kwargs['color'] = c hist_kwargs['color'] = c
fig = result.plot_corner(fig=fig, save=False, color=c, **kwargs) fig = result.plot_corner(fig=fig, save=False, color=c, **kwargs)
default_filename += '_{}'.format(result.label) default_filename += '_{}'.format(result.label)
lines.append(matplotlib.lines.Line2D([0], [0], color=c)) lines.append(mpllines.Line2D([0], [0], color=c))
default_labels.append(result.label) default_labels.append(result.label)
# Rescale the axes # Rescale the axes
...@@ -1100,7 +1127,7 @@ def make_pp_plot(results, filename=None, save=True, **kwargs): ...@@ -1100,7 +1127,7 @@ def make_pp_plot(results, filename=None, save=True, **kwargs):
Returns Returns
------- -------
fig: fig:
Matplotlib figure matplotlib figure
""" """
fig = plt.figure() fig = plt.figure()
credible_levels = pd.DataFrame() credible_levels = pd.DataFrame()
...@@ -1122,3 +1149,11 @@ def make_pp_plot(results, filename=None, save=True, **kwargs): ...@@ -1122,3 +1149,11 @@ def make_pp_plot(results, filename=None, save=True, **kwargs):
filename = 'outdir/pp.png' filename = 'outdir/pp.png'
plt.savefig(filename) plt.savefig(filename)
return fig return fig
class ResultError(Exception):
""" Base exception for all Result related errors """
class FileMovedError(ResultError):
""" Exceptions that occur when files have been moved """
...@@ -25,9 +25,9 @@ class TestResult(unittest.TestCase): ...@@ -25,9 +25,9 @@ class TestResult(unittest.TestCase):
injection_parameters=dict(x=0.5, y=0.5), injection_parameters=dict(x=0.5, y=0.5),
meta_data=dict(test='test')) meta_data=dict(test='test'))
N = 100 n = 100
posterior = pd.DataFrame(dict(x=np.random.normal(0, 1, N), posterior = pd.DataFrame(dict(x=np.random.normal(0, 1, n),
y=np.random.normal(0, 1, N))) y=np.random.normal(0, 1, n)))
result.posterior = posterior result.posterior = posterior
result.log_evidence = 10 result.log_evidence = 10
result.log_evidence_err = 11 result.log_evidence_err = 11
...@@ -66,7 +66,7 @@ class TestResult(unittest.TestCase): ...@@ -66,7 +66,7 @@ class TestResult(unittest.TestCase):
injection_parameters=dict(x=0.5, y=0.5), injection_parameters=dict(x=0.5, y=0.5),
meta_data=dict(test='test')) meta_data=dict(test='test'))
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
result.priors _ = result.priors
self.assertEqual(result.parameter_labels, result.search_parameter_keys) self.assertEqual(result.parameter_labels, result.search_parameter_keys)
self.assertEqual(result.parameter_labels_with_unit, result.search_parameter_keys) self.assertEqual(result.parameter_labels_with_unit, result.search_parameter_keys)
...@@ -102,14 +102,14 @@ class TestResult(unittest.TestCase): ...@@ -102,14 +102,14 @@ class TestResult(unittest.TestCase):
def test_unset_posterior(self): def test_unset_posterior(self):
self.result.posterior = None self.result.posterior = None
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.result.posterior _ = self.result.posterior
def test_save_and_load(self): def test_save_and_load(self):
self.result.save_to_file() self.result.save_to_file()
loaded_result = bilby.core.result.read_in_result( loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label) outdir=self.result.outdir, label=self.result.label)
self.assertTrue( self.assertTrue(pd.DataFrame.equals
all(self.result.posterior == loaded_result.posterior)) (self.result.posterior, loaded_result.posterior))
self.assertTrue(self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys) self.assertTrue(self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys)
self.assertTrue(self.result.search_parameter_keys == loaded_result.search_parameter_keys) self.assertTrue(self.result.search_parameter_keys == loaded_result.search_parameter_keys)
self.assertEqual(self.result.meta_data, loaded_result.meta_data) self.assertEqual(self.result.meta_data, loaded_result.meta_data)
...@@ -146,31 +146,28 @@ class TestResult(unittest.TestCase): ...@@ -146,31 +146,28 @@ class TestResult(unittest.TestCase):
filename = '{}/{}_posterior_samples.txt'.format(self.result.outdir, self.result.label) filename = '{}/{}_posterior_samples.txt'.format(self.result.outdir, self.result.label)
self.assertTrue(os.path.isfile(filename)) self.assertTrue(os.path.isfile(filename))
df = pd.read_csv(filename) df = pd.read_csv(filename)
self.assertTrue(all(self.result.posterior == df)) self.assertTrue(np.allclose(self.result.posterior.values, df.values))
def test_samples_to_posterior(self): def test_samples_to_posterior(self):
self.result.posterior = None self.result.posterior = None
x = [1, 2, 3] x = [1, 2, 3]
y = [4, 6, 8] y = [4, 6, 8]
log_likelihood = [6, 7, 8] log_likelihood = np.array([6, 7, 8])
self.result.samples = np.array([x, y]).T self.result.samples = np.array([x, y]).T
self.result.log_likelihood_evaluations = log_likelihood self.result.log_likelihood_evaluations = log_likelihood
self.result.samples_to_posterior(priors=self.result.priors) self.result.samples_to_posterior(priors=self.result.priors)
self.assertTrue(all(self.result.posterior['x'] == x)) self.assertTrue(all(self.result.posterior['x'] == x))
self.assertTrue(all(self.result.posterior['y'] == y)) self.assertTrue(all(self.result.posterior['y'] == y))
self.assertTrue( self.assertTrue(np.array_equal(self.result.posterior.log_likelihood.values, log_likelihood))
all(self.result.posterior['log_likelihood'] == log_likelihood)) self.assertTrue(all(self.result.posterior.c.values == self.result.priors['c'].peak))
self.assertTrue( self.assertTrue(all(self.result.posterior.d.values == self.result.priors['d'].peak))
all(self.result.posterior['c'] == self.result.priors['c'].peak))
self.assertTrue(
all(self.result.posterior['d'] == self.result.priors['d'].peak))
def test_calculate_prior_values(self): def test_calculate_prior_values(self):
self.result.calculate_prior_values(priors=self.result.priors) self.result.calculate_prior_values(priors=self.result.priors)
self.assertEqual(len(self.result.posterior), len(self.result.prior_values)) self.assertEqual(len(self.result.posterior), len(self.result.prior_values))
def test_plot_multiple(self): def test_plot_multiple(self):
filename='multiple.png'.format(self.result.outdir) filename = 'multiple.png'.format(self.result.outdir)
bilby.core.result.plot_multiple([self.result, self.result], bilby.core.result.plot_multiple([self.result, self.result],
filename=filename) filename=filename)
self.assertTrue(os.path.isfile(filename)) self.assertTrue(os.path.isfile(filename))
...@@ -188,8 +185,8 @@ class TestResult(unittest.TestCase): ...@@ -188,8 +185,8 @@ class TestResult(unittest.TestCase):
x = np.linspace(0, 1, 10) x = np.linspace(0, 1, 10)
y = np.linspace(0, 1, 10) y = np.linspace(0, 1, 10)
def model(x): def model(xx):
return x return xx
self.result.plot_with_data(model, x, y, ndraws=10) self.result.plot_with_data(model, x, y, ndraws=10)
self.assertTrue( self.assertTrue(
os.path.isfile('{}/{}_plot_with_data.png'.format( os.path.isfile('{}/{}_plot_with_data.png'.format(
...@@ -260,9 +257,8 @@ class TestResult(unittest.TestCase): ...@@ -260,9 +257,8 @@ class TestResult(unittest.TestCase):
sample = [dict(x=0, y=0.1), dict(x=0.8, y=0)] sample = [dict(x=0, y=0.1), dict(x=0.8, y=0)]
self.assertTrue( self.assertTrue(
isinstance(self.result.posterior_probability(sample), np.ndarray)) isinstance(self.result.posterior_probability(sample), np.ndarray))
self.assertTrue( self.assertTrue(np.array_equal(self.result.posterior_probability(sample),
all(self.result.posterior_probability(sample) self.result.kde([[0, 0.1], [0.8, 0]])))
== self.result.kde([[0, 0.1], [0.8, 0]])))
if __name__ == '__main__': if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment