From 0ebe9478a45502da9cd88dc901ca5c9cdab25510 Mon Sep 17 00:00:00 2001
From: Bruce Edelman <bruce.edelman@ligo.org>
Date: Mon, 30 Mar 2020 03:22:23 -0500
Subject: [PATCH] Revert "generalized the JointPrior object structure from
 Matthew Pitkin's MutlivariateGaussian prior formalism. TODO: add in the joint
 MapPrior for HEALPix priors"

This reverts commit 85cae594d8763d1093247a081e367989a9b6ae07.
---
 bilby/core/prior/dict.py      | 28 ++++++++++++++++++++++++++--
 bilby/core/sampler/kombine.py |  4 ++--
 test/prior_test.py            | 20 ++++++++++++++++++++
 test/sampler_test.py          |  2 +-
 4 files changed, 49 insertions(+), 5 deletions(-)

diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py
index ca2eed960..a375625b3 100644
--- a/bilby/core/prior/dict.py
+++ b/bilby/core/prior/dict.py
@@ -39,6 +39,7 @@ class PriorDict(dict):
             self.from_file(filename)
         elif dictionary is not None:
             raise ValueError("PriorDict input dictionary not understood")
+        self._cached_normalizations = {}
 
         self.convert_floats_to_delta_functions()
 
@@ -383,6 +384,27 @@ class PriorDict(dict):
                            if not isinstance(self[key], Constraint)}
             return all_samples
 
+    def normalize_constraint_factor(self, keys):
+        if keys in self._cached_normalizations.keys():
+            return self._cached_normalizations[keys]
+        else:
+            min_accept = 1000
+            sampling_chunk = 5000
+            samples = self.sample_subset(keys=keys, size=sampling_chunk)
+            keep = np.atleast_1d(self.evaluate_constraints(samples))
+            if len(keep) == 1:
+                return 1
+            all_samples = {key: np.array([]) for key in keys}
+            while np.count_nonzero(keep) < min_accept:
+                samples = self.sample_subset(keys=keys, size=sampling_chunk)
+                for key in samples:
+                    all_samples[key] = np.hstack(
+                        [all_samples[key], samples[key].flatten()])
+                keep = np.array(self.evaluate_constraints(all_samples), dtype=bool)
+            factor = len(keep) / np.count_nonzero(keep)
+            self._cached_normalizations[keys] = factor
+            return factor
+
     def prob(self, sample, **kwargs):
         """
 
@@ -401,6 +423,7 @@ class PriorDict(dict):
         prob = np.product([self[key].prob(sample[key])
                            for key in sample], **kwargs)
 
+        ratio = self.normalize_constraint_factor(tuple(sample.keys()))
         if np.all(prob == 0.):
             return prob
         else:
@@ -412,7 +435,7 @@ class PriorDict(dict):
             else:
                 constrained_prob = np.zeros_like(prob)
                 keep = np.array(self.evaluate_constraints(sample), dtype=bool)
-                constrained_prob[keep] = prob[keep]
+                constrained_prob[keep] = prob[keep] * ratio
                 return constrained_prob
 
     def ln_prob(self, sample, axis=None):
@@ -434,6 +457,7 @@ class PriorDict(dict):
         ln_prob = np.sum([self[key].ln_prob(sample[key])
                           for key in sample], axis=axis)
 
+        ratio = self.normalize_constraint_factor(tuple(sample.keys()))
         if np.all(np.isinf(ln_prob)):
             return ln_prob
         else:
@@ -445,7 +469,7 @@ class PriorDict(dict):
             else:
                 constrained_ln_prob = -np.inf * np.ones_like(ln_prob)
                 keep = np.array(self.evaluate_constraints(sample), dtype=bool)
-                constrained_ln_prob[keep] = ln_prob[keep]
+                constrained_ln_prob[keep] = ln_prob[keep] + np.log(ratio)
                 return constrained_ln_prob
 
     def rescale(self, keys, theta):
diff --git a/bilby/core/sampler/kombine.py b/bilby/core/sampler/kombine.py
index cd3707051..48e85342a 100644
--- a/bilby/core/sampler/kombine.py
+++ b/bilby/core/sampler/kombine.py
@@ -160,8 +160,8 @@ class Kombine(Emcee):
         self.result.nburn = self.nburn
         if self.result.nburn > self.nsteps:
             raise SamplerError(
-                "The run has finished, but the chain is not burned in: "
-                "`nburn < nsteps`. Try increasing the number of steps.")
+                "The run has finished, but the chain is not burned in: `nburn < nsteps` ({} < {}). Try increasing the "
+                "number of steps.".format(self.result.nburn, self.nsteps))
         tmp_chain = self.sampler.chain[self.nburn:, :, :].copy()
         self.result.samples = tmp_chain.reshape((-1, self.ndim))
         blobs = np.array(self.sampler.blobs)
diff --git a/test/prior_test.py b/test/prior_test.py
index 832956928..a3c5a312e 100644
--- a/test/prior_test.py
+++ b/test/prior_test.py
@@ -907,6 +907,26 @@ class TestPriorDict(unittest.TestCase):
             self.assertFalse(self.prior_set_from_dict.test_redundancy(key=key))
 
 
+class TestConstraintPriorNormalisation(unittest.TestCase):
+    def setUp(self):
+        self.priors = dict(mass_1=bilby.core.prior.Uniform(name='mass_1', minimum=5, maximum=10, unit='$M_{\odot}$',
+                                                           boundary=None),
+                      mass_2=bilby.core.prior.Uniform(name='mass_2', minimum=5, maximum=10, unit='$M_{\odot}$',
+                                                      boundary=None),
+                      mass_ratio=bilby.core.prior.Constraint(name='mass_ratio', minimum=0, maximum=1))
+        self.priors = bilby.core.prior.PriorDict(self.priors)
+
+    def test_prob_integrate_to_one(self):
+        keys = ['mass_1', 'mass_2', 'mass_ratio']
+        n = 5000
+        samples = self.priors.sample_subset(keys=keys, size=n)
+        prob = self.priors.prob(samples, axis=0)
+        dm1 = self.priors['mass_1'].maximum - self.priors['mass_1'].minimum
+        dm2 = self.priors['mass_2'].maximum - self.priors['mass_2'].minimum
+        integral = np.sum(prob * (dm1 * dm2)) / len(samples['mass_1'])
+        self.assertAlmostEqual(1, integral, 5)
+
+
 class TestLoadPrior(unittest.TestCase):
     def test_load_prior_with_float(self):
         filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
diff --git a/test/sampler_test.py b/test/sampler_test.py
index 1bc4b5916..b1c7ccdac 100644
--- a/test/sampler_test.py
+++ b/test/sampler_test.py
@@ -563,7 +563,7 @@ class TestRunningSamplers(unittest.TestCase):
     def test_run_kombine(self):
         _ = bilby.run_sampler(
             likelihood=self.likelihood, priors=self.priors, sampler='kombine',
-            iterations=2500, nwalkers=100, save=False)
+            iterations=1000, nwalkers=100, save=False, autoburnin=True)
 
     def test_run_nestle(self):
         _ = bilby.run_sampler(
-- 
GitLab