Skip to content
Snippets Groups Projects
Commit cf48d3c3 authored by Colm Talbot's avatar Colm Talbot
Browse files

Dirichlet Prior: Fix reading prior that has been written to file

parent 907dbf65
No related branches found
No related tags found
1 merge request!1128Dirichlet Prior: Fix reading prior that has been written to file
...@@ -371,7 +371,7 @@ class DirichletElement(ConditionalBeta): ...@@ -371,7 +371,7 @@ class DirichletElement(ConditionalBeta):
self._required_variables = [ self._required_variables = [
label + str(ii) for ii in range(order) label + str(ii) for ii in range(order)
] ]
self.__class__.__name__ = 'Dirichlet' self.__class__.__name__ = 'DirichletElement'
def dirichlet_condition(self, reference_parms, **kwargs): def dirichlet_condition(self, reference_parms, **kwargs):
remaining = 1 - sum( remaining = 1 - sum(
......
import os
import shutil
import unittest import unittest
from unittest import mock from unittest import mock
import numpy as np import numpy as np
import pandas as pd
import bilby import bilby
...@@ -408,5 +411,33 @@ class TestConditionalPriorDict(unittest.TestCase): ...@@ -408,5 +411,33 @@ class TestConditionalPriorDict(unittest.TestCase):
print(res) print(res)
class TestDirichletPrior(unittest.TestCase):
def setUp(self):
self.priors = bilby.core.prior.DirichletPriorDict(5)
def tearDown(self):
if os.path.isdir("priors"):
shutil.rmtree("priors")
def test_samples_sum_to_less_than_one(self):
"""
Test that the samples sum to less than one as required for the
Dirichlet distribution.
"""
samples = pd.DataFrame(self.priors.sample(10000)).values
self.assertLess(max(np.sum(samples, axis=1)), 1)
def test_read_write_file(self):
self.priors.to_file(outdir="priors", label="test")
test = bilby.core.prior.PriorDict(filename="priors/test.prior")
self.assertEqual(self.priors, test)
def test_read_write_json(self):
self.priors.to_json(outdir="priors", label="test")
test = bilby.core.prior.PriorDict.from_json(filename="priors/test_prior.json")
self.assertEqual(self.priors, test)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment