Skip to content
Snippets Groups Projects
Commit 569edca9 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'sample_from_modes' into 'master'

Randomly pick multiple samples from multivariate Gaussian from different modes

See merge request !822
parents 342443af 19d07499
No related branches found
No related tags found
1 merge request!822Randomly pick multiple samples from multivariate Gaussian from different modes
Pipeline #140030 passed with warnings
...@@ -570,7 +570,15 @@ class MultivariateGaussianDist(BaseJointPriorDist): ...@@ -570,7 +570,15 @@ class MultivariateGaussianDist(BaseJointPriorDist):
if self.nmodes == 1: if self.nmodes == 1:
mode = 0 mode = 0
else: else:
mode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0] if size == 1:
mode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0]
else:
# pick modes
mode = [
np.argwhere(self.cumweights - r > 0)[0][0]
for r in np.random.rand(size)
]
samps = np.zeros((size, len(self))) samps = np.zeros((size, len(self)))
for i in range(size): for i in range(size):
inbound = False inbound = False
...@@ -578,7 +586,10 @@ class MultivariateGaussianDist(BaseJointPriorDist): ...@@ -578,7 +586,10 @@ class MultivariateGaussianDist(BaseJointPriorDist):
# sample the multivariate Gaussian keys # sample the multivariate Gaussian keys
vals = np.random.uniform(0, 1, len(self)) vals = np.random.uniform(0, 1, len(self))
samp = np.atleast_1d(self.rescale(vals, mode=mode)) if isinstance(mode, list):
samp = np.atleast_1d(self.rescale(vals, mode=mode[i]))
else:
samp = np.atleast_1d(self.rescale(vals, mode=mode))
samps[i, :] = samp samps[i, :] = samp
# check sample is in bounds (otherwise perform another draw) # check sample is in bounds (otherwise perform another draw)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment