Skip to content
Snippets Groups Projects
Commit a9405708 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch '365-define-credible-interval-object' into 'master'

Resolve "Define credible interval object"

Closes #365

See merge request lscsoft/bilby!455
parents 0b0e497b b29d0552
No related branches found
No related tags found
1 merge request!455Resolve "Define credible interval object"
Pipeline #60744 passed
......@@ -703,29 +703,96 @@ def logtrapzexp(lnf, dx):
return np.log(dx / 2.) + logsumexp([logsumexp(lnf[:-1]), logsumexp(lnf[1:])])
def credible_interval(samples, confidence_level=.9, lower=True):
"""
Return location of lower or upper confidence levels
Based on lalinference.bayespputils.cred_interval
class SamplesSummary(object):
""" Object to store a set of samples and calculate summary statistics
Parameters
----------
x: List of samples.
cl: Confidence level to return the bound of.
lower: If ``True``, return the lower bound, otherwise return the upper bound.
Returns
-------
float: the upper or lower confidence level
samples: array_like
Array of samples
average: str {'median', 'mean'}
Use either a median average or mean average when calculating relative
uncertainties
level: float
The default confidence interval level, defaults t0 0.9
"""
def cred_level(cl, x):
return np.sort(x, axis=0)[int(cl * len(x))]
if lower:
return cred_level((1. - confidence_level) / 2, samples)
else:
return cred_level((1. + confidence_level) / 2, samples)
def __init__(self, samples, average='median', confidence_level=.9):
self.samples = samples
self.average = average
self.confidence_level = confidence_level
@property
def samples(self):
return self._samples
@samples.setter
def samples(self, samples):
self._samples = samples
@property
def confidence_level(self):
return self._confidence_level
@confidence_level.setter
def confidence_level(self, confidence_level):
if 0 < confidence_level and confidence_level < 1:
self._confidence_level = confidence_level
else:
raise ValueError("Confidence level must be between 0 and 1")
@property
def average(self):
if self._average == 'mean':
return self.mean
elif self._average == 'median':
return self.median
@average.setter
def average(self, average):
allowed_averages = ['mean', 'median']
if average in allowed_averages:
self._average = average
else:
raise ValueError("Average {} not in allowed averages".format(average))
@property
def median(self):
return np.median(self.samples, axis=0)
@property
def mean(self):
return np.mean(self.samples, axis=0)
@property
def _lower_level(self):
""" The credible interval lower quantile value """
return (1 - self.confidence_level) / 2.
@property
def _upper_level(self):
""" The credible interval upper quantile value """
return (1 + self.confidence_level) / 2.
@property
def lower_absolute_credible_interval(self):
""" Absolute lower value of the credible interval """
return np.quantile(self.samples, self._lower_level, axis=0)
@property
def upper_absolute_credible_interval(self):
""" Absolute upper value of the credible interval """
return np.quantile(self.samples, self._upper_level, axis=0)
@property
def lower_relative_credible_interval(self):
""" Relative (to average) lower value of the credible interval """
return self.lower_absolute_credible_interval - self.average
@property
def upper_relative_credible_interval(self):
""" Relative (to average) upper value of the credible interval """
return self.upper_absolute_credible_interval - self.average
def run_commandline(cl, log_level=20, raise_error=True, return_output=True):
......
......@@ -9,7 +9,7 @@ import matplotlib.pyplot as plt
from ..core.utils import (gps_time_to_gmst, ra_dec_to_theta_phi,
speed_of_light, logger, run_commandline,
check_directory_exists_and_if_not_mkdir,
credible_interval)
SamplesSummary)
try:
from gwpy.timeseries import TimeSeries
......@@ -830,10 +830,12 @@ def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label=
else:
scaled_samples = xform(samples)
mu = np.mean(scaled_samples, axis=0)
lower_confidence_level = mu - credible_interval(scaled_samples, level, lower=True)
upper_confidence_level = credible_interval(scaled_samples, level, lower=False) - mu
plt.errorbar(freq_points, mu, yerr=[lower_confidence_level, upper_confidence_level],
scaled_samples_summary = SamplesSummary(scaled_samples, average='mean')
data_summary = SamplesSummary(data, average='mean')
plt.errorbar(freq_points, scaled_samples_summary.average,
yerr=[-scaled_samples_summary.lower_relative_credible_interval,
scaled_samples_summary.upper_relative_credible_interval],
fmt='.', color=color, lw=4, alpha=0.5, capsize=0)
for i, sample in enumerate(samples):
......@@ -845,6 +847,7 @@ def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label=
line, = plt.plot(freqs, np.mean(data, axis=0), color=color, label=label)
color = line.get_color()
plt.fill_between(freqs, credible_interval(data, level), credible_interval(data, level, lower=False),
plt.fill_between(freqs, data_summary.lower_absolute_credible_interval,
data_summary.upper_absolute_credible_interval,
color=color, alpha=.1, linewidth=0.1)
plt.xlim(freq_points.min() - .5, freq_points.max() + 50)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment