From 6b14e6b1379d25c030023554f0079b4d57d0e041 Mon Sep 17 00:00:00 2001 From: Moritz Huebner <moritz.huebner@ligo.org> Date: Tue, 20 Oct 2020 03:07:55 -0500 Subject: [PATCH] Generic Slab and spike priors for bilby --- bilby/core/prior/__init__.py | 1 + bilby/core/prior/analytical.py | 1 + bilby/core/prior/slabspike.py | 169 ++++++++++++++++ examples/core_examples/slabspike_example.py | 97 ++++++++++ test/core/prior/slabspike_test.py | 202 ++++++++++++++++++++ 5 files changed, 470 insertions(+) create mode 100644 bilby/core/prior/slabspike.py create mode 100644 examples/core_examples/slabspike_example.py create mode 100644 test/core/prior/slabspike_test.py diff --git a/bilby/core/prior/__init__.py b/bilby/core/prior/__init__.py index 253ad6c9c..fc795c3e1 100644 --- a/bilby/core/prior/__init__.py +++ b/bilby/core/prior/__init__.py @@ -4,3 +4,4 @@ from .conditional import * from .dict import * from .interpolated import * from .joint import * +from .slabspike import * diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index b575b3376..7a7fb2b6b 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -618,6 +618,7 @@ class TruncatedGaussian(Prior): / self.sigma / self.normalisation * self.is_in_prior_range(val) def cdf(self, val): + val = np.atleast_1d(val) _cdf = (erf((val - self.mu) / 2 ** 0.5 / self.sigma) - erf( (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 / self.normalisation _cdf[val > self.maximum] = 1 diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py new file mode 100644 index 000000000..de7dc48ad --- /dev/null +++ b/bilby/core/prior/slabspike.py @@ -0,0 +1,169 @@ +import numpy as np + +from bilby.core.prior.base import Prior +from bilby.core.utils import logger + + +class SlabSpikePrior(Prior): + + def __init__(self, slab, spike_location=None, spike_height=0): + """'Slab-and-spike' prior, see e.g. https://arxiv.org/abs/1812.07259 + This prior is composed of a `slab`, i.e. any common prior distribution, + and a Dirac spike at a fixed location. This can effectively be used + to emulate sampling in the number of dimensions (similar to reversible- + jump MCMC). + + `SymmetricLogUniform` and `FermiDirac` are currently not supported. + + Parameters + ---------- + slab: Prior + Any instance of a bilby prior class. All general prior attributes + from the slab are copied into the SlabSpikePrior. + Note that this hasn't been tested for conditional priors. + spike_location: float, optional + Location of the Dirac spike. Must be between minimum and maximum + of the slab. Defaults to the minimum of the slab + spike_height: float, optional + Relative weight of the spike compared to the slab. Must be + between 0 and 1. Defaults to 0, i.e. the prior is just the slab. + + """ + self.slab = slab + super().__init__(name=self.slab.name, latex_label=self.slab.latex_label, unit=self.slab.unit, + minimum=self.slab.minimum, maximum=self.slab.maximum, + check_range_nonzero=self.slab.check_range_nonzero, boundary=self.slab.boundary) + self.spike_location = spike_location + self.spike_height = spike_height + try: + self.inverse_cdf_below_spike = self._find_inverse_cdf_fraction_before_spike() + except Exception as e: + logger.warning("Disregard the following warning when running tests:\n {}".format(e)) + + @property + def spike_location(self): + return self._spike_loc + + @spike_location.setter + def spike_location(self, spike_loc): + if spike_loc is None: + spike_loc = self.minimum + if not self.minimum <= spike_loc <= self.maximum: + raise ValueError("Spike location {} not within prior domain ".format(spike_loc)) + self._spike_loc = spike_loc + + @property + def spike_height(self): + return self._spike_height + + @spike_height.setter + def spike_height(self, spike_height): + if 0 <= spike_height <= 1: + self._spike_height = spike_height + else: + raise ValueError("Spike height must be between 0 and 1, but is {}".format(spike_height)) + + @property + def slab_fraction(self): + """ Relative prior weight of the slab. """ + return 1 - self.spike_height + + def _find_inverse_cdf_fraction_before_spike(self): + return float(self.slab.cdf(self.spike_location)) * self.slab_fraction + + def rescale(self, val): + """ + 'Rescale' a sample from the unit line element to the prior. + + Parameters + ---------- + val: Union[float, int, array_like] + A random number between 0 and 1 + + Returns + ------- + array_like: Associated prior value with input value. + """ + val = np.atleast_1d(val) + + lower_indices = np.where(val < self.inverse_cdf_below_spike)[0] + intermediate_indices = np.where(np.logical_and( + self.inverse_cdf_below_spike <= val, + val <= self.inverse_cdf_below_spike + self.spike_height))[0] + higher_indices = np.where(val > self.inverse_cdf_below_spike + self.spike_height)[0] + + res = np.zeros(len(val)) + res[lower_indices] = self._contracted_rescale(val[lower_indices]) + res[intermediate_indices] = self.spike_location + res[higher_indices] = self._contracted_rescale(val[higher_indices] - self.spike_height) + return res + + def _contracted_rescale(self, val): + """ + Contracted version of the rescale function that implements the `rescale` function + on the pure slab part of the prior. + + Parameters + ---------- + val: Union[float, int, array_like] + A random number between 0 and self.slab_fraction + + Returns + ------- + array_like: Associated prior value with input value. + """ + return self.slab.rescale(val / self.slab_fraction) + + def prob(self, val): + """Return the prior probability of val. + Returns np.inf for the spike location + + Parameters + ---------- + val: Union[float, int, array_like] + + Returns + ------- + array_like: Prior probability of val + """ + res = self.slab.prob(val) * self.slab_fraction + res = np.atleast_1d(res) + res[np.where(val == self.spike_location)] = np.inf + return res + + def ln_prob(self, val): + """Return the Log prior probability of val. + Returns np.inf for the spike location + + Parameters + ---------- + val: Union[float, int, array_like] + + Returns + ------- + array_like: Prior probability of val + """ + res = self.slab.ln_prob(val) + np.log(self.slab_fraction) + res = np.atleast_1d(res) + res[np.where(val == self.spike_location)] = np.inf + return res + + def cdf(self, val): + """ Return the CDF of the prior. + This calls to the slab CDF and adds a discrete step + at the spike location. + + Parameters + ---------- + val: Union[float, int, array_like] + + Returns + ------- + array_like: CDF value of val + + """ + res = self.slab.cdf(val) * self.slab_fraction + res = np.atleast_1d(res) + indices_above_spike = np.where(val > self.spike_location)[0] + res[indices_above_spike] += self.spike_height + return res diff --git a/examples/core_examples/slabspike_example.py b/examples/core_examples/slabspike_example.py new file mode 100644 index 000000000..d79851235 --- /dev/null +++ b/examples/core_examples/slabspike_example.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python +""" +An example of how to use slab-and-spike priors in bilby. +In this example we look at a simple example with the sum +of two Gaussian distributions, and we try to fit with +up to three Gaussians. + +""" + +import bilby +import numpy as np +import matplotlib.pyplot as plt + +outdir = 'outdir' +label = 'slabspike' +bilby.utils.check_directory_exists_and_if_not_mkdir(outdir) + + +# Here we define our model. We want to inject two Gaussians and recover with up to three. +def gaussian(xs, amplitude, mu, sigma): + return amplitude / np.sqrt(2 * np.pi * sigma**2) * np.exp(-0.5 * (xs - mu)**2 / sigma**2) + + +def triple_gaussian(xs, amplitude_0, amplitude_1, amplitude_2, mu_0, mu_1, mu_2, sigma_0, sigma_1, sigma_2, **kwargs): + return \ + gaussian(xs, amplitude_0, mu_0, sigma_0) + \ + gaussian(xs, amplitude_1, mu_1, sigma_1) + \ + gaussian(xs, amplitude_2, mu_2, sigma_2) + + +# Let's create our data set. We create 200 points on a grid. + +xs = np.linspace(-5, 5, 200) +dx = xs[1] - xs[0] + +# Note for our injection parameters we set the amplitude of the second component to 0. +injection_params = dict(amplitude_0=-3, mu_0=-4, sigma_0=4, + amplitude_1=0, mu_1=0, sigma_1=1, + amplitude_2=4, mu_2=3, sigma_2=3) + +# We calculate the injected curve and add some Gaussian noise on the data points +sigma = 0.02 +p = bilby.core.prior.Gaussian(mu=0, sigma=sigma) +ys = triple_gaussian(xs=xs, **injection_params) + p.sample(len(xs)) + +plt.errorbar(xs, ys, yerr=sigma, fmt=".k", capsize=0, label='Injected data') +plt.plot(xs, triple_gaussian(xs=xs, **injection_params), label='True signal') +plt.legend() +plt.savefig(f'{outdir}/{label}_injected_data') +plt.clf() + + +# Now we want to set up our priors. +priors = bilby.core.prior.PriorDict() +# For the slab-and-spike prior, we first need to define the 'slab' part, which is just a regular bilby prior. +amplitude_slab_0 = bilby.core.prior.Uniform(minimum=-10, maximum=10, name='amplitude_0', latex_label='$A_0$') +amplitude_slab_1 = bilby.core.prior.Uniform(minimum=-10, maximum=10, name='amplitude_1', latex_label='$A_1$') +amplitude_slab_2 = bilby.core.prior.Uniform(minimum=-10, maximum=10, name='amplitude_2', latex_label='$A_2$') +# We do the following to create the slab-and-spike prior. The spike height is somewhat arbitrary and can +# be corrected in post-processing. +priors['amplitude_0'] = bilby.core.prior.SlabSpikePrior(slab=amplitude_slab_0, spike_location=0, spike_height=0.1) +priors['amplitude_1'] = bilby.core.prior.SlabSpikePrior(slab=amplitude_slab_1, spike_location=0, spike_height=0.1) +priors['amplitude_2'] = bilby.core.prior.SlabSpikePrior(slab=amplitude_slab_2, spike_location=0, spike_height=0.1) +# Our problem has a degeneracy in the ordering. In general, this problem is somewhat difficult to resolve properly. +# See e.g. https://github.com/GregoryAshton/kookaburra/blob/master/src/priors.py#L72 for an implementation. +# We resolve this by not letting the priors overlap in this case. +priors['mu_0'] = bilby.core.prior.Uniform(minimum=-5, maximum=-2, name='mu_0', latex_label='$\mu_0$') +priors['mu_1'] = bilby.core.prior.Uniform(minimum=-2, maximum=2, name='mu_1', latex_label='$\mu_1$') +priors['mu_2'] = bilby.core.prior.Uniform(minimum=2, maximum=5, name='mu_2', latex_label='$\mu_2$') +priors['sigma_0'] = bilby.core.prior.LogUniform(minimum=0.01, maximum=10, name='sigma_0', latex_label='$\sigma_0$') +priors['sigma_1'] = bilby.core.prior.LogUniform(minimum=0.01, maximum=10, name='sigma_1', latex_label='$\sigma_1$') +priors['sigma_2'] = bilby.core.prior.LogUniform(minimum=0.01, maximum=10, name='sigma_2', latex_label='$\sigma_2$') + +# Setting up the likelihood and running the samplers works the same as elsewhere. +likelihood = bilby.core.likelihood.GaussianLikelihood(x=xs, y=ys, func=triple_gaussian, sigma=sigma) +result = bilby.run_sampler(likelihood=likelihood, priors=priors, outdir=outdir, label=label, + sampler='dynesty', nlive=400) + +result.plot_corner(truths=injection_params) + + +# Let's also plot the maximum likelihood fit along with the data. +max_like_params = result.posterior.iloc[-1] +plt.errorbar(xs, ys, yerr=sigma, fmt=".k", capsize=0, label='Injected data') +plt.plot(xs, triple_gaussian(xs=xs, **injection_params), label='True signal') +plt.plot(xs, triple_gaussian(xs=xs, **max_like_params), label='Max likelihood fit') +plt.legend() +plt.savefig(f'{outdir}/{label}_max_likelihood_recovery') +plt.clf() + +# Finally, we can check what fraction of amplitude samples are exactly on the spike. +spike_samples_0 = len(np.where(result.posterior['amplitude_0'] == 0.0)[0]) / len(result.posterior) +spike_samples_1 = len(np.where(result.posterior['amplitude_1'] == 0.0)[0]) / len(result.posterior) +spike_samples_2 = len(np.where(result.posterior['amplitude_2'] == 0.0)[0]) / len(result.posterior) +print(f"{spike_samples_0 * 100:.2f}% of amplitude_0 samples are exactly 0.0") +print(f"{spike_samples_1 * 100:.2f}% of amplitude_1 samples are exactly 0.0") +print(f"{spike_samples_2 * 100:.2f}% of amplitude_2 samples are exactly 0.0") diff --git a/test/core/prior/slabspike_test.py b/test/core/prior/slabspike_test.py new file mode 100644 index 000000000..d2cdcc55a --- /dev/null +++ b/test/core/prior/slabspike_test.py @@ -0,0 +1,202 @@ +import numpy as np +import unittest + +import bilby +from bilby.core.prior.slabspike import SlabSpikePrior +from bilby.core.prior.analytical import Uniform, PowerLaw, LogUniform, TruncatedGaussian, \ + Beta, Gaussian, Cosine, Sine, HalfGaussian, LogNormal, Exponential, StudentT, Logistic, \ + Cauchy, Gamma, ChiSquared + + +class TestSlabSpikePrior(unittest.TestCase): + + def setUp(self): + self.minimum = 0 + self.maximum = 1 + self.spike_loc = 0.5 + self.spike_height = 0.3 + self.slab = bilby.core.prior.Prior(minimum=self.minimum, maximum=self.maximum) + self.prior = SlabSpikePrior( + slab=self.slab, spike_location=self.spike_loc, spike_height=self.spike_height) + + def tearDown(self): + del self.minimum + del self.maximum + del self.spike_loc + del self.spike_height + del self.prior + del self.slab + + def test_slab_fraction(self): + expected = 1 - self.spike_height + self.assertEqual(expected, self.prior.slab_fraction) + + def test_spike_loc(self): + self.assertEqual(self.spike_loc, self.prior.spike_location) + + def test_set_spike_loc_none(self): + self.prior.spike_location = None + self.assertEqual(self.prior.minimum, self.prior.spike_location) + + def test_set_spike_loc_outside_domain(self): + with self.assertRaises(ValueError): + self.prior.spike_location = 1.5 + + def test_set_spike_loc_maximum(self): + self.prior.spike_location = self.maximum + self.assertEqual(self.maximum, self.prior.spike_location) + + def test_class_name(self): + expected = "SlabSpikePrior" + self.assertEqual(expected, self.prior.__class__.__name__) + self.assertEqual(expected, self.prior.__class__.__qualname__) + + def test_set_spike_height_outside_domain(self): + with self.assertRaises(ValueError): + self.prior.spike_height = 1.5 + + def test_set_spike_height_domain_edge(self): + self.prior.spike_height = 0 + self.prior.spike_height = 1 + + +class TestSlabSpikeClasses(unittest.TestCase): + + def setUp(self): + self.minimum = 0.4 + self.maximum = 2.4 + self.spike_loc = 1.5 + self.spike_height = 0.3 + + self.slabs = [ + Uniform(minimum=self.minimum, maximum=self.maximum), + PowerLaw(minimum=self.minimum, maximum=self.maximum, alpha=2), + LogUniform(minimum=self.minimum, maximum=self.maximum), + TruncatedGaussian(minimum=self.minimum, maximum=self.maximum, mu=0, sigma=1), + Beta(minimum=self.minimum, maximum=self.maximum, alpha=1, beta=1), + Gaussian(mu=0, sigma=1), + Cosine(), + Sine(), + HalfGaussian(sigma=1), + LogNormal(mu=1, sigma=2), + Exponential(mu=2), + StudentT(df=2), + Logistic(mu=2, scale=1), + Cauchy(alpha=1, beta=2), + Gamma(k=1, theta=1.), + ChiSquared(nu=2)] + self.slab_spikes = [SlabSpikePrior(slab, spike_height=self.spike_height, spike_location=self.spike_loc) + for slab in self.slabs] + self.test_nodes_finite_support = np.linspace(self.minimum, self.maximum, 1000) + self.test_nodes_infinite_support = np.linspace(-10, 10, 1000) + self.test_nodes = [self.test_nodes_finite_support + if np.isinf(slab.minimum) or np.isinf(slab.maximum) + else self.test_nodes_finite_support for slab in self.slabs] + + def tearDown(self): + del self.minimum + del self.maximum + del self.spike_loc + del self.spike_height + del self.slabs + del self.test_nodes_finite_support + del self.test_nodes_infinite_support + + def test_prob_on_slab(self): + for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): + expected = slab.prob(test_nodes) * slab_spike.slab_fraction + actual = slab_spike.prob(test_nodes) + self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + + def test_prob_on_spike(self): + for slab_spike in self.slab_spikes: + self.assertEqual(np.inf, slab_spike.prob(self.spike_loc)) + + def test_ln_prob_on_slab(self): + for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): + expected = slab.ln_prob(test_nodes) + np.log(slab_spike.slab_fraction) + actual = slab_spike.ln_prob(test_nodes) + self.assertTrue(np.array_equal(expected, actual)) + + def test_ln_prob_on_spike(self): + for slab_spike in self.slab_spikes: + self.assertEqual(np.inf, slab_spike.ln_prob(self.spike_loc)) + + def test_inverse_cdf_below_spike_with_spike_at_minimum(self): + for slab in self.slabs: + slab_spike = SlabSpikePrior(slab=slab, spike_height=0.4, spike_location=slab.minimum) + self.assertEqual(0, slab_spike.inverse_cdf_below_spike) + + def test_inverse_cdf_below_spike_with_spike_at_maximum(self): + for slab in self.slabs: + slab_spike = SlabSpikePrior(slab=slab, spike_height=0.4, spike_location=slab.maximum) + expected = 1 - slab_spike.spike_height + actual = slab_spike.inverse_cdf_below_spike + self.assertEqual(expected, actual) + + def test_inverse_cdf_below_spike_arbitrary_position(self): + pass + + def test_cdf_below_spike(self): + for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): + test_nodes = test_nodes[np.where(test_nodes < self.spike_loc)] + expected = slab.cdf(test_nodes) * slab_spike.slab_fraction + actual = slab_spike.cdf(test_nodes) + self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + + def test_cdf_at_spike(self): + for slab, slab_spike in zip(self.slabs, self.slab_spikes): + expected = slab.cdf(self.spike_loc) * slab_spike.slab_fraction + actual = slab_spike.cdf(self.spike_loc) + self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + + def test_cdf_above_spike(self): + for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes): + test_nodes = test_nodes[np.where(test_nodes > self.spike_loc)] + expected = slab.cdf(test_nodes) * slab_spike.slab_fraction + self.spike_height + actual = slab_spike.cdf(test_nodes) + self.assertTrue(np.array_equal(expected, actual)) + + def test_cdf_at_minimum(self): + for slab_spike in self.slab_spikes: + expected = 0 + actual = slab_spike.cdf(slab_spike.minimum) + self.assertEqual(expected, actual) + + def test_cdf_at_maximum(self): + for slab_spike in self.slab_spikes: + expected = 1 + actual = slab_spike.cdf(slab_spike.maximum) + self.assertEqual(expected, actual) + + def test_rescale_no_spike(self): + for slab in self.slabs: + slab_spike = SlabSpikePrior(slab=slab, spike_height=0, spike_location=slab.minimum) + vals = np.linspace(0, 1, 1000) + expected = slab.rescale(vals) + actual = slab_spike.rescale(vals) + print(slab) + self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + + def test_rescale_below_spike(self): + for slab, slab_spike in zip(self.slabs, self.slab_spikes): + vals = np.linspace(0, slab_spike.inverse_cdf_below_spike, 1000) + expected = slab.rescale(vals / slab_spike.slab_fraction) + actual = slab_spike.rescale(vals) + self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + + def test_rescale_at_spike(self): + for slab, slab_spike in zip(self.slabs, self.slab_spikes): + vals = np.linspace(slab_spike.inverse_cdf_below_spike, + slab_spike.inverse_cdf_below_spike + slab_spike.spike_height, 1000) + expected = np.ones(len(vals)) * slab.rescale(vals[0] / slab_spike.slab_fraction) + actual = slab_spike.rescale(vals) + self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) + + def test_rescale_above_spike(self): + for slab, slab_spike in zip(self.slabs, self.slab_spikes): + vals = np.linspace(slab_spike.inverse_cdf_below_spike + self.spike_height, 1, 1000) + expected = np.ones(len(vals)) * slab.rescale( + (vals - self.spike_height) / slab_spike.slab_fraction) + actual = slab_spike.rescale(vals) + self.assertTrue(np.allclose(expected, actual, rtol=1e-5)) -- GitLab