From e54810282334ba6a9257c63636198f92ef2ac8da Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Mon, 5 Jul 2021 16:45:01 +0100 Subject: [PATCH] Fix error in Categorical prior and add tests --- bilby/core/prior/analytical.py | 2 +- test/core/prior/analytical_test.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 3ec00673..e3f91667 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -1469,7 +1469,7 @@ class Categorical(Prior): ======= Union[float, array_like]: Rescaled probability """ - return np.round(val * self.maximum) + return np.floor(val * (1 + self.maximum)) def prob(self, val): """Return the prior probability of val. diff --git a/test/core/prior/analytical_test.py b/test/core/prior/analytical_test.py index 6a0bc397..7a0a34f4 100644 --- a/test/core/prior/analytical_test.py +++ b/test/core/prior/analytical_test.py @@ -15,13 +15,19 @@ class TestCategoricalPrior(unittest.TestCase): self.assertTrue(in_prior) def test_array_sample(self): - categorical_prior = bilby.core.prior.Categorical(3) - N = 1000 + ncat = 4 + categorical_prior = bilby.core.prior.Categorical(ncat) + N = 100000 s = categorical_prior.sample(N) zeros = np.sum(s == 0) ones = np.sum(s == 1) twos = np.sum(s == 2) - self.assertEqual(zeros + ones + twos, N) + threes = np.sum(s == 3) + self.assertEqual(zeros + ones + twos + threes, N) + self.assertAlmostEqual(zeros / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) + self.assertAlmostEqual(ones / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) + self.assertAlmostEqual(twos / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) + self.assertAlmostEqual(threes / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) def test_single_probability(self): N = 3 -- GitLab