diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py
index 67bdb813abc2aaad7518cd9b07721067d50455cd..7a9a8e5dcb88176b9c5622ab2082e20597974ef9 100644
--- a/bilby/core/prior/dict.py
+++ b/bilby/core/prior/dict.py
@@ -491,6 +491,21 @@ class PriorDict(dict):
                 constrained_ln_prob[keep] = ln_prob[keep] + np.log(ratio)
                 return constrained_ln_prob
 
+    def cdf(self, sample):
+        """Evaluate the cumulative distribution function at the provided points
+
+        Parameters
+        ----------
+        sample: dict, pandas.DataFrame
+            Dictionary of the samples of which to calculate the CDF
+
+        Returns
+        -------
+        dict, pandas.DataFrame: Dictionary containing the CDF values
+
+        """
+        return sample.__class__({key: self[key].cdf(sample) for key, sample in sample.items()})
+
     def rescale(self, keys, theta):
         """Rescale samples from unit cube to prior
 
@@ -654,9 +669,7 @@ class ConditionalPriorDict(PriorDict):
         float: Joint probability of all individual sample probabilities
 
         """
-        self._check_resolved()
-        for key, value in sample.items():
-            self[key].least_recently_sampled = value
+        self._prepare_evaluation(*zip(*sample.items()))
         res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample]
         return np.product(res, **kwargs)
 
@@ -675,12 +688,15 @@ class ConditionalPriorDict(PriorDict):
         float: Joint log probability of all the individual sample probabilities
 
         """
-        self._check_resolved()
-        for key, value in sample.items():
-            self[key].least_recently_sampled = value
+        self._prepare_evaluation(*zip(*sample.items()))
         res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample]
         return np.sum(res, axis=axis)
 
+    def cdf(self, sample):
+        self._prepare_evaluation(*zip(*sample.items()))
+        res = {key: self[key].cdf(sample[key], **self.get_required_variables(key)) for key in sample}
+        return sample.__class__(res)
+
     def rescale(self, keys, theta):
         """Rescale samples from unit cube to prior
 
