From 18707791e5704bc010e105bb4a013c85d3b798b0 Mon Sep 17 00:00:00 2001
From: Moritz Huebner <email@moritz-huebner.de>
Date: Mon, 26 Apr 2021 14:03:00 +1000
Subject: [PATCH] Simplified tests and added pickle

---
 test/core/result_test.py | 209 +++++++++++----------------------------
 1 file changed, 59 insertions(+), 150 deletions(-)

diff --git a/test/core/result_test.py b/test/core/result_test.py
index 3f84f703f..ad46c8b81 100644
--- a/test/core/result_test.py
+++ b/test/core/result_test.py
@@ -108,6 +108,22 @@ class TestResult(unittest.TestCase):
             "{}/{}_result.hdf5".format(outdir, label),
         )
 
+    def test_result_file_name_pkl(self):
+        outdir = "outdir"
+        label = "label"
+        self.assertEqual(
+            bilby.core.result.result_file_name(outdir, label, extension="pkl"),
+            "{}/{}_result.pkl".format(outdir, label),
+        )
+
+    def test_result_file_name_pickle(self):
+        outdir = "outdir"
+        label = "label"
+        self.assertEqual(
+            bilby.core.result.result_file_name(outdir, label, extension="pickle"),
+            "{}/{}_result.pkl".format(outdir, label),
+        )
+
     def test_fail_save_and_load(self):
         with self.assertRaises(ValueError):
             bilby.core.result.read_in_result()
@@ -176,163 +192,56 @@ class TestResult(unittest.TestCase):
         with self.assertRaises(ValueError):
             _ = self.result.posterior
 
