Skip to content
Snippets Groups Projects
Commit a97d1cb9 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

eliminated for loops in tests for better diagnostics

parent 15754aaf
No related branches found
No related tags found
1 merge request!932Fix pickle
Pipeline #222260 passed
......@@ -192,55 +192,77 @@ class TestResult(unittest.TestCase):
with self.assertRaises(ValueError):
_ = self.result.posterior
def test_save_and_load(self):
gzips = [True, False, False, False]
extensions = ['json', 'json', 'pkl', 'hdf5']
for gzip, extension in zip(gzips, extensions):
self.result.save_to_file(extension=extension, gzip=gzip)
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label, extension=extension, gzip=gzip
)
self.assertTrue(
np.array_equal(
self.result.posterior.sort_values(by=["x"]),
loaded_result.posterior.sort_values(by=["x"]),
)
)
self.assertTrue(
self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
)
self.assertTrue(
self.result.search_parameter_keys == loaded_result.search_parameter_keys
)
self.assertEqual(self.result.meta_data, loaded_result.meta_data)
self.assertEqual(
self.result.injection_parameters, loaded_result.injection_parameters
)
self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
self.assertEqual(
self.result.log_noise_evidence, loaded_result.log_noise_evidence
def test_save_and_load_json(self):
self._save_and_load_test(extension='json')
def test_save_and_load_json_gzip(self):
self._save_and_load_test(extension='json', gzip=True)
def test_save_and_load_pkl(self):
self._save_and_load_test(extension='pkl')
def test_save_and_load_hdf5(self):
self._save_and_load_test(extension='hdf5')
def _save_and_load_test(self, extension, gzip=False):
self.result.save_to_file(extension=extension, gzip=gzip)
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label, extension=extension, gzip=gzip
)
self.assertTrue(
np.array_equal(
self.result.posterior.sort_values(by=["x"]),
loaded_result.posterior.sort_values(by=["x"]),
)
self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
def test_save_and_dont_overwrite(self):
extensions = ['json', 'pkl', 'hdf5']
for extension in extensions:
print(extension)
self.result.save_to_file(overwrite=False, extension=extension)
self.result.save_to_file(overwrite=False, extension=extension)
self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
def test_save_and_overwrite_default(self):
extensions = ['json', 'pkl', 'hdf5']
for extension in extensions:
self.result.save_to_file(overwrite=True, extension=extension)
self.result.save_to_file(overwrite=True, extension=extension)
self.assertFalse(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
)
self.assertTrue(
self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
)
self.assertTrue(
self.result.search_parameter_keys == loaded_result.search_parameter_keys
)
self.assertEqual(self.result.meta_data, loaded_result.meta_data)
self.assertEqual(
self.result.injection_parameters, loaded_result.injection_parameters
)
self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
self.assertEqual(
self.result.log_noise_evidence, loaded_result.log_noise_evidence
)
self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
def test_save_and_dont_overwrite_json(self):
self._save_and_dont_overwrite_test(extension='json')
def test_save_and_dont_overwrite_pkl(self):
self._save_and_dont_overwrite_test(extension='pkl')
def test_save_and_dont_overwrite_hdf5(self):
self._save_and_dont_overwrite_test(extension='hdf5')
def _save_and_dont_overwrite_test(self, extension):
self.result.save_to_file(overwrite=False, extension=extension)
self.result.save_to_file(overwrite=False, extension=extension)
self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
def test_save_and_overwrite_json(self):
self._save_and_overwrite_test(extension='json')
def test_save_and_overwrite_pkl(self):
self._save_and_overwrite_test(extension='pkl')
def test_save_and_overwrite_hdf5(self):
self._save_and_overwrite_test(extension='hdf5')
def _save_and_overwrite_test(self, extension):
self.result.save_to_file(overwrite=True, extension=extension)
self.result.save_to_file(overwrite=True, extension=extension)
self.assertFalse(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
def test_save_samples(self):
self.result.save_posterior_samples()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment