From a3e2f1f64d0f8b6a1cad4b45f12d6d2680cdb46f Mon Sep 17 00:00:00 2001 From: Matthew Pitkin <matthew.pitkin@ligo.org> Date: Thu, 28 Feb 2019 09:26:29 +0000 Subject: [PATCH] Some fixes to the multivariate Gaussian prior --- bilby/core/prior.py | 51 +++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/bilby/core/prior.py b/bilby/core/prior.py index 18f58e388..276fd0e4b 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -1932,13 +1932,6 @@ class MultivariateGaussian(object): # a list of sampled parameters self.reset_sampled() - def has_sampled(self): - if (len(self.sampled_parameters) == len(self) or - len(self.current_sample) == 0): - return False - else: - return True - def reset_sampled(self): self.sampled_parameters = [] self.current_sample = {} @@ -2057,8 +2050,8 @@ class MultivariateGaussian(object): else: self.weights.append(weight) - # set the relative weights - self.relweights = np.cumsum(self.weights) / np.sum(self.weights) + # set the cumulative relative weights + self.cumweights = np.cumsum(self.weights) / np.sum(self.weights) # add the mode self.nmodes += 1 @@ -2085,20 +2078,16 @@ class MultivariateGaussian(object): distribution. """ - # pick a mode - wmode = np.random.rand() - for imode, wbin in enumerate(self.relweights): - if wmode < self.relweights[imode]: - break + # pick a mode (with a probability given by their weights) + imode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0] # draw points from unit variance, uncorrelated Gaussian samp = erfinv(2. * value - 1) * 2. ** 0.5 # rotate and scale to the multivariate normal shape - samp = self.mus[imode] + self.sigmas[imode] * (np.einsum('j,kj->k', - np.einsum('j,j->j', samp, - self.sqeigvalues[imode]), - self.eigvectors[imode])) + samp = self.mus[imode] + self.sigmas[imode] * np.einsum('j,kj->k', + samp * self.sqeigvalues[imode], + self.eigvectors[imode]) return samp @@ -2118,7 +2107,7 @@ class MultivariateGaussian(object): # check sample is in bounds (otherwise perform another draw) outbound = False for name, val in zip(self.names, samp): - if self.bounds[name][0] < val or samp > self.bounds[name][1]: + if val < self.bounds[name][0] or val > self.bounds[name][1]: outbound = True break @@ -2165,7 +2154,7 @@ class MultivariateGaussian(object): class MultivariateGaussianPrior(Prior): - def __init__(self, mvg, name, latex_label=None, unit=None): + def __init__(self, mvg, name=None, latex_label=None, unit=None): """ A prior class for a multivariate Gaussian (mixture model) prior. @@ -2191,7 +2180,7 @@ class MultivariateGaussianPrior(Prior): "Gaussian") Prior.__init__(self, name=name, latex_label=latex_label, unit=unit, - minimun=mvg.bounds[name][0], + minimum=mvg.bounds[name][0], maximum=mvg.bounds[name][1]) self.mvg = mvg @@ -2215,15 +2204,23 @@ class MultivariateGaussianPrior(Prior): Draw a sample from the prior. """ - if self.mvg.has_sampled(): - sample = self.mvg.current_sample[self.name] - self.mvg.reset_sampled() - else: + if self.name in self.mvg.sampled_parameters: + logger.warning("You have already drawn a sample from parameter " + "'{}'. The same sample will be " + "returned".format(self.name)) + + if len(self.mvg.current_sample) == 0: # generate a sample self.mvg.sample() - sample = self.mvg.current_sample[self.name] + + sample = self.mvg.current_sample[self.name] - self.mvg.sampled_parameters.append(self.name) + if self.name not in self.mvg.sampled_parameters: + self.mvg.sampled_parameters.append(self.name) + + if len(self.mvg.sampled_parameters) == len(self.mvg): + # reset samples + self.mvg.reset_sampled() return sample -- GitLab