Skip to content
Snippets Groups Projects
Commit 3da2d6df authored by Colm Talbot's avatar Colm Talbot
Browse files

Fix prior changing

parent cd317e38
No related branches found
No related tags found
No related merge requests found
......@@ -799,7 +799,8 @@ class Cosine(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
Prior.test_valid_for_rescaling(val)
return np.arcsin(-1 + val * 2)
norm = 1 / (np.sin(self.maximum) - np.sin(self.minimum))
return np.arcsin(val / norm + np.sin(self.minimum))
def prob(self, val):
"""Return the prior probability of val. Defined over [-pi/2, pi/2].
......@@ -844,7 +845,8 @@ class Sine(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
Prior.test_valid_for_rescaling(val)
return np.arccos(1 - val * 2)
norm = 1 / (np.cos(self.minimum) - np.cos(self.maximum))
return np.arccos(np.cos(self.minimum) - val / norm)
def prob(self, val):
"""Return the prior probability of val. Defined over [0, pi].
......@@ -1279,16 +1281,16 @@ class Beta(Prior):
See superclass
"""
Prior.__init__(self, minimum=minimum, maximum=maximum, name=name,
latex_label=latex_label, unit=unit)
if alpha <= 0. or beta <= 0.:
raise ValueError("alpha and beta must both be positive values")
self.alpha = alpha
self.beta = beta
self._loc = minimum
self._scale = maximum - minimum
self._alpha = alpha
self._beta = beta
self._minimum = minimum
self._maximum = maximum
Prior.__init__(self, minimum=minimum, maximum=maximum, name=name,
latex_label=latex_label, unit=unit)
self._set_dist()
def rescale(self, val):
"""
......@@ -1299,8 +1301,7 @@ class Beta(Prior):
Prior.test_valid_for_rescaling(val)
# use scipy distribution percentage point function (ppf)
return scipy.stats.beta.ppf(
val, self.alpha, self.beta, loc=self._loc, scale=self._scale)
return self._dist.ppf(val)
def prob(self, val):
"""Return the prior probability of val.
......@@ -1314,8 +1315,7 @@ class Beta(Prior):
float: Prior probability of val
"""
spdf = scipy.stats.beta.pdf(
val, self.alpha, self.beta, loc=self._loc, scale=self._scale)
spdf = self._dist.pdf(val)
if np.all(np.isfinite(spdf)):
return spdf
......@@ -1328,8 +1328,7 @@ class Beta(Prior):
return 0.
def ln_prob(self, val):
spdf = scipy.stats.beta.logpdf(
val, self.alpha, self.beta, loc=self._loc, scale=self._scale)
spdf = self._dist.logpdf(val)
if np.all(np.isfinite(spdf)):
return spdf
......@@ -1340,6 +1339,48 @@ class Beta(Prior):
else:
return -np.inf
def _set_dist(self):
"""Try/except to stop it falling over at instantiation"""
self._dist = scipy.stats.beta(
a=self.alpha, b=self.beta, loc=self.minimum,
scale=(self.maximum - self.minimum))
@property
def maximum(self):
return self._maximum
@maximum.setter
def maximum(self, maximum):
self._maximum = maximum
self._set_dist()
@property
def minimum(self):
return self._minimum
@minimum.setter
def minimum(self, minimum):
self._minimum = minimum
self._set_dist()
@property
def alpha(self):
return self._alpha
@alpha.setter
def alpha(self, alpha):
self._alpha = alpha
self._set_dist()
@property
def beta(self):
return self._beta
@beta.setter
def beta(self, beta):
self._beta = beta
self._set_dist()
class Logistic(Prior):
def __init__(self, mu, scale, name=None, latex_label=None, unit=None):
......@@ -1605,7 +1646,7 @@ class Interped(Prior):
Prior.__init__(self, name=name, latex_label=latex_label, unit=unit,
minimum=np.nanmax(np.array((min(xx), minimum))),
maximum=np.nanmin(np.array((max(xx), maximum))))
self.__update_instance()
self._update_instance()
def __eq__(self, other):
if self.__class__ != other.__class__:
......@@ -1656,7 +1697,7 @@ class Interped(Prior):
def minimum(self, minimum):
self._minimum = minimum
if '_maximum' in self.__dict__ and self._maximum < np.inf:
self.__update_instance()
self._update_instance()
@property
def maximum(self):
......@@ -1675,14 +1716,14 @@ class Interped(Prior):
def maximum(self, maximum):
self._maximum = maximum
if '_minimum' in self.__dict__ and self._minimum < np.inf:
self.__update_instance()
self._update_instance()
def __update_instance(self):
def _update_instance(self):
self.xx = np.linspace(self.minimum, self.maximum, len(self.xx))
self.yy = self.__all_interpolated(self.xx)
self.__initialize_attributes()
self._initialize_attributes()
def __initialize_attributes(self):
def _initialize_attributes(self):
if np.trapz(self.yy, self.xx) != 1:
logger.debug('Supplied PDF for {} is not normalised, normalising.'.format(self.name))
self.yy /= np.trapz(self.yy, self.xx)
......
from . import (calibration, conversion, cosmology, detector, likelihood, prior,
result, source, utils, waveform_generator)
from .waveform_generator import WaveformGenerator
from .likelihood import GravitationalWaveTransient
......
......@@ -84,8 +84,10 @@ class Cosmological(Interped):
self._minimum['redshift'] = cosmo.z_at_value(
cosmology.comoving_distance, minimum * self.unit)
self._minimum['luminosity_distance'] = self._minimum['redshift']
if getattr(self._maximum, self.name, np.inf) < np.inf:
self.__update_instance()
try:
self._update_instance()
except (AttributeError, KeyError):
pass
@property
def maximum(self):
......@@ -108,8 +110,10 @@ class Cosmological(Interped):
self._maximum['redshift'] = cosmo.z_at_value(
cosmology.comoving_distance, maximum * self.unit)
self._maximum['luminosity_distance'] = self._maximum['redshift']
if getattr(self._minimum, self.name, np.inf) < np.inf:
self.__update_instance()
try:
self._update_instance()
except (AttributeError, KeyError):
pass
def get_corresponding_prior(self, name=None, unit=None):
subclass_args = infer_args_from_method(self.__init__)
......
from __future__ import absolute_import
from __future__ import absolute_import, division
import bilby
import unittest
from mock import Mock
import numpy as np
import os
import copy
from collections import OrderedDict
......@@ -311,6 +310,8 @@ class TestPriorClasses(unittest.TestCase):
for prior in self.priors:
if isinstance(prior, bilby.core.prior.Interped):
continue # we cannot test this because of the numpy arrays
if isinstance(prior, bilby.core.prior.Beta):
continue # We cannot test this as it has a frozen scipy dist
elif isinstance(prior, bilby.gw.prior.UniformComovingVolume):
repr_prior_string = 'bilby.gw.prior.' + repr(prior)
else:
......@@ -318,6 +319,30 @@ class TestPriorClasses(unittest.TestCase):
repr_prior = eval(repr_prior_string)
self.assertEqual(prior, repr_prior)
def test_set_maximum_setting(self):
for prior in self.priors:
if isinstance(prior, (
bilby.core.prior.DeltaFunction, bilby.core.prior.Gaussian,
bilby.core.prior.HalfGaussian, bilby.core.prior.LogNormal,
bilby.core.prior.Exponential, bilby.core.prior.StudentT,
bilby.core.prior.Logistic, bilby.core.prior.Cauchy,
bilby.core.prior.Gamma)):
continue
prior.maximum = (prior.maximum + prior.minimum) / 2
self.assertTrue(max(prior.sample(10000)) < prior.maximum)
def test_set_minimum_setting(self):
for prior in self.priors:
if isinstance(prior, (
bilby.core.prior.DeltaFunction, bilby.core.prior.Gaussian,
bilby.core.prior.HalfGaussian, bilby.core.prior.LogNormal,
bilby.core.prior.Exponential, bilby.core.prior.StudentT,
bilby.core.prior.Logistic, bilby.core.prior.Cauchy,
bilby.core.prior.Gamma)):
continue
prior.minimum = (prior.maximum + prior.minimum) / 2
self.assertTrue(min(prior.sample(10000)) > prior.minimum)
class TestPriorDict(unittest.TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment