diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index c9baa773f9d2b5789cf58261d544233a8bdd1b74..d2058d074414b14a90ebd2837c80e4287aa3cf35 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -570,7 +570,15 @@ class MultivariateGaussianDist(BaseJointPriorDist): if self.nmodes == 1: mode = 0 else: - mode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0] + if size == 1: + mode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0] + else: + # pick modes + mode = [ + np.argwhere(self.cumweights - r > 0)[0][0] + for r in np.random.rand(size) + ] + samps = np.zeros((size, len(self))) for i in range(size): inbound = False @@ -578,7 +586,10 @@ class MultivariateGaussianDist(BaseJointPriorDist): # sample the multivariate Gaussian keys vals = np.random.uniform(0, 1, len(self)) - samp = np.atleast_1d(self.rescale(vals, mode=mode)) + if isinstance(mode, list): + samp = np.atleast_1d(self.rescale(vals, mode=mode[i])) + else: + samp = np.atleast_1d(self.rescale(vals, mode=mode)) samps[i, :] = samp # check sample is in bounds (otherwise perform another draw)