diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py
index 85e8ac4ef75c8808fd20f1b175fd5214364deb1a..d25ca6487560bbb946b8930ba8c625e9d0f2d766 100644
--- a/bilby/core/prior/dict.py
+++ b/bilby/core/prior/dict.py
@@ -824,6 +824,8 @@ class ConditionalPriorDict(PriorDict):
         =======
         list: List of floats containing the rescaled sample
         """
+        from matplotlib.cbook import flatten
+
         keys = list(keys)
         theta = list(theta)
         self._check_resolved()
@@ -836,7 +838,7 @@ class ConditionalPriorDict(PriorDict):
                 theta[index], **self.get_required_variables(key)
             )
             self[key].least_recently_sampled = result[key]
-        return [result[key] for key in keys]
+        return list(flatten([result[key] for key in keys]))
 
     def _update_rescale_keys(self, keys):
         if not keys == self._least_recently_rescaled_keys:
diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py
index 742e7920b067877ce4d8aa3d2aa0f862db1dee4f..9dc35f40dd57b15a4b14c77f3395b8d1839ccac0 100644
--- a/bilby/core/prior/joint.py
+++ b/bilby/core/prior/joint.py
@@ -43,7 +43,7 @@ class BaseJointPriorDist(object):
                 if isinstance(bounds, (list, tuple, np.ndarray)):
                     if len(bound) != 2:
                         raise ValueError(
-                            "Bounds must contain an upper and " "lower value."
+                            "Bounds must contain an upper and lower value."
                         )
                     else:
                         if bound[1] <= bound[0]:
@@ -399,7 +399,7 @@ class MultivariateGaussianDist(BaseJointPriorDist):
                 if len(np.shape(sigmas)) == 1:
                     sigmas = [sigmas]
                 elif len(np.shape(sigmas)) == 0:
-                    raise ValueError("Must supply a list of standard " "deviations")
+                    raise ValueError("Must supply a list of standard deviations")
             if covs is not None:
                 if isinstance(covs, np.ndarray):
                     covs = [covs]
@@ -421,7 +421,7 @@ class MultivariateGaussianDist(BaseJointPriorDist):
                             "List of correlation coefficients the wrong shape"
                         )
                 elif not isinstance(corrcoefs, list):
-                    raise TypeError("Must pass a list of correlation " "coefficients")
+                    raise TypeError("Must pass a list of correlation coefficients")
             if weights is not None:
                 if isinstance(weights, (int, float)):
                     weights = [weights]
@@ -489,7 +489,7 @@ class MultivariateGaussianDist(BaseJointPriorDist):
 
             if len(self.corrcoefs[-1].shape) != 2:
                 raise ValueError(
-                    "Correlation coefficient matrix must be a 2d " "array."
+                    "Correlation coefficient matrix must be a 2d array."
                 )
 
             if (
@@ -497,16 +497,16 @@ class MultivariateGaussianDist(BaseJointPriorDist):
                 or self.corrcoefs[-1].shape[0] != self.num_vars
             ):
                 raise ValueError(
-                    "Correlation coefficient matrix shape is " "inconsistent"
+                    "Correlation coefficient matrix shape is inconsistent"
                 )
 
             # check matrix is symmetric
             if not np.allclose(self.corrcoefs[-1], self.corrcoefs[-1].T):
-                raise ValueError("Correlation coefficient matrix is not " "symmetric")
+                raise ValueError("Correlation coefficient matrix is not symmetric")
 
             # check diagonal is all ones
             if not np.all(np.diag(self.corrcoefs[-1]) == 1.0):
-                raise ValueError("Correlation coefficient matrix is not" "correct")
+                raise ValueError("Correlation coefficient matrix is not correct")
 
             try:
                 self.sigmas.append(list(sigmas))  # standard deviations
@@ -535,13 +535,13 @@ class MultivariateGaussianDist(BaseJointPriorDist):
             self.eigvectors.append(evecs)
         except Exception as e:
             raise RuntimeError(
-                "Problem getting eigenvalues and vectors: " "{}".format(e)
+                "Problem getting eigenvalues and vectors: {}".format(e)
             )
 
         # check eigenvalues are positive
         if np.any(self.eigvalues[-1] <= 0.0):
             raise ValueError(
-                "Correlation coefficient matrix is not positive " "definite"
+                "Correlation coefficient matrix is not positive definite"
             )
         self.sqeigvalues.append(np.sqrt(self.eigvalues[-1]))
 
diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py
index 3f7e8873169572ab0d6a096c6ddc8ad4f67078d2..4af73bdaa0c7f4f01f66b38089eb22e7fdbd73bb 100644
--- a/test/core/prior/conditional_test.py
+++ b/test/core/prior/conditional_test.py
@@ -320,6 +320,43 @@ class TestConditionalPriorDict(unittest.TestCase):
             expected.append(expected[-1] * self.test_sample[f"var_{ii}"])
         self.assertListEqual(expected, res)
 
+    def test_rescale_with_joint_prior(self):
+        """
+        Add a joint prior into the conditional prior dictionary and check that
+        the returned list is flat.
+        """
+
+        # set multivariate Gaussian distribution
+        names = ["mvgvar_0", "mvgvar_1"]
+        mu = [[0.79, -0.83]]
+        cov = [[[0.03, 0.], [0., 0.04]]]
+        mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov)
+
+        priordict = bilby.core.prior.ConditionalPriorDict(
+            dict(
+                var_3=self.prior_3,
+                var_2=self.prior_2,
+                var_0=self.prior_0,
+                var_1=self.prior_1,
+                mvgvar_0=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_0"),
+                mvgvar_1=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_1"),
+            )
+        )
+
+        ref_variables = list(self.test_sample.values()) + [0.4, 0.1]
+        keys = list(self.test_sample.keys()) + names
+        res = priordict.rescale(keys=keys, theta=ref_variables)
+
+        self.assertIsInstance(res, list)
+        self.assertEqual(np.shape(res), (6,))
+        self.assertListEqual([isinstance(r, float) for r in res], 6 * [True])
+
+        # check conditional values are still as expected
+        expected = [self.test_sample["var_0"]]
+        for ii in range(1, 4):
+            expected.append(expected[-1] * self.test_sample[f"var_{ii}"])
+        self.assertListEqual(expected, res[0:4])
+
     def test_cdf(self):
         """
         Test that the CDF method is the inverse of the rescale method.