Skip to content
Snippets Groups Projects
Commit c3e8c366 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'fermi_dirac' into 'master'

Add a Fermi-Dirac prior distribution

See merge request !409
parents 81d7c253 fdb49a73
No related branches found
No related tags found
No related merge requests found
......@@ -1852,3 +1852,110 @@ class FromFile(Interped):
logger.warning("Can't load {}.".format(self.id))
logger.warning("Format should be:")
logger.warning(r"x\tp(x)")
class FermiDirac(Prior):
def __init__(self, sigma, mu=None, r=None, name=None, latex_label=None,
unit=None):
"""A Fermi-Dirac type prior, with a fixed lower boundary at zero
(see, e.g. Section 2.3.5 of [1]_). The probability distribution
is defined by Equation 22 of [1]_.
Parameters
----------
sigma: float (required)
The range over which the attenuation of the distribution happens
mu: float
The point at which the distribution falls to 50% of its maximum
value
r: float
A value giving mu/sigma. This can be used instead of specifying
mu.
name: str
See superclass
latex_label: str
See superclass
unit: str
See superclass
References
----------
.. [1] M. Pitkin, M. Isi, J. Veitch & G. Woan, `arXiv:1705.08978v1
<https:arxiv.org/abs/1705.08978v1>`_, 2017.
"""
Prior.__init__(self, name=name, latex_label=latex_label, unit=unit, minimum=0.)
self.sigma = sigma
if mu is None and r is None:
raise ValueError("For the Fermi-Dirac prior either a 'mu' value or 'r' "
"value must be given.")
if r is None and mu is not None:
self.mu = mu
self.r = self.mu / self.sigma
else:
self.r = r
self.mu = self.sigma * self.r
if self.r <= 0. or self.sigma <= 0.:
raise ValueError("For the Fermi-Dirac prior the values of sigma and r "
"must be positive.")
def rescale(self, val):
"""
'Rescale' a sample from the unit line element to the appropriate Fermi-Dirac prior.
This maps to the inverse CDF. This has been analytically solved for this case,
see Equation 24 of [1]_.
References
----------
.. [1] M. Pitkin, M. Isi, J. Veitch & G. Woan, `arXiv:1705.08978v1
<https:arxiv.org/abs/1705.08978v1>`_, 2017.
"""
Prior.test_valid_for_rescaling(val)
inv = (-np.exp(-1. * self.r) + (1. + np.exp(self.r))**-val +
np.exp(-1. * self.r) * (1. + np.exp(self.r))**-val)
# if val is 1 this will cause inv to be negative (due to numerical
# issues), so return np.inf
if isinstance(val, (float, int)):
if inv < 0:
return np.inf
else:
return -self.sigma * np.log(inv)
else:
idx = inv >= 0.
tmpinv = np.inf * np.ones(len(val))
tmpinv[idx] = -self.sigma * np.log(inv[idx])
return tmpinv
def prob(self, val):
"""Return the prior probability of val.
Parameters
----------
val: float
Returns
-------
float: Prior probability of val
"""
return np.exp(self.ln_prob(val))
def ln_prob(self, val):
norm = -np.log(self.sigma * np.log(1. + np.exp(self.r)))
if isinstance(val, (float, int)):
if val < self.minimum:
return -np.inf
else:
return norm - np.logaddexp((val / self.sigma) - self.r, 0.)
else:
lnp = -np.inf * np.ones(len(val))
idx = val >= self.minimum
lnp[idx] = norm - np.logaddexp((val[idx] / self.sigma) - self.r, 0.)
return lnp
......@@ -153,6 +153,7 @@ class TestPriorClasses(unittest.TestCase):
bilby.core.prior.Lorentzian(name='test', unit='unit', alpha=0, beta=1),
bilby.core.prior.Gamma(name='test', unit='unit', k=1, theta=1),
bilby.core.prior.ChiSquared(name='test', unit='unit', nu=2),
bilby.core.prior.FermiDirac(name='test', unit='unit', sigma=1., r=10.),
bilby.gw.prior.AlignedSpin(name='test', unit='unit'),
]
......@@ -227,6 +228,13 @@ class TestPriorClasses(unittest.TestCase):
with self.assertRaises(ValueError):
bilby.core.prior.Beta(name='test', unit='unit', alpha=2.0, beta=-2.0),
def test_fermidirac_fail(self):
with self.assertRaises(ValueError):
bilby.core.prior.FermiDirac(name='test', unit='unit', sigma=1.)
with self.assertRaises(ValueError):
bilby.core.prior.FermiDirac(name='test', unit='unit', sigma=1., mu=-1)
def test_probability_in_domain(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
......@@ -269,6 +277,8 @@ class TestPriorClasses(unittest.TestCase):
domain = np.linspace(0., 1e2, 5000)
elif isinstance(prior, bilby.core.prior.Logistic):
domain = np.linspace(-1e2, 1e2, 1000)
elif isinstance(prior, bilby.core.prior.FermiDirac):
domain = np.linspace(0., 1e2, 1000)
else:
domain = np.linspace(prior.minimum, prior.maximum, 1000)
self.assertAlmostEqual(np.trapz(prior.prob(domain), domain), 1, 3)
......@@ -326,7 +336,7 @@ class TestPriorClasses(unittest.TestCase):
bilby.core.prior.HalfGaussian, bilby.core.prior.LogNormal,
bilby.core.prior.Exponential, bilby.core.prior.StudentT,
bilby.core.prior.Logistic, bilby.core.prior.Cauchy,
bilby.core.prior.Gamma)):
bilby.core.prior.Gamma, bilby.core.prior.FermiDirac)):
continue
prior.maximum = (prior.maximum + prior.minimum) / 2
self.assertTrue(max(prior.sample(10000)) < prior.maximum)
......@@ -338,7 +348,7 @@ class TestPriorClasses(unittest.TestCase):
bilby.core.prior.HalfGaussian, bilby.core.prior.LogNormal,
bilby.core.prior.Exponential, bilby.core.prior.StudentT,
bilby.core.prior.Logistic, bilby.core.prior.Cauchy,
bilby.core.prior.Gamma)):
bilby.core.prior.Gamma, bilby.core.prior.FermiDirac)):
continue
prior.minimum = (prior.maximum + prior.minimum) / 2
self.assertTrue(min(prior.sample(10000)) > prior.minimum)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment