From a97d1cb9fd59e4b843eabd558844bd8831595c82 Mon Sep 17 00:00:00 2001 From: Moritz Huebner <email@moritz-huebner.de> Date: Thu, 29 Apr 2021 00:20:09 +1000 Subject: [PATCH] eliminated for loops in tests for better diagnostics --- test/core/result_test.py | 118 +++++++++++++++++++++++---------------- 1 file changed, 70 insertions(+), 48 deletions(-) diff --git a/test/core/result_test.py b/test/core/result_test.py index ebfdccdd7..b2a7c24c1 100644 --- a/test/core/result_test.py +++ b/test/core/result_test.py @@ -192,55 +192,77 @@ class TestResult(unittest.TestCase): with self.assertRaises(ValueError): _ = self.result.posterior - def test_save_and_load(self): - gzips = [True, False, False, False] - extensions = ['json', 'json', 'pkl', 'hdf5'] - for gzip, extension in zip(gzips, extensions): - self.result.save_to_file(extension=extension, gzip=gzip) - loaded_result = bilby.core.result.read_in_result( - outdir=self.result.outdir, label=self.result.label, extension=extension, gzip=gzip - ) - self.assertTrue( - np.array_equal( - self.result.posterior.sort_values(by=["x"]), - loaded_result.posterior.sort_values(by=["x"]), - ) - ) - self.assertTrue( - self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys - ) - self.assertTrue( - self.result.search_parameter_keys == loaded_result.search_parameter_keys - ) - self.assertEqual(self.result.meta_data, loaded_result.meta_data) - self.assertEqual( - self.result.injection_parameters, loaded_result.injection_parameters - ) - self.assertEqual(self.result.log_evidence, loaded_result.log_evidence) - self.assertEqual( - self.result.log_noise_evidence, loaded_result.log_noise_evidence + def test_save_and_load_json(self): + self._save_and_load_test(extension='json') + + def test_save_and_load_json_gzip(self): + self._save_and_load_test(extension='json', gzip=True) + + def test_save_and_load_pkl(self): + self._save_and_load_test(extension='pkl') + + def test_save_and_load_hdf5(self): + self._save_and_load_test(extension='hdf5') + + def _save_and_load_test(self, extension, gzip=False): + self.result.save_to_file(extension=extension, gzip=gzip) + loaded_result = bilby.core.result.read_in_result( + outdir=self.result.outdir, label=self.result.label, extension=extension, gzip=gzip + ) + self.assertTrue( + np.array_equal( + self.result.posterior.sort_values(by=["x"]), + loaded_result.posterior.sort_values(by=["x"]), ) - self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err) - self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor) - self.assertEqual(self.result.priors["x"], loaded_result.priors["x"]) - self.assertEqual(self.result.priors["y"], loaded_result.priors["y"]) - self.assertEqual(self.result.priors["c"], loaded_result.priors["c"]) - self.assertEqual(self.result.priors["d"], loaded_result.priors["d"]) - - def test_save_and_dont_overwrite(self): - extensions = ['json', 'pkl', 'hdf5'] - for extension in extensions: - print(extension) - self.result.save_to_file(overwrite=False, extension=extension) - self.result.save_to_file(overwrite=False, extension=extension) - self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old")) - - def test_save_and_overwrite_default(self): - extensions = ['json', 'pkl', 'hdf5'] - for extension in extensions: - self.result.save_to_file(overwrite=True, extension=extension) - self.result.save_to_file(overwrite=True, extension=extension) - self.assertFalse(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old")) + ) + self.assertTrue( + self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys + ) + self.assertTrue( + self.result.search_parameter_keys == loaded_result.search_parameter_keys + ) + self.assertEqual(self.result.meta_data, loaded_result.meta_data) + self.assertEqual( + self.result.injection_parameters, loaded_result.injection_parameters + ) + self.assertEqual(self.result.log_evidence, loaded_result.log_evidence) + self.assertEqual( + self.result.log_noise_evidence, loaded_result.log_noise_evidence + ) + self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err) + self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor) + self.assertEqual(self.result.priors["x"], loaded_result.priors["x"]) + self.assertEqual(self.result.priors["y"], loaded_result.priors["y"]) + self.assertEqual(self.result.priors["c"], loaded_result.priors["c"]) + self.assertEqual(self.result.priors["d"], loaded_result.priors["d"]) + + def test_save_and_dont_overwrite_json(self): + self._save_and_dont_overwrite_test(extension='json') + + def test_save_and_dont_overwrite_pkl(self): + self._save_and_dont_overwrite_test(extension='pkl') + + def test_save_and_dont_overwrite_hdf5(self): + self._save_and_dont_overwrite_test(extension='hdf5') + + def _save_and_dont_overwrite_test(self, extension): + self.result.save_to_file(overwrite=False, extension=extension) + self.result.save_to_file(overwrite=False, extension=extension) + self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old")) + + def test_save_and_overwrite_json(self): + self._save_and_overwrite_test(extension='json') + + def test_save_and_overwrite_pkl(self): + self._save_and_overwrite_test(extension='pkl') + + def test_save_and_overwrite_hdf5(self): + self._save_and_overwrite_test(extension='hdf5') + + def _save_and_overwrite_test(self, extension): + self.result.save_to_file(overwrite=True, extension=extension) + self.result.save_to_file(overwrite=True, extension=extension) + self.assertFalse(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old")) def test_save_samples(self): self.result.save_posterior_samples() -- GitLab