diff --git a/bilby/core/prior.py b/bilby/core/prior.py index 63f38004cfd5d3ec174b88008c7d97c54b30368a..c67d0c1cae95260edb3b129d6d079f8bdfd5e647 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -16,7 +16,7 @@ from scipy.special import erf, erfinv, xlogy, log1p,\ from matplotlib.cbook import flatten # Keep import bilby statement, it is necessary for some eval() statements -from .utils import BilbyJsonEncoder, decode_bilby_json +from .utils import BilbyJsonEncoder, decode_bilby_json, infer_parameters_from_function from .utils import ( check_directory_exists_and_if_not_mkdir, infer_args_from_method, logger @@ -173,7 +173,7 @@ class PriorDict(dict): else: module = __name__ cls = getattr(import_module(module), cls, cls) - if key.lower() == "conversion_function": + if key.lower() in ["conversion_function", "condition_func"]: setattr(self, key, cls) elif (cls.__name__ in ['MultivariateGaussianDist', 'MultivariateNormalDist']): @@ -331,11 +331,10 @@ class PriorDict(dict): self.convert_floats_to_delta_functions() samples = dict() for key in keys: - if isinstance(self[key], Prior): - if isinstance(self[key], Constraint): - continue - else: - samples[key] = self[key].sample(size=size) + if isinstance(self[key], Constraint): + continue + elif isinstance(self[key], Prior): + samples[key] = self[key].sample(size=size) else: logger.debug('{} not a known prior.'.format(key)) return samples @@ -488,6 +487,201 @@ class PriorSet(PriorDict): super(PriorSet, self).__init__(dictionary, filename) +class ConditionalPriorDict(PriorDict): + + def __init__(self, dictionary=None, filename=None, conversion_function=None): + """ + + Parameters + ---------- + dictionary: dict + See parent class + filename: str + See parent class + """ + self._conditional_keys = [] + self._unconditional_keys = [] + self._rescale_keys = [] + self._rescale_indexes = [] + self._least_recently_rescaled_keys = [] + super(ConditionalPriorDict, self).__init__( + dictionary=dictionary, filename=filename, + conversion_function=conversion_function + ) + self._resolved = False + self._resolve_conditions() + + def _resolve_conditions(self): + """ + Resolves how priors depend on each other and automatically + sorts them into the right order. + 1. All unconditional priors are put in front in arbitrary order + 2. We loop through all the unsorted conditional priors to find + which one can go next + 3. We repeat step 2 len(self) number of times to make sure that + all conditional priors will be sorted in order + 4. We set the `self._resolved` flag to True if all conditional + priors were added in the right order + """ + self._unconditional_keys = [key for key in self.keys() if not hasattr(self[key], 'condition_func')] + conditional_keys_unsorted = [key for key in self.keys() if hasattr(self[key], 'condition_func')] + self._conditional_keys = [] + for _ in range(len(self)): + for key in conditional_keys_unsorted[:]: + if self._check_conditions_resolved(key, self.sorted_keys): + self._conditional_keys.append(key) + conditional_keys_unsorted.remove(key) + + self._resolved = True + if len(conditional_keys_unsorted) != 0: + self._resolved = False + + def _check_conditions_resolved(self, key, sampled_keys): + """ Checks if all required variables have already been sampled so we can sample this key """ + conditions_resolved = True + for k in self[key].required_variables: + if k not in sampled_keys: + conditions_resolved = False + return conditions_resolved + + def sample_subset(self, keys=iter([]), size=None): + self.convert_floats_to_delta_functions() + subset_dict = ConditionalPriorDict({key: self[key] for key in keys}) + if not subset_dict._resolved: + raise IllegalConditionsException("The current set of priors contains unresolvable conditions.") + samples = dict() + for key in subset_dict.sorted_keys: + if isinstance(self[key], Constraint): + continue + elif isinstance(self[key], Prior): + try: + samples[key] = subset_dict[key].sample(size=size, **subset_dict.get_required_variables(key)) + except ValueError: + # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) + # If that is the case, we sample each sample individually. + required_variables = subset_dict.get_required_variables(key) + samples[key] = np.zeros(size) + for i in range(size): + rvars = {key: value[i] for key, value in required_variables.items()} + samples[key][i] = subset_dict[key].sample(**rvars) + else: + logger.debug('{} not a known prior.'.format(key)) + return samples + + def get_required_variables(self, key): + """ Returns the required variables to sample a given conditional key. + + Parameters + ---------- + key : str + Name of the key that we want to know the required variables for + + Returns + ---------- + dict: key/value pairs of the required variables + """ + return {k: self[k].least_recently_sampled for k in getattr(self[key], 'required_variables', [])} + + def prob(self, sample, **kwargs): + """ + + Parameters + ---------- + sample: dict + Dictionary of the samples of which we want to have the probability of + kwargs: + The keyword arguments are passed directly to `np.product` + + Returns + ------- + float: Joint probability of all individual sample probabilities + + """ + self._check_resolved() + for key, value in sample.items(): + self[key].least_recently_sampled = value + res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample] + return np.product(res, **kwargs) + + def ln_prob(self, sample, axis=None): + """ + + Parameters + ---------- + sample: dict + Dictionary of the samples of which we want to have the log probability of + axis: Union[None, int] + Axis along which the summation is performed + + Returns + ------- + float: Joint log probability of all the individual sample probabilities + + """ + self._check_resolved() + for key, value in sample.items(): + self[key].least_recently_sampled = value + res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample] + return np.sum(res, axis=axis) + + def rescale(self, keys, theta): + """Rescale samples from unit cube to prior + + Parameters + ---------- + keys: list + List of prior keys to be rescaled + theta: list + List of randomly drawn values on a unit cube associated with the prior keys + + Returns + ------- + list: List of floats containing the rescaled sample + """ + self._check_resolved() + self._update_rescale_keys(keys) + result = dict() + for key, index in zip(self._rescale_keys, self._rescale_indexes): + required_variables = {k: result[k] for k in getattr(self[key], 'required_variables', [])} + result[key] = self[key].rescale(theta[index], **required_variables) + return [result[key] for key in keys] + + def _update_rescale_keys(self, keys): + if not keys == self._least_recently_rescaled_keys: + self._set_rescale_keys_and_indexes(keys) + self._least_recently_rescaled_keys = keys + + def _set_rescale_keys_and_indexes(self, keys): + unconditional_keys, unconditional_idxs, _ = np.intersect1d(keys, self.unconditional_keys, return_indices=True) + conditional_keys, conditional_idxs, _ = np.intersect1d(keys, self.conditional_keys, return_indices=True) + self._rescale_keys = np.append(unconditional_keys, conditional_keys) + self._rescale_indexes = np.append(unconditional_idxs, conditional_idxs) + + def _check_resolved(self): + if not self._resolved: + raise IllegalConditionsException("The current set of priors contains unresolveable conditions.") + + @property + def conditional_keys(self): + return self._conditional_keys + + @property + def unconditional_keys(self): + return self._unconditional_keys + + @property + def sorted_keys(self): + return self.unconditional_keys + self.conditional_keys + + def __setitem__(self, key, value): + super(ConditionalPriorDict, self).__setitem__(key, value) + self._resolve_conditions() + + def __delitem__(self, key): + super(ConditionalPriorDict, self).__delitem__(key) + self._resolve_conditions() + + def create_default_prior(name, default_priors_file=None): """Make a default prior for a parameter with a known name. @@ -548,6 +742,7 @@ class Prior(object): self.unit = unit self.minimum = minimum self.maximum = maximum + self.least_recently_sampled = None self.boundary = boundary def __call__(self): @@ -588,7 +783,8 @@ class Prior(object): float: A random number between 0 and 1, rescaled to match the distribution of this Prior """ - return self.rescale(np.random.uniform(0, 1, size)) + self.least_recently_sampled = self.rescale(np.random.uniform(0, 1, size)) + return self.least_recently_sampled def rescale(self, val): """ @@ -692,7 +888,7 @@ class Prior(object): """ prior_name = self.__class__.__name__ - instantiation_dict = self._get_instantiation_dict() + instantiation_dict = self.get_instantiation_dict() args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key])) for key in instantiation_dict]) return "{}({})".format(prior_name, args) @@ -775,7 +971,7 @@ class Prior(object): def maximum(self, maximum): self._maximum = maximum - def _get_instantiation_dict(self): + def get_instantiation_dict(self): subclass_args = infer_args_from_method(self.__init__) property_names = [p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)] @@ -825,11 +1021,19 @@ class Prior(object): kwargs = cls._split_repr(string) for key in kwargs: val = kwargs[key] - if key not in subclass_args: + if key not in subclass_args and not hasattr(cls, "reference_params"): raise AttributeError('Unknown argument {} for class {}'.format( key, cls.__name__)) else: kwargs[key] = cls._parse_argument_string(val) + if key in ["condition_func", "conversion_function"] and isinstance(kwargs[key], str): + if "." in kwargs[key]: + module = '.'.join(kwargs[key].split('.')[:-1]) + name = kwargs[key].split('.')[-1] + else: + module = __name__ + name = kwargs[key] + kwargs[key] = getattr(import_module(module), name) return cls(**kwargs) @classmethod @@ -2982,7 +3186,7 @@ class MultivariateGaussianDist(object): return np.exp(self.ln_prob(samp)) - def _get_instantiation_dict(self): + def get_instantiation_dict(self): subclass_args = infer_args_from_method(self.__init__) property_names = [p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)] @@ -3013,7 +3217,7 @@ class MultivariateGaussianDist(object): """ dist_name = self.__class__.__name__ - instantiation_dict = self._get_instantiation_dict() + instantiation_dict = self.get_instantiation_dict() args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key])) for key in instantiation_dict]) return "{}({})".format(dist_name, args) @@ -3146,7 +3350,7 @@ class MultivariateGaussian(Prior): if len(self.mvg.sampled_parameters) == len(self.mvg): # reset samples self.mvg.reset_sampled() - + self.least_recently_sampled = sample return sample def prob(self, val): @@ -3237,5 +3441,245 @@ class MultivariateGaussian(Prior): class MultivariateNormal(MultivariateGaussian): - """ A synonym for the :class:`bilby.core.prior.MultivariateGaussian` - prior distribution.""" + """A synonym for the :class:`bilby.core.prior.MultivariateGaussian` + prior distribution.""" + + +def conditional_prior_factory(prior_class): + class ConditionalPrior(prior_class): + def __init__(self, condition_func, name=None, latex_label=None, unit=None, + boundary=None, **reference_params): + """ + + Parameters + ---------- + condition_func: func + Functional form of the condition for this prior. The first function argument + has to be a dictionary for the `reference_params` (see below). The following + arguments are the required variables that are required before we can draw this + prior. + It needs to return a dictionary with the modified values for the + `reference_params` that are being used in the next draw. + For example if we have a Uniform prior for `x` depending on a different variable `y` + `p(x|y)` with the boundaries linearly depending on y, then this + could have the following form: + + ``` + def condition_func(reference_params, y): + return dict(minimum=reference_params['minimum'] + y, maximum=reference_params['maximum'] + y) + ``` + name: str, optional + See superclass + latex_label: str, optional + See superclass + unit: str, optional + See superclass + boundary: str, optional + See superclass + reference_params: + Initial values for attributes such as `minimum`, `maximum`. + This differs on the `prior_class`, for example for the Gaussian + prior this is `mu` and `sigma`. + """ + if 'boundary' in infer_args_from_method(super(ConditionalPrior, self).__init__): + super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, + unit=unit, boundary=boundary, **reference_params) + else: + super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, + unit=unit, **reference_params) + + self._required_variables = None + self.condition_func = condition_func + self._reference_params = reference_params + self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__) + + def sample(self, size=None, **required_variables): + """Draw a sample from the prior + + Parameters + ---------- + size: int or tuple of ints, optional + See superclass + required_variables: + Any required variables that this prior depends on + + Returns + ------- + float: See superclass + + """ + self.least_recently_sampled = self.rescale(np.random.uniform(0, 1, size), **required_variables) + return self.least_recently_sampled + + def rescale(self, val, **required_variables): + """ + 'Rescale' a sample from the unit line element to the prior. + + Parameters + ---------- + val: Union[float, int, array_like] + See superclass + required_variables: + Any required variables that this prior depends on + + + """ + self.update_conditions(**required_variables) + return super(ConditionalPrior, self).rescale(val) + + def prob(self, val, **required_variables): + """Return the prior probability of val. + + Parameters + ---------- + val: Union[float, int, array_like] + See superclass + required_variables: + Any required variables that this prior depends on + + + Returns + ------- + float: Prior probability of val + """ + self.update_conditions(**required_variables) + return super(ConditionalPrior, self).prob(val) + + def ln_prob(self, val, **required_variables): + self.update_conditions(**required_variables) + return super(ConditionalPrior, self).ln_prob(val) + + def update_conditions(self, **required_variables): + """ + This method updates the conditional parameters (depending on the parent class + this could be e.g. `minimum`, `maximum`, `mu`, `sigma`, etc.) of this prior + class depending on the required variables it depends on. + + If no variables are given, the most recently used conditional parameters are kept + + Parameters + ---------- + required_variables: + Any required variables that this prior depends on. If none are given, + self.reference_params will be used. + + """ + if sorted(list(required_variables)) == sorted(self.required_variables): + parameters = self.condition_func(self.reference_params, **required_variables) + for key, value in parameters.items(): + setattr(self, key, value) + elif len(required_variables) == 0: + return + else: + raise IllegalRequiredVariablesException("Expected kwargs for {}. Got kwargs for {} instead." + .format(self.required_variables, + list(required_variables.keys()))) + + @property + def reference_params(self): + """ + Initial values for attributes such as `minimum`, `maximum`. + This depends on the `prior_class`, for example for the Gaussian + prior this is `mu` and `sigma`. This is read-only. + """ + return self._reference_params + + @property + def condition_func(self): + return self._condition_func + + @condition_func.setter + def condition_func(self, condition_func): + if condition_func is None: + self._condition_func = lambda reference_params: reference_params + else: + self._condition_func = condition_func + self._required_variables = infer_parameters_from_function(self.condition_func) + + @property + def required_variables(self): + """ The required variables to pass into the condition function. """ + return self._required_variables + + def get_instantiation_dict(self): + instantiation_dict = super(ConditionalPrior, self).get_instantiation_dict() + for key, value in self.reference_params.items(): + instantiation_dict[key] = value + return instantiation_dict + + def reset_to_reference_parameters(self): + """ + Reset the object attributes to match the original reference parameters + """ + for key, value in self.reference_params.items(): + setattr(self, key, value) + + def __repr__(self): + """Overrides the special method __repr__. + + Returns a representation of this instance that resembles how it is instantiated. + Works correctly for all child classes + + Returns + ------- + str: A string representation of this instance + + """ + prior_name = self.__class__.__name__ + instantiation_dict = self.get_instantiation_dict() + instantiation_dict["condition_func"] = ".".join([ + instantiation_dict["condition_func"].__module__, + instantiation_dict["condition_func"].__name__ + ]) + args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key])) + for key in instantiation_dict]) + return "{}({})".format(prior_name, args) + + return ConditionalPrior + + +ConditionalBasePrior = conditional_prior_factory(Prior) # Only for testing purposes +ConditionalUniform = conditional_prior_factory(Uniform) +ConditionalDeltaFunction = conditional_prior_factory(DeltaFunction) +ConditionalPowerLaw = conditional_prior_factory(PowerLaw) +ConditionalGaussian = conditional_prior_factory(Gaussian) +ConditionalLogUniform = conditional_prior_factory(LogUniform) +ConditionalSymmetricLogUniform = conditional_prior_factory(SymmetricLogUniform) +ConditionalCosine = conditional_prior_factory(Cosine) +ConditionalSine = conditional_prior_factory(Sine) +ConditionalTruncatedGaussian = conditional_prior_factory(TruncatedGaussian) +ConditionalHalfGaussian = conditional_prior_factory(HalfGaussian) +ConditionalLogNormal = conditional_prior_factory(LogNormal) +ConditionalExponential = conditional_prior_factory(Exponential) +ConditionalStudentT = conditional_prior_factory(StudentT) +ConditionalBeta = conditional_prior_factory(Beta) +ConditionalLogistic = conditional_prior_factory(Logistic) +ConditionalCauchy = conditional_prior_factory(Cauchy) +ConditionalGamma = conditional_prior_factory(Gamma) +ConditionalChiSquared = conditional_prior_factory(ChiSquared) +ConditionalFermiDirac = conditional_prior_factory(FermiDirac) +ConditionalInterped = conditional_prior_factory(Interped) + + +class PriorException(Exception): + """ General base class for all prior exceptions """ + + +class ConditionalPriorException(PriorException): + """ General base class for all conditional prior exceptions """ + + +class IllegalRequiredVariablesException(ConditionalPriorException): + """ Exception class for exceptions relating to handling the required variables. """ + + +class PriorDictException(Exception): + """ General base class for all prior dict exceptions """ + + +class ConditionalPriorDictException(PriorDictException): + """ General base class for all conditional prior dict exceptions """ + + +class IllegalConditionsException(ConditionalPriorDictException): + """ Exception class to handle prior dicts that contain unresolvable conditions. """ diff --git a/bilby/core/result.py b/bilby/core/result.py index 7101cf99676ab0c4ecebfbfb3649cbb03f7212a3..630830dfc3830de9f6e338e88f469523bfa3f988 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -20,7 +20,7 @@ from . import utils from .utils import (logger, infer_parameters_from_function, check_directory_exists_and_if_not_mkdir,) from .utils import BilbyJsonEncoder, decode_bilby_json -from .prior import Prior, PriorDict, DeltaFunction +from .prior import Prior, PriorDict, DeltaFunction, ConditionalPriorDict def result_file_name(outdir, label, extension='json', gzip=False): @@ -299,7 +299,10 @@ class Result(object): @priors.setter def priors(self, priors): if isinstance(priors, dict): - self._priors = PriorDict(priors) + if isinstance(priors, ConditionalPriorDict): + self._priors = priors + else: + self._priors = PriorDict(priors) if self.parameter_labels is None: self.parameter_labels = [self.priors[k].latex_label for k in self.search_parameter_keys] @@ -307,7 +310,6 @@ class Result(object): self.parameter_labels_with_unit = [ self.priors[k].latex_label_with_unit for k in self.search_parameter_keys] - elif priors is None: self._priors = priors self.parameter_labels = self.search_parameter_keys diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index ab6d62bd4206caf12db8fbeda6b344d4ada02e0f..4b5cc048f484f81ce90cb00c62ac23b1cadef917 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -5,7 +5,7 @@ import numpy as np from pandas import DataFrame from ..utils import logger, command_line_args, Counter -from ..prior import Prior, PriorDict, DeltaFunction, Constraint +from ..prior import Prior, PriorDict, ConditionalPriorDict, DeltaFunction, Constraint from ..result import Result, read_in_result @@ -251,13 +251,19 @@ class Sampler(object): AttributeError prior can't be sampled. """ - for key in self.priors: - if isinstance(self.priors[key], Constraint): - continue + if isinstance(self.priors, ConditionalPriorDict): try: - self.likelihood.parameters[key] = self.priors[key].sample() + self.likelihood.parameters = self.priors.sample() except AttributeError as e: - logger.warning('Cannot sample from {}, {}'.format(key, e)) + logger.warning('Cannot sample from prior, {}'.format(e)) + else: + for key in self.priors: + if isinstance(self.priors[key], Constraint): + continue + try: + self.likelihood.parameters[key] = self.priors[key].sample() + except AttributeError as e: + logger.warning('Cannot sample from {}, {}'.format(key, e)) def _verify_parameters(self): """ Evaluate a set of parameters drawn from the prior @@ -276,9 +282,13 @@ class Sampler(object): "Your sampling set contains redundant parameters.") self._check_if_priors_can_be_sampled() - try: + if isinstance(self.priors, ConditionalPriorDict): + theta = self.priors.sample() + theta = [theta[key] for key in self._search_parameter_keys] + else: theta = [self.priors[key].sample() for key in self._search_parameter_keys] + try: self.log_likelihood(theta) except TypeError as e: raise TypeError( @@ -298,8 +308,12 @@ class Sampler(object): t1 = datetime.datetime.now() for _ in range(n_evaluations): - theta = [self.priors[key].sample() - for key in self._search_parameter_keys] + if isinstance(self.priors, ConditionalPriorDict): + theta = self.priors.sample() + theta = [theta[key] for key in self._search_parameter_keys] + else: + theta = [self.priors[key].sample() + for key in self._search_parameter_keys] self.log_likelihood(theta) total_time = (datetime.datetime.now() - t1).total_seconds() self._log_likelihood_eval_time = total_time / n_evaluations diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py index be46742558613695d8e2a6ce0fd8a6315b44f989..1ed64e89c23b1588c88083043e3b74d64989ad0f 100644 --- a/bilby/core/sampler/cpnest.py +++ b/bilby/core/sampler/cpnest.py @@ -61,9 +61,7 @@ class Cpnest(NestedSampler): def __init__(self, names, priors): self.names = names self.priors = priors - self.bounds = [ - [self.priors[key].minimum, self.priors[key].maximum] - for key in self.names] + self._update_bounds() @staticmethod def log_likelihood(x, **kwargs): @@ -75,10 +73,17 @@ class Cpnest(NestedSampler): theta = [x[n] for n in self.search_parameter_keys] return self.log_prior(theta) + def _update_bounds(self): + self.bounds = [ + [self.priors[key].minimum, self.priors[key].maximum] + for key in self.names] + def new_point(self): """Draw a point from the prior""" + prior_samples = self.priors.sample() + self._update_bounds() point = LivePoint( - self.names, [self.priors[name].sample() + self.names, [prior_samples[name] for name in self.names]) return point diff --git a/bilby/core/utils.py b/bilby/core/utils.py index 66dc063684170387a856c43c3b179b4b2c14b680..c5f9a9709b8894b3896b9526d0f453ab243de987 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -979,7 +979,7 @@ class BilbyJsonEncoder(json.JSONEncoder): if isinstance(obj, (MultivariateGaussianDist, Prior)): return {'__prior__': True, '__module__': obj.__module__, '__name__': obj.__class__.__name__, - 'kwargs': dict(obj._get_instantiation_dict())} + 'kwargs': dict(obj.get_instantiation_dict())} try: from astropy import cosmology as cosmo, units if isinstance(obj, cosmo.FLRW): diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 4a3a902aade3af8b33397ccac84a1fababd7b6c2..5837d623674c48178d3fe74406d86664b89eefff 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -4,8 +4,8 @@ import copy import numpy as np from scipy.interpolate import InterpolatedUnivariateSpline -from ..core.prior import (PriorDict, Uniform, Prior, DeltaFunction, Gaussian, - Interped, Constraint) +from ..core.prior import (ConditionalPriorDict, PriorDict, Uniform, Prior, DeltaFunction, Gaussian, + Interped, Constraint, conditional_prior_factory) from ..core.utils import infer_args_from_method, logger from .conversion import ( convert_to_lal_binary_black_hole_parameters, @@ -307,7 +307,7 @@ class AlignedSpin(Interped): boundary=boundary) -class CBCPriorDict(PriorDict): +class CBCPriorDict(ConditionalPriorDict): @property def minimum_chirp_mass(self): mass_1 = None @@ -753,3 +753,12 @@ class CalibrationPriorDict(PriorDict): latex_label=latex_label) return prior + + +def secondary_mass_condition_function(reference_params, mass_1): + return dict(minimum=reference_params['minimum'], maximum=mass_1) + + +ConditionalCosmological = conditional_prior_factory(Cosmological) +ConditionalUniformComovingVolume = conditional_prior_factory(UniformComovingVolume) +ConditionalUniformSourceFrame = conditional_prior_factory(UniformSourceFrame) diff --git a/examples/core_examples/conditional_prior.py b/examples/core_examples/conditional_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..2475e479d160f1fe0088f9142b6fdec31919a9e1 --- /dev/null +++ b/examples/core_examples/conditional_prior.py @@ -0,0 +1,44 @@ +import bilby +import numpy as np + +# This tutorial demonstrates how we can sample a prior in the shape of a ball +# Note that this will not end up sampling uniformly in that space, constraint priors are more suitable for that. +# This implementation will draw a value for the x-coordinate from p(x), and given that draw a value for the +# y-coordinate from p(y|x), and given that draw a value for the z-coordinate from p(z|x,y). +# Only the x-coordinate will end up being uniform for this + + +class ZeroLikelihood(bilby.core.likelihood.Likelihood): + """ Flat likelihood. This always returns 0. + This way our posterior distribution is exactly the prior distribution.""" + def log_likelihood(self): + return 0 + + +def condition_func_y(reference_params, x): + """ Condition function for our p(y|x) prior.""" + radius = 0.5 * (reference_params['maximum'] - reference_params['minimum']) + y_max = np.sqrt(radius**2 - x**2) + return dict(minimum=-y_max, maximum=y_max) + + +def condition_func_z(reference_params, x, y): + """ Condition function for our p(z|x, y) prior.""" + radius = 0.5 * (reference_params['maximum'] - reference_params['minimum']) + z_max = np.sqrt(radius**2 - x**2 - y**2) + return dict(minimum=-z_max, maximum=z_max) + + +# Set up the conditional priors and the flat likelihood +priors = bilby.core.prior.ConditionalPriorDict() +priors['x'] = bilby.core.prior.Uniform(minimum=-1, maximum=1, latex_label="$x$") +priors['y'] = bilby.core.prior.ConditionalUniform(condition_func=condition_func_y, minimum=-1, + maximum=1, latex_label="$y$") +priors['z'] = bilby.core.prior.ConditionalUniform(condition_func=condition_func_z, minimum=-1, + maximum=1, latex_label="$z$") +likelihood = ZeroLikelihood(parameters=dict(x=0, y=0, z=0)) + +# Sample the prior distribution +res = bilby.run_sampler(likelihood=likelihood, priors=priors, sampler='dynesty', npoints=5000, walks=100, + label='conditional_prior', outdir='outdir', resume=False, clean=True) +res.plot_corner() diff --git a/test/prior_test.py b/test/prior_test.py index 8f58d41e57fcea98a6d19ec9f5cc386e756fdd98..b6f554d44bbd2e91fd4d47ba686ecf3e73dfff95 100644 --- a/test/prior_test.py +++ b/test/prior_test.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, division import bilby import unittest from mock import Mock +import mock import numpy as np import os from collections import OrderedDict @@ -162,6 +163,9 @@ class TestPriorClasses(unittest.TestCase): covs=np.array([[2., 0.5], [0.5, 2.]]), weights=1.) + def condition_func(reference_params, test_param): + return reference_params.copy() + self.priors = [ bilby.core.prior.DeltaFunction(name='test', unit='unit', peak=1), bilby.core.prior.Gaussian(name='test', unit='unit', mu=0, sigma=1), @@ -196,7 +200,28 @@ class TestPriorClasses(unittest.TestCase): bilby.core.prior.MultivariateGaussian(mvg=mvg, name='testa', unit='unit'), bilby.core.prior.MultivariateGaussian(mvg=mvg, name='testb', unit='unit'), bilby.core.prior.MultivariateNormal(mvg=mvn, name='testa', unit='unit'), - bilby.core.prior.MultivariateNormal(mvg=mvn, name='testb', unit='unit') + bilby.core.prior.MultivariateNormal(mvg=mvn, name='testb', unit='unit'), + bilby.core.prior.ConditionalDeltaFunction(condition_func=condition_func, name='test', unit='unit', peak=1), + bilby.core.prior.ConditionalGaussian(condition_func=condition_func, name='test', unit='unit', mu=0, sigma=1), + bilby.core.prior.ConditionalPowerLaw(condition_func=condition_func, name='test', unit='unit', alpha=0, minimum=0, maximum=1), + bilby.core.prior.ConditionalPowerLaw(condition_func=condition_func, name='test', unit='unit', alpha=-1, minimum=0.5, maximum=1), + bilby.core.prior.ConditionalPowerLaw(condition_func=condition_func, name='test', unit='unit', alpha=2, minimum=1, maximum=1e2), + bilby.core.prior.ConditionalUniform(condition_func=condition_func, name='test', unit='unit', minimum=0, maximum=1), + bilby.core.prior.ConditionalLogUniform(condition_func=condition_func, name='test', unit='unit', minimum=5e0, maximum=1e2), + bilby.gw.prior.ConditionalUniformComovingVolume(condition_func=condition_func, name='redshift', minimum=0.1, maximum=1.0), + bilby.gw.prior.ConditionalUniformSourceFrame(condition_func=condition_func, name='redshift', minimum=0.1, maximum=1.0), + bilby.core.prior.ConditionalSine(condition_func=condition_func, name='test', unit='unit'), + bilby.core.prior.ConditionalCosine(condition_func=condition_func, name='test', unit='unit'), + bilby.core.prior.ConditionalTruncatedGaussian(condition_func=condition_func, name='test', unit='unit', mu=1, sigma=0.4, minimum=-1, maximum=1), + bilby.core.prior.ConditionalHalfGaussian(condition_func=condition_func, name='test', unit='unit', sigma=1), + bilby.core.prior.ConditionalLogNormal(condition_func=condition_func, name='test', unit='unit', mu=0, sigma=1), + bilby.core.prior.ConditionalExponential(condition_func=condition_func, name='test', unit='unit', mu=1), + bilby.core.prior.ConditionalStudentT(condition_func=condition_func, name='test', unit='unit', df=3, mu=0, scale=1), + bilby.core.prior.ConditionalBeta(condition_func=condition_func, name='test', unit='unit', alpha=2.0, beta=2.0), + bilby.core.prior.ConditionalLogistic(condition_func=condition_func, name='test', unit='unit', mu=0, scale=1), + bilby.core.prior.ConditionalCauchy(condition_func=condition_func, name='test', unit='unit', alpha=0, beta=1), + bilby.core.prior.ConditionalGamma(condition_func=condition_func, name='test', unit='unit', k=1, theta=1), + bilby.core.prior.ConditionalChiSquared(condition_func=condition_func, name='test', unit='unit', nu=2) ] def tearDown(self): @@ -240,6 +265,11 @@ class TestPriorClasses(unittest.TestCase): for prior in self.priors: self.assertRaises(ValueError, lambda: prior.rescale(-1)) + def test_least_recently_sampled(self): + for prior in self.priors: + least_recently_sampled_expected = prior.sample() + self.assertEqual(least_recently_sampled_expected, prior.least_recently_sampled) + def test_sampling_single(self): """Test that sampling from the prior always returns values within its domain.""" for prior in self.priors: @@ -266,6 +296,11 @@ class TestPriorClasses(unittest.TestCase): outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000) self.assertTrue(all(prior.prob(outside_domain) == 0)) + def test_least_recently_sampled(self): + for prior in self.priors: + lrs = prior.sample() + self.assertEqual(lrs, prior.least_recently_sampled) + def test_prob_and_ln_prob(self): for prior in self.priors: sample = prior.sample() @@ -603,6 +638,8 @@ class TestPriorClasses(unittest.TestCase): ) elif isinstance(prior, bilby.gw.prior.UniformComovingVolume): repr_prior_string = 'bilby.gw.prior.' + repr(prior) + elif 'Conditional' in prior.__class__.__name__: + continue # This feature does not exist because we cannot recreate the condition function else: repr_prior_string = 'bilby.core.prior.' + repr(prior) repr_prior = eval(repr_prior_string, None, dict(inf=np.inf)) @@ -906,6 +943,246 @@ class TestCreateDefaultPrior(unittest.TestCase): self.assertIsNone(bilby.core.prior.create_default_prior(name='name', default_priors_file=prior_file)) +class TestConditionalPrior(unittest.TestCase): + + def setUp(self): + self.condition_func_call_counter = 0 + + def condition_func(reference_parameters, test_variable_1, test_variable_2): + self.condition_func_call_counter += 1 + return {key: value + 1 for key, value in reference_parameters.items()} + self.condition_func = condition_func + self.minimum = 0 + self.maximum = 5 + self.test_variable_1 = 0 + self.test_variable_2 = 1 + self.prior = bilby.core.prior.ConditionalBasePrior(condition_func=condition_func, + minimum=self.minimum, + maximum=self.maximum) + + def tearDown(self): + del self.condition_func + del self.condition_func_call_counter + del self.minimum + del self.maximum + del self.test_variable_1 + del self.test_variable_2 + del self.prior + + def test_reference_params(self): + self.assertDictEqual(dict(minimum=self.minimum, maximum=self.maximum), self.prior.reference_params) + + def test_required_variables(self): + self.assertListEqual(['test_variable_1', 'test_variable_2'], sorted(self.prior.required_variables)) + + def test_required_variables_no_condition_func(self): + self.prior = bilby.core.prior.ConditionalBasePrior(condition_func=None, + minimum=self.minimum, + maximum=self.maximum) + self.assertListEqual([], self.prior.required_variables) + + def test_get_instantiation_dict(self): + expected = dict(minimum=0, maximum=5, name=None, latex_label=None, unit=None, + boundary=None, condition_func=self.condition_func) + actual = self.prior.get_instantiation_dict() + for key, value in expected.items(): + if key == 'condition_func': + continue + self.assertEqual(value, actual[key]) + + def test_update_conditions_correct_variables(self): + self.prior.update_conditions(test_variable_1=self.test_variable_1, test_variable_2=self.test_variable_2) + self.assertEqual(1, self.condition_func_call_counter) + self.assertEqual(self.minimum + 1, self.prior.minimum) + self.assertEqual(self.maximum + 1, self.prior.maximum) + + def test_update_conditions_no_variables(self): + self.prior.update_conditions(test_variable_1=self.test_variable_1, test_variable_2=self.test_variable_2) + self.prior.update_conditions() + self.assertEqual(1, self.condition_func_call_counter) + self.assertEqual(self.minimum + 1, self.prior.minimum) + self.assertEqual(self.maximum + 1, self.prior.maximum) + + def test_update_conditions_illegal_variables(self): + with self.assertRaises(bilby.core.prior.IllegalRequiredVariablesException): + self.prior.update_conditions(test_parameter_1=self.test_variable_1) + + def test_sample_calls_update_conditions(self): + with mock.patch.object(self.prior, 'update_conditions') as m: + self.prior.sample(1, + test_parameter_1=self.test_variable_1, + test_parameter_2=self.test_variable_2) + m.assert_called_with(test_parameter_1=self.test_variable_1, test_parameter_2=self.test_variable_2) + + def test_rescale_calls_update_conditions(self): + with mock.patch.object(self.prior, 'update_conditions') as m: + self.prior.rescale(1, test_parameter_1=self.test_variable_1, + test_parameter_2=self.test_variable_2) + m.assert_called_with(test_parameter_1=self.test_variable_1, + test_parameter_2=self.test_variable_2) + + def test_rescale_prob_update_conditions(self): + with mock.patch.object(self.prior, 'update_conditions') as m: + self.prior.prob(1, test_parameter_1=self.test_variable_1, + test_parameter_2=self.test_variable_2) + m.assert_called_with(test_parameter_1=self.test_variable_1, + test_parameter_2=self.test_variable_2) + + def test_rescale_ln_prob_update_conditions(self): + with mock.patch.object(self.prior, 'update_conditions') as m: + self.prior.ln_prob(1, test_parameter_1=self.test_variable_1, + test_parameter_2=self.test_variable_2) + calls = [mock.call(test_parameter_1=self.test_variable_1, + test_parameter_2=self.test_variable_2), + mock.call()] + m.assert_has_calls(calls) + + def test_reset_to_reference_parameters(self): + self.prior.minimum = 10 + self.prior.maximum = 20 + self.prior.reset_to_reference_parameters() + self.assertEqual(self.prior.reference_params['minimum'], self.prior.minimum) + self.assertEqual(self.prior.reference_params['maximum'], self.prior.maximum) + + def test_cond_prior_instantiation_no_boundary_prior(self): + prior = bilby.core.prior.ConditionalFermiDirac(condition_func=None, sigma=1, mu=1) + self.assertIsNone(prior.boundary) + + +class TestConditionalPriorDict(unittest.TestCase): + + def setUp(self): + def condition_func_1(reference_parameters, var_0): + return reference_parameters + + def condition_func_2(reference_parameters, var_0, var_1): + return reference_parameters + + def condition_func_3(reference_parameters, var_1, var_2): + return reference_parameters + + self.minimum = 0 + self.maximum = 1 + self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=self.maximum) + self.prior_1 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_1, + minimum=self.minimum, maximum=self.maximum) + self.prior_2 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_2, + minimum=self.minimum, maximum=self.maximum) + self.prior_3 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_3, + minimum=self.minimum, maximum=self.maximum) + self.conditional_priors = bilby.core.prior.ConditionalPriorDict(dict(var_3=self.prior_3, var_2=self.prior_2, + var_0=self.prior_0, var_1=self.prior_1)) + self.conditional_priors_manually_set_items = bilby.core.prior.ConditionalPriorDict() + self.test_sample = dict(var_0=0.3, var_1=0.4, var_2=0.5, var_3=0.4) + for key, value in dict(var_0=self.prior_0, var_1=self.prior_1, var_2=self.prior_2, var_3=self.prior_3).items(): + self.conditional_priors_manually_set_items[key] = value + + def tearDown(self): + del self.minimum + del self.maximum + del self.prior_0 + del self.prior_1 + del self.prior_2 + del self.prior_3 + del self.conditional_priors + del self.conditional_priors_manually_set_items + del self.test_sample + + def test_conditions_resolved_upon_instantiation(self): + self.assertListEqual(['var_0', 'var_1', 'var_2', 'var_3'], self.conditional_priors.sorted_keys) + + def test_conditions_resolved_setting_items(self): + self.assertListEqual(['var_0', 'var_1', 'var_2', 'var_3'], + self.conditional_priors_manually_set_items.sorted_keys) + + def test_unconditional_keys_upon_instantiation(self): + self.assertListEqual(['var_0'], self.conditional_priors.unconditional_keys) + + def test_unconditional_keys_setting_items(self): + self.assertListEqual(['var_0'], self.conditional_priors_manually_set_items.unconditional_keys) + + def test_conditional_keys_upon_instantiation(self): + self.assertListEqual(['var_1', 'var_2', 'var_3'], self.conditional_priors.conditional_keys) + + def test_conditional_keys_setting_items(self): + self.assertListEqual(['var_1', 'var_2', 'var_3'], self.conditional_priors_manually_set_items.conditional_keys) + + def test_prob(self): + self.assertEqual(1, self.conditional_priors.prob(sample=self.test_sample)) + + def test_prob_illegal_conditions(self): + del self.conditional_priors['var_0'] + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + self.conditional_priors.prob(sample=self.test_sample) + + def test_ln_prob(self): + self.assertEqual(0, self.conditional_priors.ln_prob(sample=self.test_sample)) + + def test_ln_prob_illegal_conditions(self): + del self.conditional_priors['var_0'] + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + self.conditional_priors.ln_prob(sample=self.test_sample) + + def test_sample_subset_all_keys(self): + with mock.patch("numpy.random.uniform") as m: + m.return_value = 0.5 + self.assertDictEqual(dict(var_0=0.5, var_1=0.5, var_2=0.5, var_3=0.5), + self.conditional_priors.sample_subset(keys=['var_0', 'var_1', 'var_2', 'var_3'])) + + def test_sample_illegal_subset(self): + with mock.patch("numpy.random.uniform") as m: + m.return_value = 0.5 + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + self.conditional_priors.sample_subset(keys=['var_1']) + + def test_sample_multiple(self): + def condition_func(reference_params, a): + return dict(minimum=reference_params['minimum'], + maximum=reference_params['maximum'], + alpha=reference_params['alpha'] * a) + priors = bilby.core.prior.ConditionalPriorDict() + priors['a'] = bilby.core.prior.Uniform(minimum=0, maximum=1) + priors['b'] = bilby.core.prior.ConditionalPowerLaw(condition_func=condition_func, minimum=1, maximum=2, + alpha=-2) + print(priors.sample(2)) + + def test_rescale(self): + + def condition_func_1_rescale(reference_parameters, var_0): + if var_0 == 0.5: + return dict(minimum=reference_parameters['minimum'], maximum=1) + return reference_parameters + + def condition_func_2_rescale(reference_parameters, var_0, var_1): + if var_0 == 0.5 and var_1 == 0.5: + return dict(minimum=reference_parameters['minimum'], maximum=1) + return reference_parameters + + def condition_func_3_rescale(reference_parameters, var_1, var_2): + if var_1 == 0.5 and var_2 == 0.5: + return dict(minimum=reference_parameters['minimum'], maximum=1) + return reference_parameters + + self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=1) + self.prior_1 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_1_rescale, + minimum=self.minimum, maximum=2) + self.prior_2 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_2_rescale, + minimum=self.minimum, maximum=2) + self.prior_3 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_3_rescale, + minimum=self.minimum, maximum=2) + self.conditional_priors = bilby.core.prior.ConditionalPriorDict(dict(var_3=self.prior_3, var_2=self.prior_2, + var_0=self.prior_0, var_1=self.prior_1)) + ref_variables = [0.5, 0.5, 0.5, 0.5] + res = self.conditional_priors.rescale(keys=list(self.test_sample.keys()), + theta=ref_variables) + self.assertListEqual(ref_variables, res) + + def test_rescale_illegal_conditions(self): + del self.conditional_priors['var_0'] + with self.assertRaises(bilby.core.prior.IllegalConditionsException): + self.conditional_priors.rescale(keys=list(self.test_sample.keys()), theta=list(self.test_sample.values())) + + class TestJsonIO(unittest.TestCase): def setUp(self):