From a3e2f1f64d0f8b6a1cad4b45f12d6d2680cdb46f Mon Sep 17 00:00:00 2001
From: Matthew Pitkin <matthew.pitkin@ligo.org>
Date: Thu, 28 Feb 2019 09:26:29 +0000
Subject: [PATCH] Some fixes to the multivariate Gaussian prior

---
 bilby/core/prior.py | 51 +++++++++++++++++++++------------------------
 1 file changed, 24 insertions(+), 27 deletions(-)

diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index 18f58e388..276fd0e4b 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -1932,13 +1932,6 @@ class MultivariateGaussian(object):
         # a list of sampled parameters
         self.reset_sampled()
 
-    def has_sampled(self):
-        if (len(self.sampled_parameters) == len(self) or
-                len(self.current_sample) == 0):
-            return False
-        else:
-            return True
-
     def reset_sampled(self):
         self.sampled_parameters = []
         self.current_sample = {}
@@ -2057,8 +2050,8 @@ class MultivariateGaussian(object):
         else:
             self.weights.append(weight)
 
-        # set the relative weights
-        self.relweights = np.cumsum(self.weights) / np.sum(self.weights)
+        # set the cumulative relative weights
+        self.cumweights = np.cumsum(self.weights) / np.sum(self.weights)
 
         # add the mode
         self.nmodes += 1
@@ -2085,20 +2078,16 @@ class MultivariateGaussian(object):
             distribution.
         """
 
-        # pick a mode
-        wmode = np.random.rand()
-        for imode, wbin in enumerate(self.relweights):
-            if wmode < self.relweights[imode]:
-                break
+        # pick a mode (with a probability given by their weights)
+        imode = np.argwhere(self.cumweights - np.random.rand() > 0)[0][0]
 
         # draw points from unit variance, uncorrelated Gaussian
         samp = erfinv(2. * value - 1) * 2. ** 0.5
 
         # rotate and scale to the multivariate normal shape
-        samp = self.mus[imode] + self.sigmas[imode] * (np.einsum('j,kj->k',
-                                                       np.einsum('j,j->j', samp,
-                                                                 self.sqeigvalues[imode]),
-                                                       self.eigvectors[imode]))
+        samp = self.mus[imode] + self.sigmas[imode] * np.einsum('j,kj->k',
+                                                      samp * self.sqeigvalues[imode],
+                                                      self.eigvectors[imode])
 
         return samp
 
@@ -2118,7 +2107,7 @@ class MultivariateGaussian(object):
             # check sample is in bounds (otherwise perform another draw)
             outbound = False
             for name, val in zip(self.names, samp):
-                if self.bounds[name][0] < val or samp > self.bounds[name][1]:
+                if val < self.bounds[name][0] or val > self.bounds[name][1]:
                     outbound = True
                     break
 
@@ -2165,7 +2154,7 @@ class MultivariateGaussian(object):
 
 class MultivariateGaussianPrior(Prior):
 
-    def __init__(self, mvg, name, latex_label=None, unit=None):
+    def __init__(self, mvg, name=None, latex_label=None, unit=None):
         """
         A prior class for a multivariate Gaussian (mixture model) prior.
 
@@ -2191,7 +2180,7 @@ class MultivariateGaussianPrior(Prior):
                              "Gaussian")
 
         Prior.__init__(self, name=name, latex_label=latex_label, unit=unit,
-                       minimun=mvg.bounds[name][0],
+                       minimum=mvg.bounds[name][0],
                        maximum=mvg.bounds[name][1])
         self.mvg = mvg
 
@@ -2215,15 +2204,23 @@ class MultivariateGaussianPrior(Prior):
         Draw a sample from the prior.
         """
 
-        if self.mvg.has_sampled():
-            sample = self.mvg.current_sample[self.name]
-            self.mvg.reset_sampled()
-        else:
+        if self.name in self.mvg.sampled_parameters:
+            logger.warning("You have already drawn a sample from parameter "
+                           "'{}'. The same sample will be "
+                           "returned".format(self.name))
+
+        if len(self.mvg.current_sample) == 0:
             # generate a sample
             self.mvg.sample()
-            sample = self.mvg.current_sample[self.name]
+        
+        sample = self.mvg.current_sample[self.name]
 
-        self.mvg.sampled_parameters.append(self.name)
+        if self.name not in self.mvg.sampled_parameters:
+            self.mvg.sampled_parameters.append(self.name)
+
+        if len(self.mvg.sampled_parameters) == len(self.mvg):
+            # reset samples
+            self.mvg.reset_sampled()
 
         return sample
 
-- 
GitLab