@@ -695,12 +711,14 @@ class ConditionalPriorDict(PriorDict):
         =======
         list: List of floats containing the rescaled sample
         """
+        keys = list(keys)
+        theta = list(theta)
         self._check_resolved()
         self._update_rescale_keys(keys)
         result = dict()
         for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes):
-            required_variables = {k: result[k] for k in getattr(self[key], 'required_variables', [])}
-            result[key] = self[key].rescale(theta[index], **required_variables)
+            result[key] = self[key].rescale(theta[index], **self.get_required_variables(key))
+            self[key].least_recently_sampled = result[key]
         return [result[key] for key in keys]
 
     def _update_rescale_keys(self, keys):
@@ -708,6 +726,11 @@ class ConditionalPriorDict(PriorDict):
             self._rescale_indexes = [keys.index(element) for element in self.sorted_keys_without_fixed_parameters]
             self._least_recently_rescaled_keys = keys
 
+    def _prepare_evaluation(self, keys, theta):
+        self._check_resolved()
+        for key, value in zip(keys, theta):
+            self[key].least_recently_sampled = value
+
     def _check_resolved(self):
         if not self._resolved:
             raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py
index a76f10a4fa379d34b4fe35499a76fb4005506741..3e2c5d051cedfe93ee401424020ee5958e98d7a9 100644
--- a/test/core/prior/conditional_test.py
+++ b/test/core/prior/conditional_test.py
@@ -1,6 +1,7 @@
 import unittest
 
 import mock
+import numpy as np
 
 import bilby
 
@@ -170,13 +171,13 @@ class TestConditionalPrior(unittest.TestCase):
 class TestConditionalPriorDict(unittest.TestCase):
     def setUp(self):
         def condition_func_1(reference_parameters, var_0):
-            return reference_parameters
+            return dict(minimum=reference_parameters["minimum"], maximum=var_0)
 
         def condition_func_2(reference_parameters, var_0, var_1):
-            return reference_parameters
+            return dict(minimum=reference_parameters["minimum"], maximum=var_1)
 
         def condition_func_3(reference_parameters, var_1, var_2):
-            return reference_parameters
+            return dict(minimum=reference_parameters["minimum"], maximum=var_2)
 
         self.minimum = 0
         self.maximum = 1
@@ -203,7 +204,8 @@ class TestConditionalPriorDict(unittest.TestCase):
         self.conditional_priors_manually_set_items = (
             bilby.core.prior.ConditionalPriorDict()
         )
-        self.test_sample = dict(var_0=0.3, var_1=0.4, var_2=0.5, var_3=0.4)
+        self.test_sample = dict(var_0=0.7, var_1=0.6, var_2=0.5, var_3=0.4)
+        self.test_value = 1 / np.prod([self.test_sample[f"var_{ii}"] for ii in range(3)])
         for key, value in dict(
             var_0=self.prior_0,
             var_1=self.prior_1,
@@ -254,7 +256,7 @@ class TestConditionalPriorDict(unittest.TestCase):
         )
 
     def test_prob(self):
-        self.assertEqual(1, self.conditional_priors.prob(sample=self.test_sample))
+        self.assertEqual(self.test_value, self.conditional_priors.prob(sample=self.test_sample))
 
     def test_prob_illegal_conditions(self):
         del self.conditional_priors["var_0"]
@@ -262,7 +264,7 @@ class TestConditionalPriorDict(unittest.TestCase):
             self.conditional_priors.prob(sample=self.test_sample)
 
     def test_ln_prob(self):
-        self.assertEqual(0, self.conditional_priors.ln_prob(sample=self.test_sample))
+        self.assertEqual(np.log(self.test_value), self.conditional_priors.ln_prob(sample=self.test_sample))
 
     def test_ln_prob_illegal_conditions(self):
         del self.conditional_priors["var_0"]
@@ -273,7 +275,7 @@ class TestConditionalPriorDict(unittest.TestCase):
         with mock.patch("numpy.random.uniform") as m:
             m.return_value = 0.5
             self.assertDictEqual(
-                dict(var_0=0.5, var_1=0.5, var_2=0.5, var_3=0.5),
+                dict(var_0=0.5, var_1=0.5 ** 2, var_2=0.5 ** 3, var_3=0.5 ** 4),
                 self.conditional_priors.sample_subset(
                     keys=["var_0", "var_1", "var_2", "var_3"]
                 ),
@@ -301,31 +303,6 @@ class TestConditionalPriorDict(unittest.TestCase):
         print(priors.sample(2))
 
     def test_rescale(self):
-        def condition_func_1_rescale(reference_parameters, var_0):
-            if var_0 == 0.5:
-                return dict(minimum=reference_parameters["minimum"], maximum=1)
-            return reference_parameters
-
-        def condition_func_2_rescale(reference_parameters, var_0, var_1):
-            if var_0 == 0.5 and var_1 == 0.5:
-                return dict(minimum=reference_parameters["minimum"], maximum=1)
-            return reference_parameters
-
-        def condition_func_3_rescale(reference_parameters, var_1, var_2):
-            if var_1 == 0.5 and var_2 == 0.5:
-                return dict(minimum=reference_parameters["minimum"], maximum=1)
-            return reference_parameters
-
-        self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=1)
-        self.prior_1 = bilby.core.prior.ConditionalUniform(
-            condition_func=condition_func_1_rescale, minimum=self.minimum, maximum=2
-        )
-        self.prior_2 = bilby.core.prior.ConditionalUniform(
-            condition_func=condition_func_2_rescale, minimum=self.minimum, maximum=2
-        )
-        self.prior_3 = bilby.core.prior.ConditionalUniform(
-            condition_func=condition_func_3_rescale, minimum=self.minimum, maximum=2
-        )
         self.conditional_priors = bilby.core.prior.ConditionalPriorDict(
             dict(
                 var_3=self.prior_3,
@@ -334,11 +311,28 @@ class TestConditionalPriorDict(unittest.TestCase):
                 var_1=self.prior_1,
             )
         )
-        ref_variables = [0.5, 0.5, 0.5, 0.5]
+        ref_variables = self.test_sample.values()
         res = self.conditional_priors.rescale(
-            keys=list(self.test_sample.keys()), theta=ref_variables
+            keys=self.test_sample.keys(), theta=ref_variables
+        )
+        expected = [self.test_sample["var_0"]]
+        for ii in range(1, 4):
+            expected.append(expected[-1] * self.test_sample[f"var_{ii}"])
+        self.assertListEqual(expected, res)
+
+    def test_cdf(self):
+        """
+        Test that the CDF method is the inverse of the rescale method.
+
+        Note that the format of inputs/outputs is different between the two methods.
+        """
+        sample = self.conditional_priors.sample()
+        self.assertEqual(
+            self.conditional_priors.rescale(
+                sample.keys(),
+                self.conditional_priors.cdf(sample=sample).values()
+            ), list(sample.values())
         )
-        self.assertListEqual(ref_variables, res)
 
     def test_rescale_illegal_conditions(self):
         del self.conditional_priors["var_0"]
diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py
index 546bcd4652cb592755630a4d356cda1517f1203f..3b8487fa1f49ea9c25459c738cc16e58252a5f76 100644
--- a/test/core/prior/dict_test.py
+++ b/test/core/prior/dict_test.py
@@ -298,6 +298,20 @@ class TestPriorDict(unittest.TestCase):
             ),
         )
 
+    def test_cdf(self):
+        """
+        Test that the CDF method is the inverse of the rescale method.
+
+        Note that the format of inputs/outputs is different between the two methods.
+        """
+        sample = self.prior_set_from_dict.sample()
+        self.assertEqual(
+            self.prior_set_from_dict.rescale(
+                sample.keys(),
+                self.prior_set_from_dict.cdf(sample=sample).values()
+            ), list(sample.values())
+        )
+
     def test_redundancy(self):
         for key in self.prior_set_from_dict.keys():
             self.assertFalse(self.prior_set_from_dict.test_redundancy(key=key))