From cf48d3c3e247f841202edda8989ca730d3451fe3 Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Fri, 12 Aug 2022 13:21:33 +0000 Subject: [PATCH] Dirichlet Prior: Fix reading prior that has been written to file --- bilby/core/prior/conditional.py | 2 +- test/core/prior/conditional_test.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py index 107822828..c4dcc3682 100644 --- a/bilby/core/prior/conditional.py +++ b/bilby/core/prior/conditional.py @@ -371,7 +371,7 @@ class DirichletElement(ConditionalBeta): self._required_variables = [ label + str(ii) for ii in range(order) ] - self.__class__.__name__ = 'Dirichlet' + self.__class__.__name__ = 'DirichletElement' def dirichlet_condition(self, reference_parms, **kwargs): remaining = 1 - sum( diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py index 4af73bdaa..94a869936 100644 --- a/test/core/prior/conditional_test.py +++ b/test/core/prior/conditional_test.py @@ -1,7 +1,10 @@ +import os +import shutil import unittest from unittest import mock import numpy as np +import pandas as pd import bilby @@ -408,5 +411,33 @@ class TestConditionalPriorDict(unittest.TestCase): print(res) +class TestDirichletPrior(unittest.TestCase): + + def setUp(self): + self.priors = bilby.core.prior.DirichletPriorDict(5) + + def tearDown(self): + if os.path.isdir("priors"): + shutil.rmtree("priors") + + def test_samples_sum_to_less_than_one(self): + """ + Test that the samples sum to less than one as required for the + Dirichlet distribution. + """ + samples = pd.DataFrame(self.priors.sample(10000)).values + self.assertLess(max(np.sum(samples, axis=1)), 1) + + def test_read_write_file(self): + self.priors.to_file(outdir="priors", label="test") + test = bilby.core.prior.PriorDict(filename="priors/test.prior") + self.assertEqual(self.priors, test) + + def test_read_write_json(self): + self.priors.to_json(outdir="priors", label="test") + test = bilby.core.prior.PriorDict.from_json(filename="priors/test_prior.json") + self.assertEqual(self.priors, test) + + if __name__ == "__main__": unittest.main() -- GitLab