diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py
index abb54e67b523390e81c7211d6435e2fe764ca565..3ec1a0593b44c6ea007febbd7c33ed11c021da49 100644
--- a/bilby/core/prior/dict.py
+++ b/bilby/core/prior/dict.py
@@ -365,23 +365,21 @@ class PriorDict(dict):
                     return sample
         else:
             needed = np.prod(size)
-            constraint_keys = list()
-            for ii, key in enumerate(keys[-1::-1]):
+            for key in keys.copy():
                 if isinstance(self[key], Constraint):
-                    constraint_keys.append(-ii - 1)
-            for ii in constraint_keys[-1::-1]:
-                del keys[ii]
+                    del keys[keys.index(key)]
             all_samples = {key: np.array([]) for key in keys}
             _first_key = list(all_samples.keys())[0]
             while len(all_samples[_first_key]) < needed:
                 samples = self.sample_subset(keys=keys, size=needed)
                 keep = np.array(self.evaluate_constraints(samples), dtype=bool)
-                for key in samples:
-                    all_samples[key] = np.hstack(
-                        [all_samples[key], samples[key][keep].flatten()])
-            all_samples = {key: np.reshape(all_samples[key][:needed], size)
-                           for key in all_samples
-                           if not isinstance(self[key], Constraint)}
+                for key in keys:
+                    all_samples[key] = np.hstack([
+                        all_samples[key], samples[key][keep].flatten()
+                    ])
+            all_samples = {
+                key: np.reshape(all_samples[key][:needed], size) for key in keys
+            }
             return all_samples
 
     def normalize_constraint_factor(self, keys):