diff --git a/bilby/core/prior/analytical.py b/bilby/core/prior/analytical.py index 3ec00673c1e63f9bd1798093d535549f4519aae8..e3f916674ba84a8bc4e379b5b242cae8b01402d9 100644 --- a/bilby/core/prior/analytical.py +++ b/bilby/core/prior/analytical.py @@ -1469,7 +1469,7 @@ class Categorical(Prior): ======= Union[float, array_like]: Rescaled probability """ - return np.round(val * self.maximum) + return np.floor(val * (1 + self.maximum)) def prob(self, val): """Return the prior probability of val. diff --git a/test/core/prior/analytical_test.py b/test/core/prior/analytical_test.py index 6a0bc397e14590755487e40608def8360e24bf3b..7a0a34f4ae801550dc125bdfa9c444c00eac556c 100644 --- a/test/core/prior/analytical_test.py +++ b/test/core/prior/analytical_test.py @@ -15,13 +15,19 @@ class TestCategoricalPrior(unittest.TestCase): self.assertTrue(in_prior) def test_array_sample(self): - categorical_prior = bilby.core.prior.Categorical(3) - N = 1000 + ncat = 4 + categorical_prior = bilby.core.prior.Categorical(ncat) + N = 100000 s = categorical_prior.sample(N) zeros = np.sum(s == 0) ones = np.sum(s == 1) twos = np.sum(s == 2) - self.assertEqual(zeros + ones + twos, N) + threes = np.sum(s == 3) + self.assertEqual(zeros + ones + twos + threes, N) + self.assertAlmostEqual(zeros / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) + self.assertAlmostEqual(ones / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) + self.assertAlmostEqual(twos / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) + self.assertAlmostEqual(threes / N, 1 / ncat, places=int(np.log10(np.sqrt(N)))) def test_single_probability(self): N = 3