diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index 6d2e9b418c0e9c415440af6bfad0d2f3f1353abd..85ed0fdff2d46445d5fb1b73de7d4e782308e19c 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -267,12 +267,20 @@ class PriorDict(OrderedDict):
         """
         prob = np.product([self[key].prob(sample[key])
                            for key in sample], **kwargs)
-        if prob == 0:
-            return 0
-        elif self.evaluate_constraints(sample):
+
+        if np.all(prob == 0.):
             return prob
         else:
-            return 0
+            if isinstance(prob, float):
+                if self.evaluate_constraints(sample):
+                    return prob
+                else:
+                    return 0.
+            else:
+                constrained_prob = np.zeros_like(prob)
+                keep = np.array(self.evaluate_constraints(sample), dtype=bool)
+                constrained_prob[keep] = prob[keep]
+                return constrained_prob
 
     def ln_prob(self, sample, axis=None):
         """
@@ -292,12 +300,20 @@ class PriorDict(OrderedDict):
         """
         ln_prob = np.sum([self[key].ln_prob(sample[key])
                           for key in sample], axis=axis)
-        if np.isinf(ln_prob):
-            return ln_prob
-        elif self.evaluate_constraints(sample):
+
+        if np.all(np.isinf(ln_prob)):
             return ln_prob
         else:
-            return -np.inf
+            if isinstance(ln_prob, float):
+                if self.evaluate_constraints(sample):
+                    return ln_prob
+                else:
+                    return -np.inf
+            else:
+                constrained_ln_prob = -np.inf * np.ones_like(ln_prob)
+                keep = np.array(self.evaluate_constraints(sample), dtype=bool)
+                constrained_ln_prob[keep] = ln_prob[keep]
+                return constrained_ln_prob
 
     def rescale(self, keys, theta):
         """Rescale samples from unit cube to prior