diff --git a/test/prior_test.py b/test/prior_test.py index c1a9c9fc1df745a2449d80d06c16b0e4abe77d31..5d78d919bcff173cb077387006c4fd644c76d309 100644 --- a/test/prior_test.py +++ b/test/prior_test.py @@ -7,6 +7,7 @@ import os import shutil from collections import OrderedDict + class TestPriorInstantiationWithoutOptionalPriors(unittest.TestCase): def setUp(self): @@ -301,10 +302,10 @@ class TestPriorDict(unittest.TestCase): self.first_prior = bilby.core.prior.Uniform(name='a', minimum=0, maximum=1, unit='kg') self.second_prior = bilby.core.prior.PowerLaw(name='b', alpha=3, minimum=1, maximum=2, unit='m/s') self.third_prior = bilby.core.prior.DeltaFunction(name='c', peak=42, unit='m') - self.prior_dict = dict(mass=self.first_prior, - speed=self.second_prior, - length=self.third_prior) - self.prior_set_from_dict = bilby.core.prior.PriorDict(dictionary=self.prior_dict) + self.priors = dict(mass=self.first_prior, + speed=self.second_prior, + length=self.third_prior) + self.prior_set_from_dict = bilby.core.prior.PriorDict(dictionary=self.priors) self.default_prior_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prior_files/binary_black_holes.prior') self.prior_set_from_file = bilby.core.prior.PriorDict(filename=self.default_prior_file) @@ -313,7 +314,7 @@ class TestPriorDict(unittest.TestCase): del self.first_prior del self.second_prior del self.third_prior - del self.prior_dict + del self.priors del self.prior_set_from_dict del self.default_prior_file del self.prior_set_from_file @@ -325,26 +326,7 @@ class TestPriorDict(unittest.TestCase): self.assertEqual(3, len(self.prior_set_from_dict)) def test_prior_set_has_expected_priors(self): - self.assertDictEqual(self.prior_dict, dict(self.prior_set_from_dict)) - - # Removed for now as it does not create the outdir in the Gitlab CI - # def test_write_to_file(self): - # outdir = 'prior_set_test_outdir' - # label = 'test_label' - # label_dot_prior = label + '.prior' - # current_dir_path = os.path.dirname(os.path.realpath(__file__)) - # outdir_path = os.path.join(current_dir_path, outdir) - # outfile_path = os.path.join(outdir_path, label_dot_prior) - # self.prior_set_from_dict.write_to_file(outdir=outdir, label=label) - # self.assertTrue(os.path.isdir(outdir_path)) - # self.assertTrue(os.listdir(outdir_path)[0] == label_dot_prior) - # with open(outfile_path, 'r') as f: - # expected_outfile = [ - # 'speed = PowerLaw(alpha=3, minimum=1, maximum=2, name=\'b\', latex_label=\'b\', unit=\'m/s\')\n', - # 'mass = Uniform(minimum=0, maximum=1, name=\'a\', latex_label=\'a\', unit=\'kg\')\n', - # 'length = DeltaFunction(peak=42, name=\'c\', latex_label=\'c\', unit=\'m\')\n'] - # self.assertListEqual(sorted(expected_outfile), sorted(f.readlines())) - # shutil.rmtree(outdir_path) + self.assertDictEqual(self.priors, dict(self.prior_set_from_dict)) def test_read_from_file(self): expected = dict( @@ -373,6 +355,21 @@ class TestPriorDict(unittest.TestCase): ) self.assertDictEqual(expected, self.prior_set_from_file) + def test_to_file(self): + expected = ["length = DeltaFunction(peak=42, name='c', latex_label='c', unit='m')\n", + "speed = PowerLaw(alpha=3, minimum=1, maximum=2, name='b', latex_label='b', unit='m/s')\n", + "mass = Uniform(minimum=0, maximum=1, name='a', latex_label='a', unit='kg')\n"] + self.prior_set_from_dict.to_file(outdir='prior_files', label='to_file_test') + with open('prior_files/to_file_test.prior') as f: + for i, line in enumerate(f.readlines()): + self.assertTrue(line in expected) + + def test_from_dict_with_string(self): + string_prior = "bilby.core.prior.PowerLaw(name='b', alpha=3, minimum=1, maximum=2, unit='m/s')" + self.priors['speed'] = string_prior + from_dict = bilby.core.prior.PriorDict(dictionary=self.priors) + self.assertDictEqual(self.prior_set_from_dict, from_dict) + def test_convert_floats_to_delta_functions(self): self.prior_set_from_dict['d'] = 5 self.prior_set_from_dict['e'] = 7.3