From de6ab27a14dc6e30d9e00e1e30371a3c37e5877f Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Tue, 20 Apr 2021 09:48:46 +0000
Subject: [PATCH] Improvements to the core prior class

Cherry picked results from !905:
- Add ignore errstate to avoid warning messages
- Add helper properties for the width, non_fixed_keys, fixed_keys, and constraint_keys
- Add constraints to the ConditionalPrior
---
 bilby/core/prior/analytical.py |  7 +++++--
 bilby/core/prior/base.py       |  7 ++++++-
 bilby/core/prior/dict.py       | 33 +++++++++++++++++++++++++++++++--
 3 files changed, 42 insertions(+), 5 deletions(-)

diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py
index 63ce9683b..1c3da1edc 100644
--- a/bilby/core/prior/analytical.py
+++ b/bilby/core/prior/analytical.py
@@ -148,8 +148,11 @@ class PowerLaw(Prior):
             normalising = (1 + self.alpha) / (self.maximum ** (1 + self.alpha) -
                                               self.minimum ** (1 + self.alpha))
 
-        return (self.alpha * np.nan_to_num(np.log(val)) + np.log(normalising)) + np.log(
-            1. * self.is_in_prior_range(val))
+        with np.errstate(divide='ignore', invalid='ignore'):
+            ln_in_range = np.log(1. * self.is_in_prior_range(val))
+            ln_p = self.alpha * np.nan_to_num(np.log(val)) + np.log(normalising)
+
+        return ln_p + ln_in_range
 
     def cdf(self, val):
         if self.alpha == -1:
diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py
index ab3f2e128..023f1609f 100644
--- a/bilby/core/prior/base.py
+++ b/bilby/core/prior/base.py
@@ -185,7 +185,8 @@ class Prior(object):
         np.nan
 
         """
-        return np.log(self.prob(val))
+        with np.errstate(divide='ignore'):
+            return np.log(self.prob(val))
 
     def is_in_prior_range(self, val):
         """Returns True if val is in the prior boundaries, zero otherwise
@@ -313,6 +314,10 @@ class Prior(object):
     def maximum(self, maximum):
         self._maximum = maximum
 
+    @property
+    def width(self):
+        return self.maximum - self.minimum
+
     def get_instantiation_dict(self):
         subclass_args = infer_args_from_method(self.__init__)
         dict_with_properties = get_dict_with_properties(self)
diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py
index c38c047d7..5e50d3771 100644
--- a/bilby/core/prior/dict.py
+++ b/bilby/core/prior/dict.py
@@ -367,6 +367,28 @@ class PriorDict(dict):
                 logger.debug('{} not a known prior.'.format(key))
         return samples
 
+    @property
+    def non_fixed_keys(self):
+        keys = self.keys()
+        keys = [k for k in keys if isinstance(self[k], Prior)]
+        keys = [k for k in keys if self[k].is_fixed is False]
+        keys = [k for k in keys if k not in self.constraint_keys]
+        return keys
+
+    @property
+    def fixed_keys(self):
+        return [
+            k for k, p in self.items()
+            if (p.is_fixed and k not in self.constraint_keys)
+        ]
+
+    @property
+    def constraint_keys(self):
+        return [
+            k for k, p in self.items()
+            if isinstance(p, Constraint)
+        ]
+
     def sample_subset_constrained(self, keys=iter([]), size=None):
         if size is None or size == 1:
             while True:
@@ -432,6 +454,9 @@ class PriorDict(dict):
         prob = np.product([self[key].prob(sample[key])
                            for key in sample], **kwargs)
 
+        return self.check_prob(sample, prob)
+
+    def check_prob(self, sample, prob):
         ratio = self.normalize_constraint_factor(tuple(sample.keys()))
         if np.all(prob == 0.):
             return prob
@@ -465,7 +490,9 @@ class PriorDict(dict):
         """
         ln_prob = np.sum([self[key].ln_prob(sample[key])
                           for key in sample], axis=axis)
+        return self.check_ln_prob(sample, ln_prob)
 
+    def check_ln_prob(self, sample, ln_prob):
         ratio = self.normalize_constraint_factor(tuple(sample.keys()))
         if np.all(np.isinf(ln_prob)):
             return ln_prob
@@ -648,7 +675,8 @@ class ConditionalPriorDict(PriorDict):
         for key, value in sample.items():
             self[key].least_recently_sampled = value
         res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample]
-        return np.product(res, **kwargs)
+        prob = np.product(res, **kwargs)
+        return self.check_prob(sample, prob)
 
     def ln_prob(self, sample, axis=None):
         """
@@ -669,7 +697,8 @@ class ConditionalPriorDict(PriorDict):
         for key, value in sample.items():
             self[key].least_recently_sampled = value
         res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample]
-        return np.sum(res, axis=axis)
+        ln_prob = np.sum(res, axis=axis)
+        return self.check_ln_prob(sample, ln_prob)
 
     def rescale(self, keys, theta):
         """Rescale samples from unit cube to prior
-- 
GitLab