diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index 107822828c52ed32111a9462c1d8f4325b143719..c4dcc36827fd844853095dcbbcafd69f2585c40a 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -371,7 +371,7 @@ class DirichletElement(ConditionalBeta): self._required_variables = [ label + str(ii) for ii in range(order) ] - self.__class__.__name__ = 'Dirichlet' + self.__class__.__name__ = 'DirichletElement' def dirichlet_condition(self, reference_parms, **kwargs): remaining = 1 - sum( diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 4af73bdaa0c7f4f01f66b38089eb22e7fdbd73bb..94a869936fc6cd71cc3647e51a5a372e08ed4544 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -1,7 +1,10 @@ +import os +import shutil import unittest from unittest import mock import numpy as np +import pandas as pd import bilby @@ -408,5 +411,33 @@ class TestConditionalPriorDict(unittest.TestCase): print(res) +class TestDirichletPrior(unittest.TestCase): + + def setUp(self): + self.priors = bilby.core.prior.DirichletPriorDict(5) + + def tearDown(self): + if os.path.isdir("priors"): + shutil.rmtree("priors") + + def test_samples_sum_to_less_than_one(self): + """ + Test that the samples sum to less than one as required for the + Dirichlet distribution. + """ + samples = pd.DataFrame(self.priors.sample(10000)).values + self.assertLess(max(np.sum(samples, axis=1)), 1) + + def test_read_write_file(self): + self.priors.to_file(outdir="priors", label="test") + test = bilby.core.prior.PriorDict(filename="priors/test.prior") + self.assertEqual(self.priors, test) + + def test_read_write_json(self): + self.priors.to_json(outdir="priors", label="test") + test = bilby.core.prior.PriorDict.from_json(filename="priors/test_prior.json") + self.assertEqual(self.priors, test) + + if __name__ == "__main__": unittest.main()