Skip to content
Snippets Groups Projects
Commit 182376ad authored by Moritz Huebner's avatar Moritz Huebner Committed by Gregory Ashton
Browse files

Resolve "Restructure prior module"

parent 0ca9c1d8
No related branches found
No related tags found
No related merge requests found
......@@ -24,6 +24,7 @@ stages:
- python -m pip install .
- python -c "import bilby"
- python -c "import bilby.core"
- python -c "import bilby.core.prior"
- python -c "import bilby.core.sampler"
- python -c "import bilby.gw"
- python -c "import bilby.gw.detector"
......
from .analytical import *
from .base import *
from .conditional import *
from .dict import *
from .interpolated import *
from .joint import *
This diff is collapsed.
from importlib import import_module
import json
import os
import re
import numpy as np
import scipy.stats
from scipy.integrate import cumtrapz
from scipy.interpolate import interp1d
from bilby.core.utils import infer_args_from_method, BilbyJsonEncoder, decode_bilby_json, logger
class Prior(object):
_default_latex_labels = {}
def __init__(self, name=None, latex_label=None, unit=None, minimum=-np.inf,
maximum=np.inf, boundary=None):
""" Implements a Prior object
Parameters
----------
name: str, optional
Name associated with prior.
latex_label: str, optional
Latex label associated with prior, used for plotting.
unit: str, optional
If given, a Latex string describing the units of the parameter.
minimum: float, optional
Minimum of the domain, default=-np.inf
maximum: float, optional
Maximum of the domain, default=np.inf
boundary: str, optional
The boundary condition of the prior, can be 'periodic', 'reflective'
Currently implemented in cpnest, dynesty and pymultinest.
"""
self.name = name
self.latex_label = latex_label
self.unit = unit
self.minimum = minimum
self.maximum = maximum
self.least_recently_sampled = None
self.boundary = boundary
self._is_fixed = False
def __call__(self):
"""Overrides the __call__ special method. Calls the sample method.
Returns
-------
float: The return value of the sample method.
"""
return self.sample()
def __eq__(self, other):
if self.__class__ != other.__class__:
return False
if sorted(self.__dict__.keys()) != sorted(other.__dict__.keys()):
return False
for key in self.__dict__:
if type(self.__dict__[key]) is np.ndarray:
if not np.array_equal(self.__dict__[key], other.__dict__[key]):
return False
elif isinstance(self.__dict__[key], type(scipy.stats.beta(1., 1.))):
continue
else:
if not self.__dict__[key] == other.__dict__[key]:
return False
return True
def sample(self, size=None):
"""Draw a sample from the prior
Parameters
----------
size: int or tuple of ints, optional
See numpy.random.uniform docs
Returns
-------
float: A random number between 0 and 1, rescaled to match the distribution of this Prior
"""
self.least_recently_sampled = self.rescale(np.random.uniform(0, 1, size))
return self.least_recently_sampled
def rescale(self, val):
"""
'Rescale' a sample from the unit line element to the prior.
This should be overwritten by each subclass.
Parameters
----------
val: Union[float, int, array_like]
A random number between 0 and 1
Returns
-------
None
"""
return None
def prob(self, val):
"""Return the prior probability of val, this should be overwritten
Parameters
----------
val: Union[float, int, array_like]
Returns
-------
np.nan
"""
return np.nan
def cdf(self, val):
""" Generic method to calculate CDF, can be overwritten in subclass """
if np.any(np.isinf([self.minimum, self.maximum])):
raise ValueError(
"Unable to use the generic CDF calculation for priors with"
"infinite support")
x = np.linspace(self.minimum, self.maximum, 1000)
pdf = self.prob(x)
cdf = cumtrapz(pdf, x, initial=0)
interp = interp1d(x, cdf, assume_sorted=True, bounds_error=False,
fill_value=(0, 1))
return interp(val)
def ln_prob(self, val):
"""Return the prior ln probability of val, this should be overwritten
Parameters
----------
val: Union[float, int, array_like]
Returns
-------
np.nan
"""
return np.log(self.prob(val))
def is_in_prior_range(self, val):
"""Returns True if val is in the prior boundaries, zero otherwise
Parameters
----------
val: Union[float, int, array_like]
Returns
-------
np.nan
"""
return (val >= self.minimum) & (val <= self.maximum)
@staticmethod
def test_valid_for_rescaling(val):
"""Test if 0 < val < 1
Parameters
----------
val: Union[float, int, array_like]
Raises
-------
ValueError: If val is not between 0 and 1
"""
valarray = np.atleast_1d(val)
tests = (valarray < 0) + (valarray > 1)
if np.any(tests):
raise ValueError("Number to be rescaled should be in [0, 1]")
def __repr__(self):
"""Overrides the special method __repr__.
Returns a representation of this instance that resembles how it is instantiated.
Works correctly for all child classes
Returns
-------
str: A string representation of this instance
"""
prior_name = self.__class__.__name__
instantiation_dict = self.get_instantiation_dict()
args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
for key in instantiation_dict])
return "{}({})".format(prior_name, args)
@property
def _repr_dict(self):
"""
Get a dictionary containing the arguments needed to reproduce this object.
"""
property_names = {p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)}
subclass_args = infer_args_from_method(self.__init__)
dict_with_properties = self.__dict__.copy()
for key in property_names.intersection(subclass_args):
dict_with_properties[key] = getattr(self, key)
return {key: dict_with_properties[key] for key in subclass_args}
@property
def is_fixed(self):
"""
Returns True if the prior is fixed and should not be used in the sampler. Does this by checking if this instance
is an instance of DeltaFunction.
Returns
-------
bool: Whether it's fixed or not!
"""
return self._is_fixed
@property
def latex_label(self):
"""Latex label that can be used for plots.
Draws from a set of default labels if no label is given
Returns
-------
str: A latex representation for this prior
"""
return self.__latex_label
@latex_label.setter
def latex_label(self, latex_label=None):
if latex_label is None:
self.__latex_label = self.__default_latex_label
else:
self.__latex_label = latex_label
@property
def unit(self):
return self.__unit
@unit.setter
def unit(self, unit):
self.__unit = unit
@property
def latex_label_with_unit(self):
""" If a unit is specified, returns a string of the latex label and unit """
if self.unit is not None:
return "{} [{}]".format(self.latex_label, self.unit)
else:
return self.latex_label
@property
def minimum(self):
return self._minimum
@minimum.setter
def minimum(self, minimum):
self._minimum = minimum
@property
def maximum(self):
return self._maximum
@maximum.setter
def maximum(self, maximum):
self._maximum = maximum
def get_instantiation_dict(self):
subclass_args = infer_args_from_method(self.__init__)
property_names = [p for p in dir(self.__class__)
if isinstance(getattr(self.__class__, p), property)]
dict_with_properties = self.__dict__.copy()
for key in property_names:
dict_with_properties[key] = getattr(self, key)
instantiation_dict = dict()
for key in subclass_args:
instantiation_dict[key] = dict_with_properties[key]
return instantiation_dict
@property
def boundary(self):
return self._boundary
@boundary.setter
def boundary(self, boundary):
if boundary not in ['periodic', 'reflective', None]:
raise ValueError('{} is not a valid setting for prior boundaries'.format(boundary))
self._boundary = boundary
@property
def __default_latex_label(self):
if self.name in self._default_latex_labels.keys():
label = self._default_latex_labels[self.name]
else:
label = self.name
return label
def to_json(self):
return json.dumps(self, cls=BilbyJsonEncoder)
@classmethod
def from_json(cls, dct):
return decode_bilby_json(dct)
@classmethod
def from_repr(cls, string):
"""Generate the prior from it's __repr__"""
return cls._from_repr(string)
@classmethod
def _from_repr(cls, string):
subclass_args = infer_args_from_method(cls.__init__)
string = string.replace(' ', '')
kwargs = cls._split_repr(string)
for key in kwargs:
val = kwargs[key]
if key not in subclass_args and not hasattr(cls, "reference_params"):
raise AttributeError('Unknown argument {} for class {}'.format(
key, cls.__name__))
else:
kwargs[key] = cls._parse_argument_string(val)
if key in ["condition_func", "conversion_function"] and isinstance(kwargs[key], str):
if "." in kwargs[key]:
module = '.'.join(kwargs[key].split('.')[:-1])
name = kwargs[key].split('.')[-1]
else:
module = __name__
name = kwargs[key]
kwargs[key] = getattr(import_module(module), name)
return cls(**kwargs)
@classmethod
def _split_repr(cls, string):
subclass_args = infer_args_from_method(cls.__init__)
args = string.split(',')
remove = list()
for ii, key in enumerate(args):
if '(' in key:
jj = ii
while ')' not in args[jj]:
jj += 1
args[ii] = ','.join([args[ii], args[jj]]).strip()
remove.append(jj)
remove.reverse()
for ii in remove:
del args[ii]
kwargs = dict()
for ii, arg in enumerate(args):
if '=' not in arg:
logger.debug(
'Reading priors with non-keyword arguments is dangerous!')
key = subclass_args[ii]
val = arg
else:
split_arg = arg.split('=')
key = split_arg[0]
val = '='.join(split_arg[1:])
kwargs[key] = val
return kwargs
@classmethod
def _parse_argument_string(cls, val):
"""
Parse a string into the appropriate type for prior reading.
Four tests are applied in the following order:
- If the string is 'None':
`None` is returned.
- Else If the string is a raw string, e.g., r'foo':
A stripped version of the string is returned, e.g., foo.
- Else If the string contains ', e.g., 'foo':
A stripped version of the string is returned, e.g., foo.
- Else If the string contains an open parenthesis, (:
The string is interpreted as a call to instantiate another prior
class, Bilby will attempt to recursively construct that prior,
e.g., Uniform(minimum=0, maximum=1), my.custom.PriorClass(**kwargs).
- Else:
Try to evaluate the string using `eval`. Only built-in functions
and numpy methods can be used, e.g., np.pi / 2, 1.57.
Parameters
----------
val: str
The string version of the agument
Returns
-------
val: object
The parsed version of the argument.
Raises
------
TypeError:
If val cannot be parsed as described above.
"""
if val == 'None':
val = None
elif re.sub(r'\'.*\'', '', val) in ['r', 'u']:
val = val[2:-1]
elif "'" in val:
val = val.strip("'")
elif '(' in val:
other_cls = val.split('(')[0]
vals = '('.join(val.split('(')[1:])[:-1]
if "." in other_cls:
module = '.'.join(other_cls.split('.')[:-1])
other_cls = other_cls.split('.')[-1]
else:
module = __name__.replace('.' + os.path.basename(__file__).replace('.py', ''), '')
other_cls = getattr(import_module(module), other_cls)
val = other_cls.from_repr(vals)
else:
try:
val = eval(val, dict(), dict(np=np))
except NameError:
raise TypeError(
"Cannot evaluate prior, "
"failed to parse argument {}".format(val)
)
return val
class Constraint(Prior):
def __init__(self, minimum, maximum, name=None, latex_label=None,
unit=None):
super(Constraint, self).__init__(minimum=minimum, maximum=maximum, name=name,
latex_label=latex_label, unit=unit)
self._is_fixed = True
def prob(self, val):
return (val > self.minimum) & (val < self.maximum)
def ln_prob(self, val):
return np.log((val > self.minimum) & (val < self.maximum))
class PriorException(Exception):
""" General base class for all prior exceptions """
import numpy as np
from .base import Prior, PriorException
from bilby.core.prior.interpolated import Interped
from bilby.core.prior.analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \
SymmetricLogUniform, Cosine, Sine, Gaussian, TruncatedGaussian, HalfGaussian, \
LogNormal, Exponential, StudentT, Beta, Logistic, Cauchy, Gamma, ChiSquared, FermiDirac
from bilby.core.utils import infer_args_from_method, infer_parameters_from_function
def conditional_prior_factory(prior_class):
class ConditionalPrior(prior_class):
def __init__(self, condition_func, name=None, latex_label=None, unit=None,
boundary=None, **reference_params):
"""
Parameters
----------
condition_func: func
Functional form of the condition for this prior. The first function argument
has to be a dictionary for the `reference_params` (see below). The following
arguments are the required variables that are required before we can draw this
prior.
It needs to return a dictionary with the modified values for the
`reference_params` that are being used in the next draw.
For example if we have a Uniform prior for `x` depending on a different variable `y`
`p(x|y)` with the boundaries linearly depending on y, then this
could have the following form:
```
def condition_func(reference_params, y):
return dict(minimum=reference_params['minimum'] + y, maximum=reference_params['maximum'] + y)
```
name: str, optional
See superclass
latex_label: str, optional
See superclass
unit: str, optional
See superclass
boundary: str, optional
See superclass
reference_params:
Initial values for attributes such as `minimum`, `maximum`.
This differs on the `prior_class`, for example for the Gaussian
prior this is `mu` and `sigma`.
"""
if 'boundary' in infer_args_from_method(super(ConditionalPrior, self).__init__):
super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label,
unit=unit, boundary=boundary, **reference_params)
else:
super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label,
unit=unit, **reference_params)
self._required_variables = None
self.condition_func = condition_func
self._reference_params = reference_params
self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__)
self.__class__.__qualname__ = 'Conditional{}'.format(prior_class.__qualname__)
def sample(self, size=None, **required_variables):
"""Draw a sample from the prior
Parameters
----------
size: int or tuple of ints, optional
See superclass
required_variables:
Any required variables that this prior depends on
Returns
-------
float: See superclass
"""
self.least_recently_sampled = self.rescale(np.random.uniform(0, 1, size), **required_variables)
return self.least_recently_sampled
def rescale(self, val, **required_variables):
"""
'Rescale' a sample from the unit line element to the prior.
Parameters
----------
val: Union[float, int, array_like]
See superclass
required_variables:
Any required variables that this prior depends on
"""
self.update_conditions(**required_variables)
return super(ConditionalPrior, self).rescale(val)
def prob(self, val, **required_variables):
"""Return the prior probability of val.
Parameters
----------
val: Union[float, int, array_like]
See superclass
required_variables:
Any required variables that this prior depends on
Returns
-------
float: Prior probability of val
"""
self.update_conditions(**required_variables)
return super(ConditionalPrior, self).prob(val)
def ln_prob(self, val, **required_variables):
self.update_conditions(**required_variables)
return super(ConditionalPrior, self).ln_prob(val)
def update_conditions(self, **required_variables):
"""
This method updates the conditional parameters (depending on the parent class
this could be e.g. `minimum`, `maximum`, `mu`, `sigma`, etc.) of this prior
class depending on the required variables it depends on.
If no variables are given, the most recently used conditional parameters are kept
Parameters
----------
required_variables:
Any required variables that this prior depends on. If none are given,
self.reference_params will be used.
"""
if sorted(list(required_variables)) == sorted(self.required_variables):
parameters = self.condition_func(self.reference_params, **required_variables)
for key, value in parameters.items():
setattr(self, key, value)
elif len(required_variables) == 0:
return
else:
raise IllegalRequiredVariablesException("Expected kwargs for {}. Got kwargs for {} instead."
.format(self.required_variables,
list(required_variables.keys())))
@property
def reference_params(self):
"""
Initial values for attributes such as `minimum`, `maximum`.
This depends on the `prior_class`, for example for the Gaussian
prior this is `mu` and `sigma`. This is read-only.
"""
return self._reference_params
@property
def condition_func(self):
return self._condition_func
@condition_func.setter
def condition_func(self, condition_func):
if condition_func is None:
self._condition_func = lambda reference_params: reference_params
else:
self._condition_func = condition_func
self._required_variables = infer_parameters_from_function(self.condition_func)
@property
def required_variables(self):
""" The required variables to pass into the condition function. """
return self._required_variables
def get_instantiation_dict(self):
instantiation_dict = super(ConditionalPrior, self).get_instantiation_dict()
for key, value in self.reference_params.items():
instantiation_dict[key] = value
return instantiation_dict
def reset_to_reference_parameters(self):
"""
Reset the object attributes to match the original reference parameters
"""
for key, value in self.reference_params.items():
setattr(self, key, value)
def __repr__(self):
"""Overrides the special method __repr__.
Returns a representation of this instance that resembles how it is instantiated.
Works correctly for all child classes
Returns
-------
str: A string representation of this instance
"""
prior_name = self.__class__.__name__
instantiation_dict = self.get_instantiation_dict()
instantiation_dict["condition_func"] = ".".join([
instantiation_dict["condition_func"].__module__,
instantiation_dict["condition_func"].__name__
])
args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
for key in instantiation_dict])
return "{}({})".format(prior_name, args)
return ConditionalPrior
ConditionalBasePrior = conditional_prior_factory(Prior) # Only for testing purposes
ConditionalUniform = conditional_prior_factory(Uniform)
ConditionalDeltaFunction = conditional_prior_factory(DeltaFunction)
ConditionalPowerLaw = conditional_prior_factory(PowerLaw)
ConditionalGaussian = conditional_prior_factory(Gaussian)
ConditionalLogUniform = conditional_prior_factory(LogUniform)
ConditionalSymmetricLogUniform = conditional_prior_factory(SymmetricLogUniform)
ConditionalCosine = conditional_prior_factory(Cosine)
ConditionalSine = conditional_prior_factory(Sine)
ConditionalTruncatedGaussian = conditional_prior_factory(TruncatedGaussian)
ConditionalHalfGaussian = conditional_prior_factory(HalfGaussian)
ConditionalLogNormal = conditional_prior_factory(LogNormal)
ConditionalExponential = conditional_prior_factory(Exponential)
ConditionalStudentT = conditional_prior_factory(StudentT)
ConditionalBeta = conditional_prior_factory(Beta)
ConditionalLogistic = conditional_prior_factory(Logistic)
ConditionalCauchy = conditional_prior_factory(Cauchy)
ConditionalGamma = conditional_prior_factory(Gamma)
ConditionalChiSquared = conditional_prior_factory(ChiSquared)
ConditionalFermiDirac = conditional_prior_factory(FermiDirac)
ConditionalInterped = conditional_prior_factory(Interped)
class ConditionalPriorException(PriorException):
""" General base class for all conditional prior exceptions """
class IllegalRequiredVariablesException(ConditionalPriorException):
""" Exception class for exceptions relating to handling the required variables. """
This diff is collapsed.
import numpy as np
from scipy.integrate import cumtrapz
from scipy.interpolate import interp1d
from .base import Prior
from bilby.core.utils import logger
class Interped(Prior):
def __init__(self, xx, yy, minimum=np.nan, maximum=np.nan, name=None,
latex_label=None, unit=None, boundary=None):
"""Creates an interpolated prior function from arrays of xx and yy=p(xx)
Parameters
----------
xx: array_like
x values for the to be interpolated prior function
yy: array_like
p(xx) values for the to be interpolated prior function
minimum: float
See superclass
maximum: float
See superclass
name: str
See superclass
latex_label: str
See superclass
unit: str
See superclass
boundary: str
See superclass
Attributes
----------
probability_density: scipy.interpolate.interp1d
Interpolated prior probability distribution
cumulative_distribution: scipy.interpolate.interp1d
Interpolated cumulative prior probability distribution
inverse_cumulative_distribution: scipy.interpolate.interp1d
Inverted cumulative prior probability distribution
YY: array_like
Cumulative prior probability distribution
"""
self.xx = xx
self._yy = yy
self.YY = None
self.probability_density = None
self.cumulative_distribution = None
self.inverse_cumulative_distribution = None
self.__all_interpolated = interp1d(x=xx, y=yy, bounds_error=False, fill_value=0)
minimum = float(np.nanmax(np.array((min(xx), minimum))))
maximum = float(np.nanmin(np.array((max(xx), maximum))))
super(Interped, self).__init__(name=name, latex_label=latex_label, unit=unit,
minimum=minimum, maximum=maximum, boundary=boundary)
self._update_instance()
def __eq__(self, other):
if self.__class__ != other.__class__:
return False
if np.array_equal(self.xx, other.xx) and np.array_equal(self.yy, other.yy):
return True
return False
def prob(self, val):
"""Return the prior probability of val.
Parameters
----------
val: Union[float, int, array_like]
Returns
-------
Union[float, array_like]: Prior probability of val
"""
return self.probability_density(val)
def cdf(self, val):
return self.cumulative_distribution(val)
def rescale(self, val):
"""
'Rescale' a sample from the unit line element to the prior.
This maps to the inverse CDF. This is done using interpolation.
"""
self.test_valid_for_rescaling(val)
rescaled = self.inverse_cumulative_distribution(val)
if rescaled.shape == ():
rescaled = float(rescaled)
return rescaled
@property
def minimum(self):
"""Return minimum of the prior distribution.
Updates the prior distribution if minimum is set to a different value.
Returns
-------
float: Minimum of the prior distribution
"""
return self._minimum
@minimum.setter
def minimum(self, minimum):
self._minimum = minimum
if '_maximum' in self.__dict__ and self._maximum < np.inf:
self._update_instance()
@property
def maximum(self):
"""Return maximum of the prior distribution.
Updates the prior distribution if maximum is set to a different value.
Returns
-------
float: Maximum of the prior distribution
"""
return self._maximum
@maximum.setter
def maximum(self, maximum):
self._maximum = maximum
if '_minimum' in self.__dict__ and self._minimum < np.inf:
self._update_instance()
@property
def yy(self):
"""Return p(xx) values of the interpolated prior function.
Updates the prior distribution if it is changed
Returns
-------
array_like: p(xx) values
"""
return self._yy
@yy.setter
def yy(self, yy):
self._yy = yy
self.__all_interpolated = interp1d(x=self.xx, y=self._yy, bounds_error=False, fill_value=0)
self._update_instance()
def _update_instance(self):
self.xx = np.linspace(self.minimum, self.maximum, len(self.xx))
self._yy = self.__all_interpolated(self.xx)
self._initialize_attributes()
def _initialize_attributes(self):
if np.trapz(self._yy, self.xx) != 1:
logger.debug('Supplied PDF for {} is not normalised, normalising.'.format(self.name))
self._yy /= np.trapz(self._yy, self.xx)
self.YY = cumtrapz(self._yy, self.xx, initial=0)
# Need last element of cumulative distribution to be exactly one.
self.YY[-1] = 1
self.probability_density = interp1d(x=self.xx, y=self._yy, bounds_error=False, fill_value=0)
self.cumulative_distribution = interp1d(x=self.xx, y=self.YY, bounds_error=False, fill_value=(0, 1))
self.inverse_cumulative_distribution = interp1d(x=self.YY, y=self.xx, bounds_error=True)
class FromFile(Interped):
def __init__(self, file_name, minimum=None, maximum=None, name=None,
latex_label=None, unit=None, boundary=None):
"""Creates an interpolated prior function from arrays of xx and yy=p(xx) extracted from a file
Parameters
----------
file_name: str
Name of the file containing the xx and yy arrays
minimum: float
See superclass
maximum: float
See superclass
name: str
See superclass
latex_label: str
See superclass
unit: str
See superclass
boundary: str
See superclass
"""
try:
self.id = file_name
xx, yy = np.genfromtxt(self.id).T
super(FromFile, self).__init__(xx=xx, yy=yy, minimum=minimum,
maximum=maximum, name=name, latex_label=latex_label,
unit=unit, boundary=boundary)
except IOError:
logger.warning("Can't load {}.".format(self.id))
logger.warning("Format should be:")
logger.warning(r"x\tp(x)")
This diff is collapsed.
......@@ -70,7 +70,7 @@ setup(name='bilby',
author_email='paul.lasky@monash.edu',
license="MIT",
version=VERSION,
packages=['bilby', 'bilby.core', 'bilby.core.sampler',
packages=['bilby', 'bilby.core', 'bilby.core.prior', 'bilby.core.sampler',
'bilby.gw', 'bilby.gw.detector', 'bilby.gw.sampler',
'bilby.hyper', 'cli_bilby'],
package_dir={'bilby': 'bilby', 'cli_bilby': 'cli_bilby'},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment