diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py
index 107822828c52ed32111a9462c1d8f4325b143719..c4dcc36827fd844853095dcbbcafd69f2585c40a 100644
--- a/bilby/core/prior/conditional.py
+++ b/bilby/core/prior/conditional.py
@@ -371,7 +371,7 @@ class DirichletElement(ConditionalBeta):
         self._required_variables = [
             label + str(ii) for ii in range(order)
         ]
-        self.__class__.__name__ = 'Dirichlet'
+        self.__class__.__name__ = 'DirichletElement'
 
     def dirichlet_condition(self, reference_parms, **kwargs):
         remaining = 1 - sum(
diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py
index 4af73bdaa0c7f4f01f66b38089eb22e7fdbd73bb..94a869936fc6cd71cc3647e51a5a372e08ed4544 100644
--- a/test/core/prior/conditional_test.py
+++ b/test/core/prior/conditional_test.py
@@ -1,7 +1,10 @@
+import os
+import shutil
 import unittest
 from unittest import mock
 
 import numpy as np
+import pandas as pd
 
 import bilby
 
@@ -408,5 +411,33 @@ class TestConditionalPriorDict(unittest.TestCase):
         print(res)
 
 
+class TestDirichletPrior(unittest.TestCase):
+
+    def setUp(self):
+        self.priors = bilby.core.prior.DirichletPriorDict(5)
+
+    def tearDown(self):
+        if os.path.isdir("priors"):
+            shutil.rmtree("priors")
+
+    def test_samples_sum_to_less_than_one(self):
+        """
+        Test that the samples sum to less than one as required for the
+        Dirichlet distribution.
+        """
+        samples = pd.DataFrame(self.priors.sample(10000)).values
+        self.assertLess(max(np.sum(samples, axis=1)), 1)
+
+    def test_read_write_file(self):
+        self.priors.to_file(outdir="priors", label="test")
+        test = bilby.core.prior.PriorDict(filename="priors/test.prior")
+        self.assertEqual(self.priors, test)
+
+    def test_read_write_json(self):
+        self.priors.to_json(outdir="priors", label="test")
+        test = bilby.core.prior.PriorDict.from_json(filename="priors/test_prior.json")
+        self.assertEqual(self.priors, test)
+
+
 if __name__ == "__main__":
     unittest.main()