diff --git a/bilby/core/utils.py b/bilby/core/utils.py index b9a9dba3b6f82c97c39aa8983ee9b51410bc9293..df9df99a8f262eba06e9a509ec02969fce32b62e 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -703,29 +703,96 @@ def logtrapzexp(lnf, dx): return np.log(dx / 2.) + logsumexp([logsumexp(lnf[:-1]), logsumexp(lnf[1:])]) -def credible_interval(samples, confidence_level=.9, lower=True): - """ - Return location of lower or upper confidence levels - Based on lalinference.bayespputils.cred_interval +class SamplesSummary(object): + """ Object to store a set of samples and calculate summary statistics Parameters ---------- - x: List of samples. - cl: Confidence level to return the bound of. - lower: If ``True``, return the lower bound, otherwise return the upper bound. - - Returns - ------- - float: the upper or lower confidence level + samples: array_like + Array of samples + average: str {'median', 'mean'} + Use either a median average or mean average when calculating relative + uncertainties + level: float + The default confidence interval level, defaults t0 0.9 """ - def cred_level(cl, x): - return np.sort(x, axis=0)[int(cl * len(x))] - - if lower: - return cred_level((1. - confidence_level) / 2, samples) - else: - return cred_level((1. + confidence_level) / 2, samples) + def __init__(self, samples, average='median', confidence_level=.9): + self.samples = samples + self.average = average + self.confidence_level = confidence_level + + @property + def samples(self): + return self._samples + + @samples.setter + def samples(self, samples): + self._samples = samples + + @property + def confidence_level(self): + return self._confidence_level + + @confidence_level.setter + def confidence_level(self, confidence_level): + if 0 < confidence_level and confidence_level < 1: + self._confidence_level = confidence_level + else: + raise ValueError("Confidence level must be between 0 and 1") + + @property + def average(self): + if self._average == 'mean': + return self.mean + elif self._average == 'median': + return self.median + + @average.setter + def average(self, average): + allowed_averages = ['mean', 'median'] + if average in allowed_averages: + self._average = average + else: + raise ValueError("Average {} not in allowed averages".format(average)) + + @property + def median(self): + return np.median(self.samples, axis=0) + + @property + def mean(self): + return np.mean(self.samples, axis=0) + + @property + def _lower_level(self): + """ The credible interval lower quantile value """ + return (1 - self.confidence_level) / 2. + + @property + def _upper_level(self): + """ The credible interval upper quantile value """ + return (1 + self.confidence_level) / 2. + + @property + def lower_absolute_credible_interval(self): + """ Absolute lower value of the credible interval """ + return np.quantile(self.samples, self._lower_level, axis=0) + + @property + def upper_absolute_credible_interval(self): + """ Absolute upper value of the credible interval """ + return np.quantile(self.samples, self._upper_level, axis=0) + + @property + def lower_relative_credible_interval(self): + """ Relative (to average) lower value of the credible interval """ + return self.lower_absolute_credible_interval - self.average + + @property + def upper_relative_credible_interval(self): + """ Relative (to average) upper value of the credible interval """ + return self.upper_absolute_credible_interval - self.average def run_commandline(cl, log_level=20, raise_error=True, return_output=True): diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index 82ac038db39edd2ef0eb3f483413300a64046ec2..ef2eda91590cdeda061124d3b76d49458c95b654 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -9,7 +9,7 @@ import matplotlib.pyplot as plt from ..core.utils import (gps_time_to_gmst, ra_dec_to_theta_phi, speed_of_light, logger, run_commandline, check_directory_exists_and_if_not_mkdir, - credible_interval) + SamplesSummary) try: from gwpy.timeseries import TimeSeries @@ -830,10 +830,12 @@ def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label= else: scaled_samples = xform(samples) - mu = np.mean(scaled_samples, axis=0) - lower_confidence_level = mu - credible_interval(scaled_samples, level, lower=True) - upper_confidence_level = credible_interval(scaled_samples, level, lower=False) - mu - plt.errorbar(freq_points, mu, yerr=[lower_confidence_level, upper_confidence_level], + scaled_samples_summary = SamplesSummary(scaled_samples, average='mean') + data_summary = SamplesSummary(data, average='mean') + + plt.errorbar(freq_points, scaled_samples_summary.average, + yerr=[-scaled_samples_summary.lower_relative_credible_interval, + scaled_samples_summary.upper_relative_credible_interval], fmt='.', color=color, lw=4, alpha=0.5, capsize=0) for i, sample in enumerate(samples): @@ -845,6 +847,7 @@ def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label= line, = plt.plot(freqs, np.mean(data, axis=0), color=color, label=label) color = line.get_color() - plt.fill_between(freqs, credible_interval(data, level), credible_interval(data, level, lower=False), + plt.fill_between(freqs, data_summary.lower_absolute_credible_interval, + data_summary.upper_absolute_credible_interval, color=color, alpha=.1, linewidth=0.1) plt.xlim(freq_points.min() - .5, freq_points.max() + 50)