-    def test_save_and_load_hdf5(self):
-        self.result.save_to_file(extension="hdf5")
-        loaded_result = bilby.core.result.read_in_result(
-            outdir=self.result.outdir, label=self.result.label, extension="hdf5"
-        )
-        self.assertTrue(
-            pd.DataFrame.equals(self.result.posterior, loaded_result.posterior)
-        )
-        self.assertTrue(
-            self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
-        )
-        self.assertTrue(
-            self.result.search_parameter_keys == loaded_result.search_parameter_keys
-        )
-        self.assertEqual(self.result.meta_data, loaded_result.meta_data)
-        self.assertEqual(
-            self.result.injection_parameters, loaded_result.injection_parameters
-        )
-        self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
-        self.assertEqual(
-            self.result.log_noise_evidence, loaded_result.log_noise_evidence
-        )
-        self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
-        self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
-        self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
-        self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
-        self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
-        self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
-
-    def test_save_and_load_default(self):
-        self.result.save_to_file()
-        loaded_result = bilby.core.result.read_in_result(
-            outdir=self.result.outdir, label=self.result.label
-        )
-        self.assertTrue(
-            np.array_equal(
-                self.result.posterior.sort_values(by=["x"]),
-                loaded_result.posterior.sort_values(by=["x"]),
+    def test_save_and_load(self):
+        gzips = [True, False, False, False]
+        extensions = ['json', 'json', 'pkl', 'hdf5']
+        for gzip, extension in zip(gzips, extensions):
+            self.result.save_to_file(extension=extension, gzip=gzip)
+            loaded_result = bilby.core.result.read_in_result(
+                outdir=self.result.outdir, label=self.result.label, extension=extension, gzip=gzip
             )
-        )
-        self.assertTrue(
-            self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
-        )
-        self.assertTrue(
-            self.result.search_parameter_keys == loaded_result.search_parameter_keys
-        )
-        self.assertEqual(self.result.meta_data, loaded_result.meta_data)
-        self.assertEqual(
-            self.result.injection_parameters, loaded_result.injection_parameters
-        )
-        self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
-        self.assertEqual(
-            self.result.log_noise_evidence, loaded_result.log_noise_evidence
-        )
-        self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
-        self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
-        self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
-        self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
-        self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
-        self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
-
-    def test_save_and_load_gzip(self):
-        self.result.save_to_file(gzip=True)
-        loaded_result = bilby.core.result.read_in_result(
-            outdir=self.result.outdir, label=self.result.label, gzip=True
-        )
-        self.assertTrue(
-            np.array_equal(
-                self.result.posterior.sort_values(by=["x"]),
-                loaded_result.posterior.sort_values(by=["x"]),
+            self.assertTrue(
+                np.array_equal(
+                    self.result.posterior.sort_values(by=["x"]),
+                    loaded_result.posterior.sort_values(by=["x"]),
+                )
             )
-        )
-        self.assertTrue(
-            self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
-        )
-        self.assertTrue(
-            self.result.search_parameter_keys == loaded_result.search_parameter_keys
-        )
-        self.assertEqual(self.result.meta_data, loaded_result.meta_data)
-        self.assertEqual(
-            self.result.injection_parameters, loaded_result.injection_parameters
-        )
-        self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
-        self.assertEqual(
-            self.result.log_noise_evidence, loaded_result.log_noise_evidence
-        )
-        self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
-        self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
-        self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
-        self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
-        self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
-        self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
-
-    def test_save_and_dont_overwrite_default(self):
-        shutil.rmtree(
-            "{}/{}_result.json.old".format(self.result.outdir, self.result.label),
-            ignore_errors=True,
-        )
-        self.result.save_to_file(overwrite=False)
-        self.result.save_to_file(overwrite=False)
-        self.assertTrue(
-            os.path.isfile(
-                "{}/{}_result.json.old".format(self.result.outdir, self.result.label)
+            self.assertTrue(
+                self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
             )
-        )
-
-    def test_save_and_dont_overwrite_hdf5(self):
-        shutil.rmtree(
-            "{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label),
-            ignore_errors=True,
-        )
-        self.result.save_to_file(overwrite=False, extension="hdf5")
-        self.result.save_to_file(overwrite=False, extension="hdf5")
-        self.assertTrue(
-            os.path.isfile(
-                "{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label)
+            self.assertTrue(
+                self.result.search_parameter_keys == loaded_result.search_parameter_keys
             )
-        )
-
-    def test_save_and_overwrite_hdf5(self):
-        shutil.rmtree(
-            "{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label),
-            ignore_errors=True,
-        )
-        self.result.save_to_file(overwrite=True, extension="hdf5")
-        self.result.save_to_file(overwrite=True, extension="hdf5")
-        self.assertFalse(
-            os.path.isfile(
-                "{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label)
+            self.assertEqual(self.result.meta_data, loaded_result.meta_data)
+            self.assertEqual(
+                self.result.injection_parameters, loaded_result.injection_parameters
             )
-        )
-
-    def test_save_and_overwrite_default(self):
-        shutil.rmtree(
-            "{}/{}_result.json.old".format(self.result.outdir, self.result.label),
-            ignore_errors=True,
-        )
-        self.result.save_to_file(overwrite=True, extension="hdf5")
-        self.result.save_to_file(overwrite=True, extension="hdf5")
-        self.assertFalse(
-            os.path.isfile(
-                "{}/{}_result.h5.old".format(self.result.outdir, self.result.label)
+            self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
+            self.assertEqual(
+                self.result.log_noise_evidence, loaded_result.log_noise_evidence
             )
-        )
+            self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
+            self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
+            self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
+            self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
+            self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
+            self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
+
+    def test_save_and_dont_overwrite(self):
+        extensions = ['json', 'pkl', 'hdf5']
+        for extension in extensions:
+            shutil.rmtree(f"{self.result.outdir}/{self.result.label}_result.{extension}.old", ignore_errors=True)
+            self.result.save_to_file(overwrite=False)
+            self.result.save_to_file(overwrite=False)
+            self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
 
-    def test_save_and_overwrite_default_2(self):
-        shutil.rmtree(
-            "{}/{}_result.json.old".format(self.result.outdir, self.result.label),
-            ignore_errors=True,
-        )
-        self.result.save_to_file(overwrite=True)
-        self.result.save_to_file(overwrite=True)
-        self.assertFalse(
-            os.path.isfile(
-                "{}/{}_result.json.old".format(self.result.outdir, self.result.label)
-            )
-        )
+    def test_save_and_overwrite_default(self):
+        extensions = ['json', 'pkl', 'hdf5']
+        for extension in extensions:
+            shutil.rmtree(f"{self.result.outdir}/{self.result.label}_result.{extension}.old", ignore_errors=True)
+            self.result.save_to_file(overwrite=True, extension=extension)
+            self.result.save_to_file(overwrite=True, extension=extension)
+            self.assertFalse(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
 
     def test_save_samples(self):
         self.result.save_posterior_samples()
-- 
GitLab