From a97d1cb9fd59e4b843eabd558844bd8831595c82 Mon Sep 17 00:00:00 2001
From: Moritz Huebner <email@moritz-huebner.de>
Date: Thu, 29 Apr 2021 00:20:09 +1000
Subject: [PATCH] eliminated for loops in tests for better diagnostics

---
 test/core/result_test.py | 118 +++++++++++++++++++++++----------------
 1 file changed, 70 insertions(+), 48 deletions(-)

diff --git a/test/core/result_test.py b/test/core/result_test.py
index ebfdccdd7..b2a7c24c1 100644
--- a/test/core/result_test.py
+++ b/test/core/result_test.py
@@ -192,55 +192,77 @@ class TestResult(unittest.TestCase):
         with self.assertRaises(ValueError):
             _ = self.result.posterior
 
-    def test_save_and_load(self):
-        gzips = [True, False, False, False]
-        extensions = ['json', 'json', 'pkl', 'hdf5']
-        for gzip, extension in zip(gzips, extensions):
-            self.result.save_to_file(extension=extension, gzip=gzip)
-            loaded_result = bilby.core.result.read_in_result(
-                outdir=self.result.outdir, label=self.result.label, extension=extension, gzip=gzip
-            )
-            self.assertTrue(
-                np.array_equal(
-                    self.result.posterior.sort_values(by=["x"]),
-                    loaded_result.posterior.sort_values(by=["x"]),
-                )
-            )
-            self.assertTrue(
-                self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
-            )
-            self.assertTrue(
-                self.result.search_parameter_keys == loaded_result.search_parameter_keys
-            )
-            self.assertEqual(self.result.meta_data, loaded_result.meta_data)
-            self.assertEqual(
-                self.result.injection_parameters, loaded_result.injection_parameters
-            )
-            self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
-            self.assertEqual(
-                self.result.log_noise_evidence, loaded_result.log_noise_evidence
+    def test_save_and_load_json(self):
+        self._save_and_load_test(extension='json')
+
+    def test_save_and_load_json_gzip(self):
+        self._save_and_load_test(extension='json', gzip=True)
+
+    def test_save_and_load_pkl(self):
+        self._save_and_load_test(extension='pkl')
+
+    def test_save_and_load_hdf5(self):
+        self._save_and_load_test(extension='hdf5')
+
+    def _save_and_load_test(self, extension, gzip=False):
+        self.result.save_to_file(extension=extension, gzip=gzip)
+        loaded_result = bilby.core.result.read_in_result(
+            outdir=self.result.outdir, label=self.result.label, extension=extension, gzip=gzip
+        )
+        self.assertTrue(
+            np.array_equal(
+                self.result.posterior.sort_values(by=["x"]),
+                loaded_result.posterior.sort_values(by=["x"]),
             )
-            self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
-            self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
-            self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
-            self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
-            self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
-            self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
-
-    def test_save_and_dont_overwrite(self):
-        extensions = ['json', 'pkl', 'hdf5']
-        for extension in extensions:
-            print(extension)
-            self.result.save_to_file(overwrite=False, extension=extension)
-            self.result.save_to_file(overwrite=False, extension=extension)
-            self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
-
-    def test_save_and_overwrite_default(self):
-        extensions = ['json', 'pkl', 'hdf5']
-        for extension in extensions:
-            self.result.save_to_file(overwrite=True, extension=extension)
-            self.result.save_to_file(overwrite=True, extension=extension)
-            self.assertFalse(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
+        )
+        self.assertTrue(
+            self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
+        )
+        self.assertTrue(
+            self.result.search_parameter_keys == loaded_result.search_parameter_keys
+        )
+        self.assertEqual(self.result.meta_data, loaded_result.meta_data)
+        self.assertEqual(
+            self.result.injection_parameters, loaded_result.injection_parameters
+        )
+        self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
+        self.assertEqual(
+            self.result.log_noise_evidence, loaded_result.log_noise_evidence
+        )
+        self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
+        self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
+        self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
+        self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
+        self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
+        self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
+
+    def test_save_and_dont_overwrite_json(self):
+        self._save_and_dont_overwrite_test(extension='json')
+
+    def test_save_and_dont_overwrite_pkl(self):
+        self._save_and_dont_overwrite_test(extension='pkl')
+
+    def test_save_and_dont_overwrite_hdf5(self):
+        self._save_and_dont_overwrite_test(extension='hdf5')
+
+    def _save_and_dont_overwrite_test(self, extension):
+        self.result.save_to_file(overwrite=False, extension=extension)
+        self.result.save_to_file(overwrite=False, extension=extension)
+        self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
+
+    def test_save_and_overwrite_json(self):
+        self._save_and_overwrite_test(extension='json')
+
+    def test_save_and_overwrite_pkl(self):
+        self._save_and_overwrite_test(extension='pkl')
+
+    def test_save_and_overwrite_hdf5(self):
+        self._save_and_overwrite_test(extension='hdf5')
+
+    def _save_and_overwrite_test(self, extension):
+        self.result.save_to_file(overwrite=True, extension=extension)
+        self.result.save_to_file(overwrite=True, extension=extension)
+        self.assertFalse(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
 
     def test_save_samples(self):
         self.result.save_posterior_samples()
-- 
GitLab