diff --git a/bilby/core/prior.py b/bilby/core/prior.py index 6d2e9b418c0e9c415440af6bfad0d2f3f1353abd..85ed0fdff2d46445d5fb1b73de7d4e782308e19c 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -267,12 +267,20 @@ class PriorDict(OrderedDict): """ prob = np.product([self[key].prob(sample[key]) for key in sample], **kwargs) - if prob == 0: - return 0 - elif self.evaluate_constraints(sample): + + if np.all(prob == 0.): return prob else: - return 0 + if isinstance(prob, float): + if self.evaluate_constraints(sample): + return prob + else: + return 0. + else: + constrained_prob = np.zeros_like(prob) + keep = np.array(self.evaluate_constraints(sample), dtype=bool) + constrained_prob[keep] = prob[keep] + return constrained_prob def ln_prob(self, sample, axis=None): """ @@ -292,12 +300,20 @@ class PriorDict(OrderedDict): """ ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis) - if np.isinf(ln_prob): - return ln_prob - elif self.evaluate_constraints(sample): + + if np.all(np.isinf(ln_prob)): return ln_prob else: - return -np.inf + if isinstance(ln_prob, float): + if self.evaluate_constraints(sample): + return ln_prob + else: + return -np.inf + else: + constrained_ln_prob = -np.inf * np.ones_like(ln_prob) + keep = np.array(self.evaluate_constraints(sample), dtype=bool) + constrained_ln_prob[keep] = ln_prob[keep] + return constrained_ln_prob def rescale(self, keys, theta): """Rescale samples from unit cube to prior