Skip to content
Snippets Groups Projects
Commit 2f179803 authored by Shanika Galaudage's avatar Shanika Galaudage
Browse files

Merge branch 'slabspikeificator' into 'master'

Generic Slab and spike priors for bilby

See merge request !857
parents fe308ee3 6b14e6b1
No related branches found
No related tags found
1 merge request!857Generic Slab and spike priors for bilby
Pipeline #164409 passed
......@@ -4,3 +4,4 @@ from .conditional import *
from .dict import *
from .interpolated import *
from .joint import *
from .slabspike import *
......@@ -618,6 +618,7 @@ class TruncatedGaussian(Prior):
/ self.sigma / self.normalisation * self.is_in_prior_range(val)
def cdf(self, val):
val = np.atleast_1d(val)
_cdf = (erf((val - self.mu) / 2 ** 0.5 / self.sigma) - erf(
(self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 / self.normalisation
_cdf[val > self.maximum] = 1
......
import numpy as np
from bilby.core.prior.base import Prior
from bilby.core.utils import logger
class SlabSpikePrior(Prior):
def __init__(self, slab, spike_location=None, spike_height=0):
"""'Slab-and-spike' prior, see e.g. https://arxiv.org/abs/1812.07259
This prior is composed of a `slab`, i.e. any common prior distribution,
and a Dirac spike at a fixed location. This can effectively be used
to emulate sampling in the number of dimensions (similar to reversible-
jump MCMC).
`SymmetricLogUniform` and `FermiDirac` are currently not supported.
Parameters
----------
slab: Prior
Any instance of a bilby prior class. All general prior attributes
from the slab are copied into the SlabSpikePrior.
Note that this hasn't been tested for conditional priors.
spike_location: float, optional
Location of the Dirac spike. Must be between minimum and maximum
of the slab. Defaults to the minimum of the slab
spike_height: float, optional
Relative weight of the spike compared to the slab. Must be
between 0 and 1. Defaults to 0, i.e. the prior is just the slab.
"""
self.slab = slab
super().__init__(name=self.slab.name, latex_label=self.slab.latex_label, unit=self.slab.unit,
minimum=self.slab.minimum, maximum=self.slab.maximum,
check_range_nonzero=self.slab.check_range_nonzero, boundary=self.slab.boundary)
self.spike_location = spike_location
self.spike_height = spike_height
try:
self.inverse_cdf_below_spike = self._find_inverse_cdf_fraction_before_spike()
except Exception as e:
logger.warning("Disregard the following warning when running tests:\n {}".format(e))
@property
def spike_location(self):
return self._spike_loc
@spike_location.setter
def spike_location(self, spike_loc):
if spike_loc is None:
spike_loc = self.minimum
if not self.minimum <= spike_loc <= self.maximum:
raise ValueError("Spike location {} not within prior domain ".format(spike_loc))
self._spike_loc = spike_loc
@property
def spike_height(self):
return self._spike_height
@spike_height.setter
def spike_height(self, spike_height):
if 0 <= spike_height <= 1:
self._spike_height = spike_height
else:
raise ValueError("Spike height must be between 0 and 1, but is {}".format(spike_height))
@property
def slab_fraction(self):
""" Relative prior weight of the slab. """
return 1 - self.spike_height
def _find_inverse_cdf_fraction_before_spike(self):
return float(self.slab.cdf(self.spike_location)) * self.slab_fraction
def rescale(self, val):
"""
'Rescale' a sample from the unit line element to the prior.
Parameters
----------
val: Union[float, int, array_like]
A random number between 0 and 1
Returns
-------
array_like: Associated prior value with input value.
"""
val = np.atleast_1d(val)
lower_indices = np.where(val < self.inverse_cdf_below_spike)[0]
intermediate_indices = np.where(np.logical_and(
self.inverse_cdf_below_spike <= val,
val <= self.inverse_cdf_below_spike + self.spike_height))[0]
higher_indices = np.where(val > self.inverse_cdf_below_spike + self.spike_height)[0]
res = np.zeros(len(val))
res[lower_indices] = self._contracted_rescale(val[lower_indices])
res[intermediate_indices] = self.spike_location
res[higher_indices] = self._contracted_rescale(val[higher_indices] - self.spike_height)
return res
def _contracted_rescale(self, val):
"""
Contracted version of the rescale function that implements the `rescale` function
on the pure slab part of the prior.
Parameters
----------
val: Union[float, int, array_like]
A random number between 0 and self.slab_fraction
Returns
-------
array_like: Associated prior value with input value.
"""
return self.slab.rescale(val / self.slab_fraction)
def prob(self, val):
"""Return the prior probability of val.
Returns np.inf for the spike location
Parameters
----------
val: Union[float, int, array_like]
Returns
-------
array_like: Prior probability of val
"""
res = self.slab.prob(val) * self.slab_fraction
res = np.atleast_1d(res)
res[np.where(val == self.spike_location)] = np.inf
return res
def ln_prob(self, val):
"""Return the Log prior probability of val.
Returns np.inf for the spike location
Parameters
----------
val: Union[float, int, array_like]
Returns
-------
array_like: Prior probability of val
"""
res = self.slab.ln_prob(val) + np.log(self.slab_fraction)
res = np.atleast_1d(res)
res[np.where(val == self.spike_location)] = np.inf
return res
def cdf(self, val):
""" Return the CDF of the prior.
This calls to the slab CDF and adds a discrete step
at the spike location.
Parameters
----------
val: Union[float, int, array_like]
Returns
-------
array_like: CDF value of val
"""
res = self.slab.cdf(val) * self.slab_fraction
res = np.atleast_1d(res)
indices_above_spike = np.where(val > self.spike_location)[0]
res[indices_above_spike] += self.spike_height
return res
#!/usr/bin/env python
"""
An example of how to use slab-and-spike priors in bilby.
In this example we look at a simple example with the sum
of two Gaussian distributions, and we try to fit with
up to three Gaussians.
"""
import bilby
import numpy as np
import matplotlib.pyplot as plt
outdir = 'outdir'
label = 'slabspike'
bilby.utils.check_directory_exists_and_if_not_mkdir(outdir)
# Here we define our model. We want to inject two Gaussians and recover with up to three.
def gaussian(xs, amplitude, mu, sigma):
return amplitude / np.sqrt(2 * np.pi * sigma**2) * np.exp(-0.5 * (xs - mu)**2 / sigma**2)
def triple_gaussian(xs, amplitude_0, amplitude_1, amplitude_2, mu_0, mu_1, mu_2, sigma_0, sigma_1, sigma_2, **kwargs):
return \
gaussian(xs, amplitude_0, mu_0, sigma_0) + \
gaussian(xs, amplitude_1, mu_1, sigma_1) + \
gaussian(xs, amplitude_2, mu_2, sigma_2)
# Let's create our data set. We create 200 points on a grid.
xs = np.linspace(-5, 5, 200)
dx = xs[1] - xs[0]
# Note for our injection parameters we set the amplitude of the second component to 0.
injection_params = dict(amplitude_0=-3, mu_0=-4, sigma_0=4,
amplitude_1=0, mu_1=0, sigma_1=1,
amplitude_2=4, mu_2=3, sigma_2=3)
# We calculate the injected curve and add some Gaussian noise on the data points
sigma = 0.02
p = bilby.core.prior.Gaussian(mu=0, sigma=sigma)
ys = triple_gaussian(xs=xs, **injection_params) + p.sample(len(xs))
plt.errorbar(xs, ys, yerr=sigma, fmt=".k", capsize=0, label='Injected data')
plt.plot(xs, triple_gaussian(xs=xs, **injection_params), label='True signal')
plt.legend()
plt.savefig(f'{outdir}/{label}_injected_data')
plt.clf()
# Now we want to set up our priors.
priors = bilby.core.prior.PriorDict()
# For the slab-and-spike prior, we first need to define the 'slab' part, which is just a regular bilby prior.
amplitude_slab_0 = bilby.core.prior.Uniform(minimum=-10, maximum=10, name='amplitude_0', latex_label='$A_0$')
amplitude_slab_1 = bilby.core.prior.Uniform(minimum=-10, maximum=10, name='amplitude_1', latex_label='$A_1$')
amplitude_slab_2 = bilby.core.prior.Uniform(minimum=-10, maximum=10, name='amplitude_2', latex_label='$A_2$')
# We do the following to create the slab-and-spike prior. The spike height is somewhat arbitrary and can
# be corrected in post-processing.
priors['amplitude_0'] = bilby.core.prior.SlabSpikePrior(slab=amplitude_slab_0, spike_location=0, spike_height=0.1)
priors['amplitude_1'] = bilby.core.prior.SlabSpikePrior(slab=amplitude_slab_1, spike_location=0, spike_height=0.1)
priors['amplitude_2'] = bilby.core.prior.SlabSpikePrior(slab=amplitude_slab_2, spike_location=0, spike_height=0.1)
# Our problem has a degeneracy in the ordering. In general, this problem is somewhat difficult to resolve properly.
# See e.g. https://github.com/GregoryAshton/kookaburra/blob/master/src/priors.py#L72 for an implementation.
# We resolve this by not letting the priors overlap in this case.
priors['mu_0'] = bilby.core.prior.Uniform(minimum=-5, maximum=-2, name='mu_0', latex_label='$\mu_0$')
priors['mu_1'] = bilby.core.prior.Uniform(minimum=-2, maximum=2, name='mu_1', latex_label='$\mu_1$')
priors['mu_2'] = bilby.core.prior.Uniform(minimum=2, maximum=5, name='mu_2', latex_label='$\mu_2$')
priors['sigma_0'] = bilby.core.prior.LogUniform(minimum=0.01, maximum=10, name='sigma_0', latex_label='$\sigma_0$')
priors['sigma_1'] = bilby.core.prior.LogUniform(minimum=0.01, maximum=10, name='sigma_1', latex_label='$\sigma_1$')
priors['sigma_2'] = bilby.core.prior.LogUniform(minimum=0.01, maximum=10, name='sigma_2', latex_label='$\sigma_2$')
# Setting up the likelihood and running the samplers works the same as elsewhere.
likelihood = bilby.core.likelihood.GaussianLikelihood(x=xs, y=ys, func=triple_gaussian, sigma=sigma)
result = bilby.run_sampler(likelihood=likelihood, priors=priors, outdir=outdir, label=label,
sampler='dynesty', nlive=400)
result.plot_corner(truths=injection_params)
# Let's also plot the maximum likelihood fit along with the data.
max_like_params = result.posterior.iloc[-1]
plt.errorbar(xs, ys, yerr=sigma, fmt=".k", capsize=0, label='Injected data')
plt.plot(xs, triple_gaussian(xs=xs, **injection_params), label='True signal')
plt.plot(xs, triple_gaussian(xs=xs, **max_like_params), label='Max likelihood fit')
plt.legend()
plt.savefig(f'{outdir}/{label}_max_likelihood_recovery')
plt.clf()
# Finally, we can check what fraction of amplitude samples are exactly on the spike.
spike_samples_0 = len(np.where(result.posterior['amplitude_0'] == 0.0)[0]) / len(result.posterior)
spike_samples_1 = len(np.where(result.posterior['amplitude_1'] == 0.0)[0]) / len(result.posterior)
spike_samples_2 = len(np.where(result.posterior['amplitude_2'] == 0.0)[0]) / len(result.posterior)
print(f"{spike_samples_0 * 100:.2f}% of amplitude_0 samples are exactly 0.0")
print(f"{spike_samples_1 * 100:.2f}% of amplitude_1 samples are exactly 0.0")
print(f"{spike_samples_2 * 100:.2f}% of amplitude_2 samples are exactly 0.0")
import numpy as np
import unittest
import bilby
from bilby.core.prior.slabspike import SlabSpikePrior
from bilby.core.prior.analytical import Uniform, PowerLaw, LogUniform, TruncatedGaussian, \
Beta, Gaussian, Cosine, Sine, HalfGaussian, LogNormal, Exponential, StudentT, Logistic, \
Cauchy, Gamma, ChiSquared
class TestSlabSpikePrior(unittest.TestCase):
def setUp(self):
self.minimum = 0
self.maximum = 1
self.spike_loc = 0.5
self.spike_height = 0.3
self.slab = bilby.core.prior.Prior(minimum=self.minimum, maximum=self.maximum)
self.prior = SlabSpikePrior(
slab=self.slab, spike_location=self.spike_loc, spike_height=self.spike_height)
def tearDown(self):
del self.minimum
del self.maximum
del self.spike_loc
del self.spike_height
del self.prior
del self.slab
def test_slab_fraction(self):
expected = 1 - self.spike_height
self.assertEqual(expected, self.prior.slab_fraction)
def test_spike_loc(self):
self.assertEqual(self.spike_loc, self.prior.spike_location)
def test_set_spike_loc_none(self):
self.prior.spike_location = None
self.assertEqual(self.prior.minimum, self.prior.spike_location)
def test_set_spike_loc_outside_domain(self):
with self.assertRaises(ValueError):
self.prior.spike_location = 1.5
def test_set_spike_loc_maximum(self):
self.prior.spike_location = self.maximum
self.assertEqual(self.maximum, self.prior.spike_location)
def test_class_name(self):
expected = "SlabSpikePrior"
self.assertEqual(expected, self.prior.__class__.__name__)
self.assertEqual(expected, self.prior.__class__.__qualname__)
def test_set_spike_height_outside_domain(self):
with self.assertRaises(ValueError):
self.prior.spike_height = 1.5
def test_set_spike_height_domain_edge(self):
self.prior.spike_height = 0
self.prior.spike_height = 1
class TestSlabSpikeClasses(unittest.TestCase):
def setUp(self):
self.minimum = 0.4
self.maximum = 2.4
self.spike_loc = 1.5
self.spike_height = 0.3
self.slabs = [
Uniform(minimum=self.minimum, maximum=self.maximum),
PowerLaw(minimum=self.minimum, maximum=self.maximum, alpha=2),
LogUniform(minimum=self.minimum, maximum=self.maximum),
TruncatedGaussian(minimum=self.minimum, maximum=self.maximum, mu=0, sigma=1),
Beta(minimum=self.minimum, maximum=self.maximum, alpha=1, beta=1),
Gaussian(mu=0, sigma=1),
Cosine(),
Sine(),
HalfGaussian(sigma=1),
LogNormal(mu=1, sigma=2),
Exponential(mu=2),
StudentT(df=2),
Logistic(mu=2, scale=1),
Cauchy(alpha=1, beta=2),
Gamma(k=1, theta=1.),
ChiSquared(nu=2)]
self.slab_spikes = [SlabSpikePrior(slab, spike_height=self.spike_height, spike_location=self.spike_loc)
for slab in self.slabs]
self.test_nodes_finite_support = np.linspace(self.minimum, self.maximum, 1000)
self.test_nodes_infinite_support = np.linspace(-10, 10, 1000)
self.test_nodes = [self.test_nodes_finite_support
if np.isinf(slab.minimum) or np.isinf(slab.maximum)
else self.test_nodes_finite_support for slab in self.slabs]
def tearDown(self):
del self.minimum
del self.maximum
del self.spike_loc
del self.spike_height
del self.slabs
del self.test_nodes_finite_support
del self.test_nodes_infinite_support
def test_prob_on_slab(self):
for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes):
expected = slab.prob(test_nodes) * slab_spike.slab_fraction
actual = slab_spike.prob(test_nodes)
self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
def test_prob_on_spike(self):
for slab_spike in self.slab_spikes:
self.assertEqual(np.inf, slab_spike.prob(self.spike_loc))
def test_ln_prob_on_slab(self):
for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes):
expected = slab.ln_prob(test_nodes) + np.log(slab_spike.slab_fraction)
actual = slab_spike.ln_prob(test_nodes)
self.assertTrue(np.array_equal(expected, actual))
def test_ln_prob_on_spike(self):
for slab_spike in self.slab_spikes:
self.assertEqual(np.inf, slab_spike.ln_prob(self.spike_loc))
def test_inverse_cdf_below_spike_with_spike_at_minimum(self):
for slab in self.slabs:
slab_spike = SlabSpikePrior(slab=slab, spike_height=0.4, spike_location=slab.minimum)
self.assertEqual(0, slab_spike.inverse_cdf_below_spike)
def test_inverse_cdf_below_spike_with_spike_at_maximum(self):
for slab in self.slabs:
slab_spike = SlabSpikePrior(slab=slab, spike_height=0.4, spike_location=slab.maximum)
expected = 1 - slab_spike.spike_height
actual = slab_spike.inverse_cdf_below_spike
self.assertEqual(expected, actual)
def test_inverse_cdf_below_spike_arbitrary_position(self):
pass
def test_cdf_below_spike(self):
for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes):
test_nodes = test_nodes[np.where(test_nodes < self.spike_loc)]
expected = slab.cdf(test_nodes) * slab_spike.slab_fraction
actual = slab_spike.cdf(test_nodes)
self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
def test_cdf_at_spike(self):
for slab, slab_spike in zip(self.slabs, self.slab_spikes):
expected = slab.cdf(self.spike_loc) * slab_spike.slab_fraction
actual = slab_spike.cdf(self.spike_loc)
self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
def test_cdf_above_spike(self):
for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes):
test_nodes = test_nodes[np.where(test_nodes > self.spike_loc)]
expected = slab.cdf(test_nodes) * slab_spike.slab_fraction + self.spike_height
actual = slab_spike.cdf(test_nodes)
self.assertTrue(np.array_equal(expected, actual))
def test_cdf_at_minimum(self):
for slab_spike in self.slab_spikes:
expected = 0
actual = slab_spike.cdf(slab_spike.minimum)
self.assertEqual(expected, actual)
def test_cdf_at_maximum(self):
for slab_spike in self.slab_spikes:
expected = 1
actual = slab_spike.cdf(slab_spike.maximum)
self.assertEqual(expected, actual)
def test_rescale_no_spike(self):
for slab in self.slabs:
slab_spike = SlabSpikePrior(slab=slab, spike_height=0, spike_location=slab.minimum)
vals = np.linspace(0, 1, 1000)
expected = slab.rescale(vals)
actual = slab_spike.rescale(vals)
print(slab)
self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
def test_rescale_below_spike(self):
for slab, slab_spike in zip(self.slabs, self.slab_spikes):
vals = np.linspace(0, slab_spike.inverse_cdf_below_spike, 1000)
expected = slab.rescale(vals / slab_spike.slab_fraction)
actual = slab_spike.rescale(vals)
self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
def test_rescale_at_spike(self):
for slab, slab_spike in zip(self.slabs, self.slab_spikes):
vals = np.linspace(slab_spike.inverse_cdf_below_spike,
slab_spike.inverse_cdf_below_spike + slab_spike.spike_height, 1000)
expected = np.ones(len(vals)) * slab.rescale(vals[0] / slab_spike.slab_fraction)
actual = slab_spike.rescale(vals)
self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
def test_rescale_above_spike(self):
for slab, slab_spike in zip(self.slabs, self.slab_spikes):
vals = np.linspace(slab_spike.inverse_cdf_below_spike + self.spike_height, 1, 1000)
expected = np.ones(len(vals)) * slab.rescale(
(vals - self.spike_height) / slab_spike.slab_fraction)
actual = slab_spike.rescale(vals)
self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment