diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py
index c9baa773f9d2b5789cf58261d544233a8bdd1b74..d2058d074414b14a90ebd2837c80e4287aa3cf35 100644
--- a/bilby/core/prior/joint.py
+++ b/bilby/core/prior/joint.py
@@ -570,7 +570,15 @@ class MultivariateGaussianDist(BaseJointPriorDist):
             if self.nmodes == 1:
                 mode = 0
             else:
-                mode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0]
+                if size == 1:
+                    mode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0]
+                else:
+                    # pick modes
+                    mode = [
+                        np.argwhere(self.cumweights - r > 0)[0][0]
+                        for r in np.random.rand(size)
+                    ]
+
         samps = np.zeros((size, len(self)))
         for i in range(size):
             inbound = False
@@ -578,7 +586,10 @@ class MultivariateGaussianDist(BaseJointPriorDist):
                 # sample the multivariate Gaussian keys
                 vals = np.random.uniform(0, 1, len(self))
 
-                samp = np.atleast_1d(self.rescale(vals, mode=mode))
+                if isinstance(mode, list):
+                    samp = np.atleast_1d(self.rescale(vals, mode=mode[i]))
+                else:
+                    samp = np.atleast_1d(self.rescale(vals, mode=mode))
                 samps[i, :] = samp
 
                 # check sample is in bounds (otherwise perform another draw)