Skip to content
Snippets Groups Projects
Commit de549e36 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch 'fix-categorial-prior' into 'master'

Fix error in Categorical prior and add tests

See merge request lscsoft/bilby!990
parents 7c74e2d4 e5481028
No related branches found
No related tags found
No related merge requests found
......@@ -1469,7 +1469,7 @@ class Categorical(Prior):
=======
Union[float, array_like]: Rescaled probability
"""
return np.round(val * self.maximum)
return np.floor(val * (1 + self.maximum))
def prob(self, val):
"""Return the prior probability of val.
......
......@@ -15,13 +15,19 @@ class TestCategoricalPrior(unittest.TestCase):
self.assertTrue(in_prior)
def test_array_sample(self):
categorical_prior = bilby.core.prior.Categorical(3)
N = 1000
ncat = 4
categorical_prior = bilby.core.prior.Categorical(ncat)
N = 100000
s = categorical_prior.sample(N)
zeros = np.sum(s == 0)
ones = np.sum(s == 1)
twos = np.sum(s == 2)
self.assertEqual(zeros + ones + twos, N)
threes = np.sum(s == 3)
self.assertEqual(zeros + ones + twos + threes, N)
self.assertAlmostEqual(zeros / N, 1 / ncat, places=int(np.log10(np.sqrt(N))))
self.assertAlmostEqual(ones / N, 1 / ncat, places=int(np.log10(np.sqrt(N))))
self.assertAlmostEqual(twos / N, 1 / ncat, places=int(np.log10(np.sqrt(N))))
self.assertAlmostEqual(threes / N, 1 / ncat, places=int(np.log10(np.sqrt(N))))
def test_single_probability(self):
N = 3
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment