Skip to content
Snippets Groups Projects
Commit cf48d3c3 authored by Colm Talbot's avatar Colm Talbot
Browse files

Dirichlet Prior: Fix reading prior that has been written to file

parent 907dbf65
No related branches found
No related tags found
1 merge request!1128Dirichlet Prior: Fix reading prior that has been written to file
......@@ -371,7 +371,7 @@ class DirichletElement(ConditionalBeta):
self._required_variables = [
label + str(ii) for ii in range(order)
]
self.__class__.__name__ = 'Dirichlet'
self.__class__.__name__ = 'DirichletElement'
def dirichlet_condition(self, reference_parms, **kwargs):
remaining = 1 - sum(
......
import os
import shutil
import unittest
from unittest import mock
import numpy as np
import pandas as pd
import bilby
......@@ -408,5 +411,33 @@ class TestConditionalPriorDict(unittest.TestCase):
print(res)
class TestDirichletPrior(unittest.TestCase):
def setUp(self):
self.priors = bilby.core.prior.DirichletPriorDict(5)
def tearDown(self):
if os.path.isdir("priors"):
shutil.rmtree("priors")
def test_samples_sum_to_less_than_one(self):
"""
Test that the samples sum to less than one as required for the
Dirichlet distribution.
"""
samples = pd.DataFrame(self.priors.sample(10000)).values
self.assertLess(max(np.sum(samples, axis=1)), 1)
def test_read_write_file(self):
self.priors.to_file(outdir="priors", label="test")
test = bilby.core.prior.PriorDict(filename="priors/test.prior")
self.assertEqual(self.priors, test)
def test_read_write_json(self):
self.priors.to_json(outdir="priors", label="test")
test = bilby.core.prior.PriorDict.from_json(filename="priors/test_prior.json")
self.assertEqual(self.priors, test)
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment