Skip to content
Snippets Groups Projects
Commit 5d060938 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Add a categorical prior

parent 8abb56b4
No related branches found
No related tags found
1 merge request!982Add a categorical prior
......@@ -1423,3 +1423,97 @@ class FermiDirac(Prior):
idx = val >= self.minimum
lnp[idx] = norm - np.logaddexp((val[idx] / self.sigma) - self.r, 0.)
return lnp
class Categorical(Prior):
def __init__(self, ncategories, name=None, latex_label=None,
unit=None, boundary="periodic"):
""" An equal-weighted Categorical prior
Parameters:
-----------
ncategories: int
The number of available categories. The prior mass support is then
integers [0, ncategories - 1].
name: str
See superclass
latex_label: str
See superclass
unit: str
See superclass
"""
minimum = 0
# Small delta added to help with MCMC walking
maximum = ncategories - 1 + 1e-15
super(Categorical, self).__init__(
name=name, latex_label=latex_label, minimum=minimum,
maximum=maximum, unit=unit, boundary=boundary)
self.ncategories = ncategories
self.categories = np.arange(self.minimum, self.maximum)
self.p = 1 / self.ncategories
self.lnp = -np.log(self.ncategories)
def rescale(self, val):
"""
'Rescale' a sample from the unit line element to the categorical prior.
This maps to the inverse CDF. This has been analytically solved for this case.
Parameters
==========
val: Union[float, int, array_like]
Uniform probability
Returns
=======
Union[float, array_like]: Rescaled probability
"""
return np.round(val * self.maximum)
def prob(self, val):
"""Return the prior probability of val.
Parameters
==========
val: Union[float, int, array_like]
Returns
=======
float: Prior probability of val
"""
if isinstance(val, (float, int)):
if val in self.categories:
return self.p
else:
return 0
else:
val = np.atleast_1d(val)
probs = np.zeros_like(val, dtype=np.float64)
idxs = np.isin(val, self.categories)
probs[idxs] = self.p
return probs
def ln_prob(self, val):
"""Return the logarithmic prior probability of val
Parameters
==========
val: Union[float, int, array_like]
Returns
=======
float:
"""
if isinstance(val, (float, int)):
if val in self.categories:
return self.lnp
else:
return -np.inf
else:
val = np.atleast_1d(val)
probs = -np.inf * np.ones_like(val, dtype=np.float64)
idxs = np.isin(val, self.categories)
probs[idxs] = self.lnp
return probs
import unittest
import numpy as np
import bilby
class TestCategoricalPrior(unittest.TestCase):
def test_single_sample(self):
categorical_prior = bilby.core.prior.Categorical(3)
in_prior = True
for _ in range(1000):
s = categorical_prior.sample()
if s not in [0, 1, 2]:
in_prior = False
self.assertTrue(in_prior)
def test_array_sample(self):
categorical_prior = bilby.core.prior.Categorical(3)
N = 1000
s = categorical_prior.sample(N)
zeros = np.sum(s == 0)
ones = np.sum(s == 1)
twos = np.sum(s == 2)
self.assertEqual(zeros + ones + twos, N)
def test_single_probability(self):
N = 3
categorical_prior = bilby.core.prior.Categorical(N)
self.assertEqual(categorical_prior.prob(0), 1 / N)
self.assertEqual(categorical_prior.prob(1), 1 / N)
self.assertEqual(categorical_prior.prob(2), 1 / N)
self.assertEqual(categorical_prior.prob(0.5), 0)
def test_array_probability(self):
N = 3
categorical_prior = bilby.core.prior.Categorical(N)
self.assertTrue(
np.all(
categorical_prior.prob([0, 1, 1, 2, 3])
== np.array([1 / N, 1 / N, 1 / N, 1 / N, 0])
)
)
def test_single_lnprobability(self):
N = 3
categorical_prior = bilby.core.prior.Categorical(N)
self.assertEqual(categorical_prior.ln_prob(0), -np.log(N))
self.assertEqual(categorical_prior.ln_prob(1), -np.log(N))
self.assertEqual(categorical_prior.ln_prob(2), -np.log(N))
self.assertEqual(categorical_prior.ln_prob(0.5), -np.inf)
def test_array_lnprobability(self):
N = 3
categorical_prior = bilby.core.prior.Categorical(N)
self.assertTrue(
np.all(
categorical_prior.ln_prob([0, 1, 1, 2, 3])
== np.array([-np.log(N), -np.log(N), -np.log(N), -np.log(N), -np.inf])
)
)
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment