diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index 9861bcdac4e3a35fd09b741eca501a56330252d9..7c0ed966287263afa72730d1f458a7fdd6e6847e 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -494,66 +494,61 @@ class ConditionalPriorDict(PriorDict):
         filename: str
             See parent class
         """
-        self.conditioned_keys = []
-        self.unconditioned_keys = []
+        self._conditioned_keys = []
+        self._unconditioned_keys = []
+        self._sorted_keys = []
         super(ConditionalPriorDict, self).__init__(dictionary=dictionary, filename=filename)
         self._resolve_conditions()
 
     def _resolve_conditions(self):
-        self.convert_floats_to_delta_functions()
-        self.conditioned_keys = []
-        self.unconditioned_keys = []
-        for key in self.keys():
-            if hasattr(self[key], 'condition_func'):
-                self.conditioned_keys.append(key)
-            else:
-                self.unconditioned_keys.append(key)
-
-    def sample_subset(self, keys=iter([]), size=None):
-        self.convert_floats_to_delta_functions()
-        conditional_keys = []
-        unconditional_keys = []
-        sampled_keys = []
-        for key in keys:
-            if key in self.unconditioned_keys:
-                unconditional_keys.append(key)
-            elif key in self.conditioned_keys:
-                conditional_keys.append(key)
-            else:
-                raise KeyError('Invalid key in keys argument')
-
-        samples = dict()
-        for key in unconditional_keys:
-            samples[key] = self[key].sample(size=size)
-            sampled_keys.append(key)
-
-        for i in range(0, 1000):
-            for key in conditional_keys:
-                if self._check_conditions_resolved(key, sampled_keys):
-                    cvars = self._get_conditional_variables(key)
-                    samples[key] = self[key].sample(size=size, **cvars)
-                    conditional_keys.remove(key)
-                    sampled_keys.append(key)
-            if not conditional_keys:
-                break
-            if i == 999:
-                raise Exception('This set contains unresolvable conditions')
-
-        return samples
-
-    def _get_conditional_variables(self, key):
-        conditional_variables = dict()
-        for k in self[key].required_variables:
-            conditional_variables[k] = self[k].least_recently_sampled
-        return conditional_variables
+        """ Resolves how variables depend on each other and automatically sorts them into the right order """
+        conditioned_keys_unsorted = [key for key in self.keys() if hasattr(self[key], 'condition_func')]
+        self._unconditioned_keys = [key for key in self.keys() if not hasattr(self[key], 'condition_func')]
+        self._conditioned_keys = []
+        checked_keys = self._unconditioned_keys.copy()
+        for i in range(len(self)):
+            for key in conditioned_keys_unsorted:
+                if self._check_conditions_resolved(key, checked_keys):
+                    checked_keys.append(key)
+                    self._conditioned_keys.append(key)
+                    conditioned_keys_unsorted.remove(key)
+
+        self._sorted_keys = self._unconditioned_keys.copy()
+        self._sorted_keys.extend(self.conditional_keys)
+
+        if len(conditioned_keys_unsorted) != 0:
+            raise IllegalConditionsException('This set contains unresolvable conditions')
 
     def _check_conditions_resolved(self, key, sampled_keys):
+        """ Checks if all required variables have already been sampled so we can sample this key """
         conditions_resolved = True
         for k in self[key].required_variables:
             if k not in sampled_keys:
                 conditions_resolved = False
         return conditions_resolved
 
+    @property
+    def conditional_keys(self):
+        return self._conditioned_keys
+
+    @property
+    def unconditional_keys(self):
+        return self._unconditioned_keys
+
+    def sample_subset(self, keys=iter([]), size=None):
+        self.convert_floats_to_delta_functions()
+        subset_unconditional_keys = [key for key in keys if key in self.unconditional_keys]
+        subset_conditional_keys = [key for key in keys if key in self.conditional_keys]
+        samples = {key: self[key].sample(size=size) for key in subset_unconditional_keys}
+        for key in subset_conditional_keys:
+            required_variables = self._get_required_variables(key)
+            samples[key] = self[key].sample(size=size, **required_variables)
+        return samples
+
+    def _get_required_variables(self, key):
+        """ Returns the required variables to sample a given key """
+        return {k: self[k].least_recently_sampled for k in self[key].required_variables}
+
     def prob(self, sample, **kwargs):
         """
 
@@ -571,7 +566,7 @@ class ConditionalPriorDict(PriorDict):
         """
         ls = []
         for key in sample:
-            if key in self.conditioned_keys:
+            if key in self.conditional_keys:
                 conditional_variables = dict([(k, sample[k]) for k in self[key].required_variables])
                 ls.append(self[key].prob(sample[key], **conditional_variables))
             else:
@@ -593,7 +588,7 @@ class ConditionalPriorDict(PriorDict):
         """
         ls = []
         for key in sample:
-            if key in self.conditioned_keys:
+            if key in self.conditional_keys:
                 conditional_variables = dict([(k, sample[k]) for k in self[key].required_variables])
                 ls.append(self[key].ln_prob(sample[key], **conditional_variables))
             else:
@@ -3484,3 +3479,15 @@ class ConditionalPriorException(PriorException):
 class IllegalRequiredVariablesException(ConditionalPriorException):
     """ Exception class for exceptions relating to handling the required variables. """
 
+
+class PriorDictException(Exception):
+    """ General base class for all prior dict exceptions """
+
+
+class ConditionalPriorDictException(PriorDictException):
+    """ General base class for all conditional prior dict exceptions """
+
+
+class IllegalConditionsException(ConditionalPriorDictException):
+    """ Exception class to handle prior dicts that contain unresolvable conditions. """
+
diff --git a/examples/gw_examples/injection_examples/conditional_prior.py b/examples/gw_examples/injection_examples/conditional_prior.py
index d7a855111232bd522c0db314758e1d6f92cdf582..05a2e55e22736761dd644f5cca2c634c28c1be10 100644
--- a/examples/gw_examples/injection_examples/conditional_prior.py
+++ b/examples/gw_examples/injection_examples/conditional_prior.py
@@ -33,7 +33,7 @@ plt.clf()
 
 
 plt.hist(res['mass_2'], bins='fd', alpha=0.6, density=True, label='Sampled')
-plt.xlabel('$q$')
+plt.xlabel('$m_2$')
 plt.ylabel('$p(m_2 | m_1)$')
 # plt.loglog()
 plt.legend()
@@ -86,8 +86,8 @@ likelihood = bilby.gw.GravitationalWaveTransient(
 
 # Run sampler.  In this case we're going to use the `dynesty` sampler
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000,
-    injection_parameters=injection_parameters, outdir=outdir, label=label)
+    likelihood=likelihood, priors=priors, sampler='dynesty', npoints=10,
+    injection_parameters=injection_parameters, outdir=outdir, label=label, clean=True, resume=False)
 
 # Make a corner plot.
 result.plot_corner()