From 6b14e6b1379d25c030023554f0079b4d57d0e041 Mon Sep 17 00:00:00 2001
From: Moritz Huebner <moritz.huebner@ligo.org>
Date: Tue, 20 Oct 2020 03:07:55 -0500
Subject: [PATCH] Generic Slab and spike priors for bilby

---
 bilby/core/prior/__init__.py                |   1 +
 bilby/core/prior/analytical.py              |   1 +
 bilby/core/prior/slabspike.py               | 169 ++++++++++++++++
 examples/core_examples/slabspike_example.py |  97 ++++++++++
 test/core/prior/slabspike_test.py           | 202 ++++++++++++++++++++
 5 files changed, 470 insertions(+)
 create mode 100644 bilby/core/prior/slabspike.py
 create mode 100644 examples/core_examples/slabspike_example.py
 create mode 100644 test/core/prior/slabspike_test.py

diff --git a/bilby/core/prior/__init__.py b/bilby/core/prior/__init__.py
index 253ad6c9c..fc795c3e1 100644
--- a/bilby/core/prior/__init__.py
+++ b/bilby/core/prior/__init__.py
@@ -4,3 +4,4 @@ from .conditional import *
 from .dict import *
 from .interpolated import *
 from .joint import *
+from .slabspike import *
diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py
index b575b3376..7a7fb2b6b 100644
--- a/bilby/core/prior/analytical.py
+++ b/bilby/core/prior/analytical.py
@@ -618,6 +618,7 @@ class TruncatedGaussian(Prior):
             / self.sigma / self.normalisation * self.is_in_prior_range(val)
 
     def cdf(self, val):
