From cf48d3c3e247f841202edda8989ca730d3451fe3 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Fri, 12 Aug 2022 13:21:33 +0000
Subject: [PATCH] Dirichlet Prior: Fix reading prior that has been written to
 file

---
 bilby/core/prior/conditional.py     |  2 +-
 test/core/prior/conditional_test.py | 31 +++++++++++++++++++++++++++++
 2 files changed, 32 insertions(+), 1 deletion(-)

diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py
index 107822828..c4dcc3682 100644
--- a/bilby/core/prior/conditional.py
+++ b/bilby/core/prior/conditional.py
@@ -371,7 +371,7 @@ class DirichletElement(ConditionalBeta):
         self._required_variables = [
             label + str(ii) for ii in range(order)
         ]
-        self.__class__.__name__ = 'Dirichlet'
+        self.__class__.__name__ = 'DirichletElement'
 
     def dirichlet_condition(self, reference_parms, **kwargs):
         remaining = 1 - sum(
diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py
index 4af73bdaa..94a869936 100644
--- a/test/core/prior/conditional_test.py
+++ b/test/core/prior/conditional_test.py
@@ -1,7 +1,10 @@
+import os
+import shutil
 import unittest
 from unittest import mock
 
 import numpy as np
+import pandas as pd
 
 import bilby
 
@@ -408,5 +411,33 @@ class TestConditionalPriorDict(unittest.TestCase):
         print(res)
 
 
+class TestDirichletPrior(unittest.TestCase):
+
+    def setUp(self):
+        self.priors = bilby.core.prior.DirichletPriorDict(5)
+
+    def tearDown(self):
+        if os.path.isdir("priors"):
+            shutil.rmtree("priors")
+
+    def test_samples_sum_to_less_than_one(self):
+        """
+        Test that the samples sum to less than one as required for the
+        Dirichlet distribution.
+        """
+        samples = pd.DataFrame(self.priors.sample(10000)).values
+        self.assertLess(max(np.sum(samples, axis=1)), 1)
+
+    def test_read_write_file(self):
+        self.priors.to_file(outdir="priors", label="test")
+        test = bilby.core.prior.PriorDict(filename="priors/test.prior")
+        self.assertEqual(self.priors, test)
+
+    def test_read_write_json(self):
+        self.priors.to_json(outdir="priors", label="test")
+        test = bilby.core.prior.PriorDict.from_json(filename="priors/test_prior.json")
+        self.assertEqual(self.priors, test)
+
+
 if __name__ == "__main__":
     unittest.main()
-- 
GitLab