From de6ab27a14dc6e30d9e00e1e30371a3c37e5877f Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Tue, 20 Apr 2021 09:48:46 +0000 Subject: [PATCH] Improvements to the core prior class Cherry picked results from !905: - Add ignore errstate to avoid warning messages - Add helper properties for the width, non_fixed_keys, fixed_keys, and constraint_keys - Add constraints to the ConditionalPrior --- bilby/core/prior/analytical.py | 7 +++++-- bilby/core/prior/base.py | 7 ++++++- bilby/core/prior/dict.py | 33 +++++++++++++++++++++++++++++++-- 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 63ce9683b..1c3da1edc 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -148,8 +148,11 @@ class PowerLaw(Prior): normalising = (1 + self.alpha) / (self.maximum ** (1 + self.alpha) - self.minimum ** (1 + self.alpha)) - return (self.alpha * np.nan_to_num(np.log(val)) + np.log(normalising)) + np.log( - 1. * self.is_in_prior_range(val)) + with np.errstate(divide='ignore', invalid='ignore'): + ln_in_range = np.log(1. * self.is_in_prior_range(val)) + ln_p = self.alpha * np.nan_to_num(np.log(val)) + np.log(normalising) + + return ln_p + ln_in_range def cdf(self, val): if self.alpha == -1: diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index ab3f2e128..023f1609f 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -185,7 +185,8 @@ class Prior(object): np.nan """ - return np.log(self.prob(val)) + with np.errstate(divide='ignore'): + return np.log(self.prob(val)) def is_in_prior_range(self, val): """Returns True if val is in the prior boundaries, zero otherwise @@ -313,6 +314,10 @@ class Prior(object): def maximum(self, maximum): self._maximum = maximum + @property + def width(self): + return self.maximum - self.minimum + def get_instantiation_dict(self): subclass_args = infer_args_from_method(self.__init__) dict_with_properties = get_dict_with_properties(self) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index c38c047d7..5e50d3771 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -367,6 +367,28 @@ class PriorDict(dict): logger.debug('{} not a known prior.'.format(key)) return samples + @property + def non_fixed_keys(self): + keys = self.keys() + keys = [k for k in keys if isinstance(self[k], Prior)] + keys = [k for k in keys if self[k].is_fixed is False] + keys = [k for k in keys if k not in self.constraint_keys] + return keys + + @property + def fixed_keys(self): + return [ + k for k, p in self.items() + if (p.is_fixed and k not in self.constraint_keys) + ] + + @property + def constraint_keys(self): + return [ + k for k, p in self.items() + if isinstance(p, Constraint) + ] + def sample_subset_constrained(self, keys=iter([]), size=None): if size is None or size == 1: while True: @@ -432,6 +454,9 @@ class PriorDict(dict): prob = np.product([self[key].prob(sample[key]) for key in sample], **kwargs) + return self.check_prob(sample, prob) + + def check_prob(self, sample, prob): ratio = self.normalize_constraint_factor(tuple(sample.keys())) if np.all(prob == 0.): return prob @@ -465,7 +490,9 @@ class PriorDict(dict): """ ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis) + return self.check_ln_prob(sample, ln_prob) + def check_ln_prob(self, sample, ln_prob): ratio = self.normalize_constraint_factor(tuple(sample.keys())) if np.all(np.isinf(ln_prob)): return ln_prob @@ -648,7 +675,8 @@ class ConditionalPriorDict(PriorDict): for key, value in sample.items(): self[key].least_recently_sampled = value res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample] - return np.product(res, **kwargs) + prob = np.product(res, **kwargs) + return self.check_prob(sample, prob) def ln_prob(self, sample, axis=None): """ @@ -669,7 +697,8 @@ class ConditionalPriorDict(PriorDict): for key, value in sample.items(): self[key].least_recently_sampled = value res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample] - return np.sum(res, axis=axis) + ln_prob = np.sum(res, axis=axis) + return self.check_ln_prob(sample, ln_prob) def rescale(self, keys, theta): """Rescale samples from unit cube to prior -- GitLab