From e54810282334ba6a9257c63636198f92ef2ac8da Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Mon, 5 Jul 2021 16:45:01 +0100
Subject: [PATCH] Fix error in Categorical prior and add tests

---
 bilby/core/prior/analytical.py     |  2 +-
 test/core/prior/analytical_test.py | 12 +++++++++---
 2 files changed, 10 insertions(+), 4 deletions(-)

diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py
index 3ec00673..e3f91667 100644
--- a/bilby/core/prior/analytical.py
+++ b/bilby/core/prior/analytical.py
@@ -1469,7 +1469,7 @@ class Categorical(Prior):
         =======
         Union[float, array_like]: Rescaled probability
         """
-        return np.round(val * self.maximum)
+        return np.floor(val * (1 + self.maximum))
 
     def prob(self, val):
         """Return the prior probability of val.
diff --git a/test/core/prior/analytical_test.py b/test/core/prior/analytical_test.py
index 6a0bc397..7a0a34f4 100644
--- a/test/core/prior/analytical_test.py
+++ b/test/core/prior/analytical_test.py
@@ -15,13 +15,19 @@ class TestCategoricalPrior(unittest.TestCase):
         self.assertTrue(in_prior)
 
     def test_array_sample(self):
-        categorical_prior = bilby.core.prior.Categorical(3)
-        N = 1000
+        ncat = 4
+        categorical_prior = bilby.core.prior.Categorical(ncat)
+        N = 100000
         s = categorical_prior.sample(N)
         zeros = np.sum(s == 0)
         ones = np.sum(s == 1)
         twos = np.sum(s == 2)
-        self.assertEqual(zeros + ones + twos, N)
+        threes = np.sum(s == 3)
+        self.assertEqual(zeros + ones + twos + threes, N)
+        self.assertAlmostEqual(zeros / N, 1 / ncat, places=int(np.log10(np.sqrt(N))))
+        self.assertAlmostEqual(ones / N, 1 / ncat, places=int(np.log10(np.sqrt(N))))
+        self.assertAlmostEqual(twos / N, 1 / ncat, places=int(np.log10(np.sqrt(N))))
+        self.assertAlmostEqual(threes / N, 1 / ncat, places=int(np.log10(np.sqrt(N))))
 
     def test_single_probability(self):
         N = 3
-- 
GitLab