+        val = np.atleast_1d(val)
         _cdf = (erf((val - self.mu) / 2 ** 0.5 / self.sigma) - erf(
             (self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) / 2 / self.normalisation
         _cdf[val > self.maximum] = 1
diff --git a/bilby/core/prior/slabspike.py b/bilby/core/prior/slabspike.py
new file mode 100644
index 000000000..de7dc48ad
--- /dev/null
+++ b/bilby/core/prior/slabspike.py
@@ -0,0 +1,169 @@
+import numpy as np
+
+from bilby.core.prior.base import Prior
+from bilby.core.utils import logger
+
+
+class SlabSpikePrior(Prior):
+
+    def __init__(self, slab, spike_location=None, spike_height=0):
+        """'Slab-and-spike' prior, see e.g. https://arxiv.org/abs/1812.07259
+        This prior is composed of a `slab`, i.e. any common prior distribution,
+        and a Dirac spike at a fixed location. This can effectively be used
+        to emulate sampling in the number of dimensions (similar to reversible-
+        jump MCMC).
+
+        `SymmetricLogUniform` and `FermiDirac` are currently not supported.
+
+        Parameters
+        ----------
+        slab: Prior
+            Any instance of a bilby prior class. All general prior attributes
+            from the slab are copied into the SlabSpikePrior.
+            Note that this hasn't been tested for conditional priors.
+        spike_location: float, optional
+            Location of the Dirac spike. Must be between minimum and maximum
+            of the slab. Defaults to the minimum of the slab
+        spike_height: float, optional
+            Relative weight of the spike compared to the slab. Must be
+            between 0 and 1. Defaults to 0, i.e. the prior is just the slab.
+
+        """
+        self.slab = slab
+        super().__init__(name=self.slab.name, latex_label=self.slab.latex_label, unit=self.slab.unit,
+                         minimum=self.slab.minimum, maximum=self.slab.maximum,
+                         check_range_nonzero=self.slab.check_range_nonzero, boundary=self.slab.boundary)
+        self.spike_location = spike_location
+        self.spike_height = spike_height
+        try:
+            self.inverse_cdf_below_spike = self._find_inverse_cdf_fraction_before_spike()
+        except Exception as e:
+            logger.warning("Disregard the following warning when running tests:\n {}".format(e))
+
+    @property
+    def spike_location(self):
+        return self._spike_loc
+
+    @spike_location.setter
+    def spike_location(self, spike_loc):
+        if spike_loc is None:
+            spike_loc = self.minimum
+        if not self.minimum <= spike_loc <= self.maximum:
+            raise ValueError("Spike location {} not within prior domain ".format(spike_loc))
+        self._spike_loc = spike_loc
+
+    @property
+    def spike_height(self):
+        return self._spike_height
+
+    @spike_height.setter
+    def spike_height(self, spike_height):
+        if 0 <= spike_height <= 1:
+            self._spike_height = spike_height
+        else:
+            raise ValueError("Spike height must be between 0 and 1, but is {}".format(spike_height))
+
+    @property
+    def slab_fraction(self):
+        """ Relative prior weight of the slab. """
+        return 1 - self.spike_height
+
+    def _find_inverse_cdf_fraction_before_spike(self):
+        return float(self.slab.cdf(self.spike_location)) * self.slab_fraction
+
+    def rescale(self, val):
+        """
+        'Rescale' a sample from the unit line element to the prior.
+
+        Parameters
+        ----------
+        val: Union[float, int, array_like]
+            A random number between 0 and 1
+
+        Returns
+        -------
+        array_like: Associated prior value with input value.
+        """
+        val = np.atleast_1d(val)
+
+        lower_indices = np.where(val < self.inverse_cdf_below_spike)[0]
+        intermediate_indices = np.where(np.logical_and(
+            self.inverse_cdf_below_spike <= val,
+            val <= self.inverse_cdf_below_spike + self.spike_height))[0]
+        higher_indices = np.where(val > self.inverse_cdf_below_spike + self.spike_height)[0]
+
+        res = np.zeros(len(val))
+        res[lower_indices] = self._contracted_rescale(val[lower_indices])
+        res[intermediate_indices] = self.spike_location
+        res[higher_indices] = self._contracted_rescale(val[higher_indices] - self.spike_height)
+        return res
+
+    def _contracted_rescale(self, val):
+        """
+        Contracted version of the rescale function that implements the `rescale` function
+        on the pure slab part of the prior.
+
+        Parameters
+        ----------
+        val: Union[float, int, array_like]
+            A random number between 0 and self.slab_fraction
+
+        Returns
+        -------
+        array_like: Associated prior value with input value.
+        """
+        return self.slab.rescale(val / self.slab_fraction)
+
+    def prob(self, val):
+        """Return the prior probability of val.
+        Returns np.inf for the spike location
+
+        Parameters
+        ----------
+        val: Union[float, int, array_like]
+
+        Returns
+        -------
+        array_like: Prior probability of val
+        """
+        res = self.slab.prob(val) * self.slab_fraction
+        res = np.atleast_1d(res)
+        res[np.where(val == self.spike_location)] = np.inf
+        return res
+
+    def ln_prob(self, val):
+        """Return the Log prior probability of val.
+        Returns np.inf for the spike location
+
+        Parameters
+        ----------
+        val: Union[float, int, array_like]
+
+        Returns
+        -------
+        array_like: Prior probability of val
+        """
+        res = self.slab.ln_prob(val) + np.log(self.slab_fraction)
+        res = np.atleast_1d(res)
+        res[np.where(val == self.spike_location)] = np.inf
+        return res
+
+    def cdf(self, val):
+        """ Return the CDF of the prior.
+        This calls to the slab CDF and adds a discrete step
+        at the spike location.
+
+        Parameters
+        ----------
+        val: Union[float, int, array_like]
+
+        Returns
+        -------
+        array_like: CDF value of val
+
+        """
+        res = self.slab.cdf(val) * self.slab_fraction
+        res = np.atleast_1d(res)
+        indices_above_spike = np.where(val > self.spike_location)[0]
+        res[indices_above_spike] += self.spike_height
+        return res
diff --git a/examples/core_examples/slabspike_example.py b/examples/core_examples/slabspike_example.py
new file mode 100644
index 000000000..d79851235
--- /dev/null
+++ b/examples/core_examples/slabspike_example.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python
+"""
+An example of how to use slab-and-spike priors in bilby.
+In this example we look at a simple example with the sum
+of two Gaussian distributions, and we try to fit with
+up to three Gaussians.
+
+"""
+
+import bilby
+import numpy as np
+import matplotlib.pyplot as plt
+
+outdir = 'outdir'
+label = 'slabspike'
+bilby.utils.check_directory_exists_and_if_not_mkdir(outdir)
+
+
+# Here we define our model. We want to inject two Gaussians and recover with up to three.
+def gaussian(xs, amplitude, mu, sigma):
+    return amplitude / np.sqrt(2 * np.pi * sigma**2) * np.exp(-0.5 * (xs - mu)**2 / sigma**2)
+
+
+def triple_gaussian(xs, amplitude_0, amplitude_1, amplitude_2, mu_0, mu_1, mu_2, sigma_0, sigma_1, sigma_2, **kwargs):
+    return \
+        gaussian(xs, amplitude_0, mu_0, sigma_0) + \
+        gaussian(xs, amplitude_1, mu_1, sigma_1) + \
+        gaussian(xs, amplitude_2, mu_2, sigma_2)
+
+
+# Let's create our data set. We create 200 points on a grid.
+
+xs = np.linspace(-5, 5, 200)
+dx = xs[1] - xs[0]
+
+# Note for our injection parameters we set the amplitude of the second component to 0.
+injection_params = dict(amplitude_0=-3, mu_0=-4, sigma_0=4,
+                        amplitude_1=0, mu_1=0, sigma_1=1,
+                        amplitude_2=4, mu_2=3, sigma_2=3)
+
+# We calculate the injected curve and add some Gaussian noise on the data points
+sigma = 0.02
+p = bilby.core.prior.Gaussian(mu=0, sigma=sigma)
+ys = triple_gaussian(xs=xs, **injection_params) + p.sample(len(xs))
+
+plt.errorbar(xs, ys, yerr=sigma, fmt=".k", capsize=0, label='Injected data')
+plt.plot(xs, triple_gaussian(xs=xs, **injection_params), label='True signal')
+plt.legend()
+plt.savefig(f'{outdir}/{label}_injected_data')
+plt.clf()
+
+
+# Now we want to set up our priors.
+priors = bilby.core.prior.PriorDict()
+# For the slab-and-spike prior, we first need to define the 'slab' part, which is just a regular bilby prior.
+amplitude_slab_0 = bilby.core.prior.Uniform(minimum=-10, maximum=10, name='amplitude_0', latex_label='$A_0$')
+amplitude_slab_1 = bilby.core.prior.Uniform(minimum=-10, maximum=10, name='amplitude_1', latex_label='$A_1$')
+amplitude_slab_2 = bilby.core.prior.Uniform(minimum=-10, maximum=10, name='amplitude_2', latex_label='$A_2$')
+# We do the following to create the slab-and-spike prior. The spike height is somewhat arbitrary and can
+# be corrected in post-processing.
+priors['amplitude_0'] = bilby.core.prior.SlabSpikePrior(slab=amplitude_slab_0, spike_location=0, spike_height=0.1)
+priors['amplitude_1'] = bilby.core.prior.SlabSpikePrior(slab=amplitude_slab_1, spike_location=0, spike_height=0.1)
+priors['amplitude_2'] = bilby.core.prior.SlabSpikePrior(slab=amplitude_slab_2, spike_location=0, spike_height=0.1)
+# Our problem has a degeneracy in the ordering. In general, this problem is somewhat difficult to resolve properly.
+# See e.g. https://github.com/GregoryAshton/kookaburra/blob/master/src/priors.py#L72 for an implementation.
+# We resolve this by not letting the priors overlap in this case.
+priors['mu_0'] = bilby.core.prior.Uniform(minimum=-5, maximum=-2, name='mu_0', latex_label='$\mu_0$')
+priors['mu_1'] = bilby.core.prior.Uniform(minimum=-2, maximum=2, name='mu_1', latex_label='$\mu_1$')
+priors['mu_2'] = bilby.core.prior.Uniform(minimum=2, maximum=5, name='mu_2', latex_label='$\mu_2$')
+priors['sigma_0'] = bilby.core.prior.LogUniform(minimum=0.01, maximum=10, name='sigma_0', latex_label='$\sigma_0$')
+priors['sigma_1'] = bilby.core.prior.LogUniform(minimum=0.01, maximum=10, name='sigma_1', latex_label='$\sigma_1$')
+priors['sigma_2'] = bilby.core.prior.LogUniform(minimum=0.01, maximum=10, name='sigma_2', latex_label='$\sigma_2$')
+
+# Setting up the likelihood and running the samplers works the same as elsewhere.
+likelihood = bilby.core.likelihood.GaussianLikelihood(x=xs, y=ys, func=triple_gaussian, sigma=sigma)
+result = bilby.run_sampler(likelihood=likelihood, priors=priors, outdir=outdir, label=label,
+                           sampler='dynesty', nlive=400)
+
+result.plot_corner(truths=injection_params)
+
+
+# Let's also plot the maximum likelihood fit along with the data.
+max_like_params = result.posterior.iloc[-1]
+plt.errorbar(xs, ys, yerr=sigma, fmt=".k", capsize=0, label='Injected data')
+plt.plot(xs, triple_gaussian(xs=xs, **injection_params), label='True signal')
+plt.plot(xs, triple_gaussian(xs=xs, **max_like_params), label='Max likelihood fit')
+plt.legend()
+plt.savefig(f'{outdir}/{label}_max_likelihood_recovery')
+plt.clf()
+
+# Finally, we can check what fraction of amplitude samples are exactly on the spike.
+spike_samples_0 = len(np.where(result.posterior['amplitude_0'] == 0.0)[0]) / len(result.posterior)
+spike_samples_1 = len(np.where(result.posterior['amplitude_1'] == 0.0)[0]) / len(result.posterior)
+spike_samples_2 = len(np.where(result.posterior['amplitude_2'] == 0.0)[0]) / len(result.posterior)
+print(f"{spike_samples_0 * 100:.2f}% of amplitude_0 samples are exactly 0.0")
+print(f"{spike_samples_1 * 100:.2f}% of amplitude_1 samples are exactly 0.0")
+print(f"{spike_samples_2 * 100:.2f}% of amplitude_2 samples are exactly 0.0")
diff --git a/test/core/prior/slabspike_test.py b/test/core/prior/slabspike_test.py
new file mode 100644
index 000000000..d2cdcc55a
--- /dev/null
+++ b/test/core/prior/slabspike_test.py
@@ -0,0 +1,202 @@
+import numpy as np
+import unittest
+
+import bilby
+from bilby.core.prior.slabspike import SlabSpikePrior
+from bilby.core.prior.analytical import Uniform, PowerLaw, LogUniform, TruncatedGaussian, \
+    Beta, Gaussian, Cosine, Sine, HalfGaussian, LogNormal, Exponential, StudentT, Logistic, \
+    Cauchy, Gamma, ChiSquared
+
+
+class TestSlabSpikePrior(unittest.TestCase):
+
+    def setUp(self):
+        self.minimum = 0
+        self.maximum = 1
+        self.spike_loc = 0.5
+        self.spike_height = 0.3
+        self.slab = bilby.core.prior.Prior(minimum=self.minimum, maximum=self.maximum)
+        self.prior = SlabSpikePrior(
+            slab=self.slab, spike_location=self.spike_loc, spike_height=self.spike_height)
+
+    def tearDown(self):
+        del self.minimum
+        del self.maximum
+        del self.spike_loc
+        del self.spike_height
+        del self.prior
+        del self.slab
+
+    def test_slab_fraction(self):
+        expected = 1 - self.spike_height
+        self.assertEqual(expected, self.prior.slab_fraction)
+
+    def test_spike_loc(self):
+        self.assertEqual(self.spike_loc, self.prior.spike_location)
+
+    def test_set_spike_loc_none(self):
+        self.prior.spike_location = None
+        self.assertEqual(self.prior.minimum, self.prior.spike_location)
+
+    def test_set_spike_loc_outside_domain(self):
+        with self.assertRaises(ValueError):
+            self.prior.spike_location = 1.5
+
+    def test_set_spike_loc_maximum(self):
+        self.prior.spike_location = self.maximum
+        self.assertEqual(self.maximum, self.prior.spike_location)
+
+    def test_class_name(self):
+        expected = "SlabSpikePrior"
+        self.assertEqual(expected, self.prior.__class__.__name__)
+        self.assertEqual(expected, self.prior.__class__.__qualname__)
+
+    def test_set_spike_height_outside_domain(self):
+        with self.assertRaises(ValueError):
+            self.prior.spike_height = 1.5
+
+    def test_set_spike_height_domain_edge(self):
+        self.prior.spike_height = 0
+        self.prior.spike_height = 1
+
+
+class TestSlabSpikeClasses(unittest.TestCase):
+
+    def setUp(self):
+        self.minimum = 0.4
+        self.maximum = 2.4
+        self.spike_loc = 1.5
+        self.spike_height = 0.3
+
+        self.slabs = [
+            Uniform(minimum=self.minimum, maximum=self.maximum),
+            PowerLaw(minimum=self.minimum, maximum=self.maximum, alpha=2),
+            LogUniform(minimum=self.minimum, maximum=self.maximum),
+            TruncatedGaussian(minimum=self.minimum, maximum=self.maximum, mu=0, sigma=1),
+            Beta(minimum=self.minimum, maximum=self.maximum, alpha=1, beta=1),
+            Gaussian(mu=0, sigma=1),
+            Cosine(),
+            Sine(),
+            HalfGaussian(sigma=1),
+            LogNormal(mu=1, sigma=2),
+            Exponential(mu=2),
+            StudentT(df=2),
+            Logistic(mu=2, scale=1),
+            Cauchy(alpha=1, beta=2),
+            Gamma(k=1, theta=1.),
+            ChiSquared(nu=2)]
+        self.slab_spikes = [SlabSpikePrior(slab, spike_height=self.spike_height, spike_location=self.spike_loc)
+                            for slab in self.slabs]
+        self.test_nodes_finite_support = np.linspace(self.minimum, self.maximum, 1000)
+        self.test_nodes_infinite_support = np.linspace(-10, 10, 1000)
+        self.test_nodes = [self.test_nodes_finite_support
+                           if np.isinf(slab.minimum) or np.isinf(slab.maximum)
+                           else self.test_nodes_finite_support for slab in self.slabs]
+
+    def tearDown(self):
+        del self.minimum
+        del self.maximum
+        del self.spike_loc
+        del self.spike_height
+        del self.slabs
+        del self.test_nodes_finite_support
+        del self.test_nodes_infinite_support
+
+    def test_prob_on_slab(self):
+        for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes):
+            expected = slab.prob(test_nodes) * slab_spike.slab_fraction
+            actual = slab_spike.prob(test_nodes)
+            self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
+
+    def test_prob_on_spike(self):
+        for slab_spike in self.slab_spikes:
+            self.assertEqual(np.inf, slab_spike.prob(self.spike_loc))
+
+    def test_ln_prob_on_slab(self):
+        for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes):
+            expected = slab.ln_prob(test_nodes) + np.log(slab_spike.slab_fraction)
+            actual = slab_spike.ln_prob(test_nodes)
+            self.assertTrue(np.array_equal(expected, actual))
+
+    def test_ln_prob_on_spike(self):
+        for slab_spike in self.slab_spikes:
+            self.assertEqual(np.inf, slab_spike.ln_prob(self.spike_loc))
+
+    def test_inverse_cdf_below_spike_with_spike_at_minimum(self):
+        for slab in self.slabs:
+            slab_spike = SlabSpikePrior(slab=slab, spike_height=0.4, spike_location=slab.minimum)
+            self.assertEqual(0, slab_spike.inverse_cdf_below_spike)
+
+    def test_inverse_cdf_below_spike_with_spike_at_maximum(self):
+        for slab in self.slabs:
+            slab_spike = SlabSpikePrior(slab=slab, spike_height=0.4, spike_location=slab.maximum)
+            expected = 1 - slab_spike.spike_height
+            actual = slab_spike.inverse_cdf_below_spike
+            self.assertEqual(expected, actual)
+
+    def test_inverse_cdf_below_spike_arbitrary_position(self):
+        pass
+
+    def test_cdf_below_spike(self):
+        for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes):
+            test_nodes = test_nodes[np.where(test_nodes < self.spike_loc)]
+            expected = slab.cdf(test_nodes) * slab_spike.slab_fraction
+            actual = slab_spike.cdf(test_nodes)
+            self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
+
+    def test_cdf_at_spike(self):
+        for slab, slab_spike in zip(self.slabs, self.slab_spikes):
+            expected = slab.cdf(self.spike_loc) * slab_spike.slab_fraction
+            actual = slab_spike.cdf(self.spike_loc)
+            self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
+
+    def test_cdf_above_spike(self):
+        for slab, slab_spike, test_nodes in zip(self.slabs, self.slab_spikes, self.test_nodes):
+            test_nodes = test_nodes[np.where(test_nodes > self.spike_loc)]
+            expected = slab.cdf(test_nodes) * slab_spike.slab_fraction + self.spike_height
+            actual = slab_spike.cdf(test_nodes)
+            self.assertTrue(np.array_equal(expected, actual))
+
+    def test_cdf_at_minimum(self):
+        for slab_spike in self.slab_spikes:
+            expected = 0
+            actual = slab_spike.cdf(slab_spike.minimum)
+            self.assertEqual(expected, actual)
+
+    def test_cdf_at_maximum(self):
+        for slab_spike in self.slab_spikes:
+            expected = 1
+            actual = slab_spike.cdf(slab_spike.maximum)
+            self.assertEqual(expected, actual)
+
+    def test_rescale_no_spike(self):
+        for slab in self.slabs:
+            slab_spike = SlabSpikePrior(slab=slab, spike_height=0, spike_location=slab.minimum)
+            vals = np.linspace(0, 1, 1000)
+            expected = slab.rescale(vals)
+            actual = slab_spike.rescale(vals)
+            print(slab)
+            self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
+
+    def test_rescale_below_spike(self):
+        for slab, slab_spike in zip(self.slabs, self.slab_spikes):
+            vals = np.linspace(0, slab_spike.inverse_cdf_below_spike, 1000)
+            expected = slab.rescale(vals / slab_spike.slab_fraction)
+            actual = slab_spike.rescale(vals)
+            self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
+
+    def test_rescale_at_spike(self):
+        for slab, slab_spike in zip(self.slabs, self.slab_spikes):
+            vals = np.linspace(slab_spike.inverse_cdf_below_spike,
+                               slab_spike.inverse_cdf_below_spike + slab_spike.spike_height, 1000)
+            expected = np.ones(len(vals)) * slab.rescale(vals[0] / slab_spike.slab_fraction)
+            actual = slab_spike.rescale(vals)
+            self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
+
+    def test_rescale_above_spike(self):
+        for slab, slab_spike in zip(self.slabs, self.slab_spikes):
+            vals = np.linspace(slab_spike.inverse_cdf_below_spike + self.spike_height, 1, 1000)
+            expected = np.ones(len(vals)) * slab.rescale(
+                (vals - self.spike_height) / slab_spike.slab_fraction)
+            actual = slab_spike.rescale(vals)
+            self.assertTrue(np.allclose(expected, actual, rtol=1e-5))
-- 
GitLab