From c2c3776464e3ea62ae7927ce9d3aaa4b24b5778b Mon Sep 17 00:00:00 2001
From: Moritz Huebner <moritz.huebner@ligo.org>
Date: Thu, 29 Apr 2021 04:47:14 +0000
Subject: [PATCH] Fix pickle

---
 bilby/core/result.py           |  14 +--
 bilby/core/sampler/__init__.py |   3 +-
 test/core/result_test.py       | 170 ++++++++++-----------------------
 3 files changed, 59 insertions(+), 128 deletions(-)

diff --git a/bilby/core/result.py b/bilby/core/result.py
index b50c89806..f81558c2a 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -35,7 +35,7 @@ def result_file_name(outdir, label, extension='json', gzip=False):
     label: str
         Naming scheme of the output file
     extension: str, optional
-        Whether to save as `hdf5` or `json`
+        Whether to save as `hdf5`, `json`, or `pickle`
     gzip: bool, optional
         Set to True to append `.gz` to the extension for saving in gzipped format
 
@@ -43,7 +43,9 @@ def result_file_name(outdir, label, extension='json', gzip=False):
     =======
     str: File name of the output file
     """
-    if extension in ['json', 'hdf5']:
+    if extension == 'pickle':
+        extension = 'pkl'
+    if extension in ['json', 'hdf5', 'pkl']:
         if extension == 'json' and gzip:
             return os.path.join(outdir, '{}_result.{}.gz'.format(label, extension))
         else:
@@ -324,7 +326,7 @@ class Result(object):
                  num_likelihood_evaluations=None, walkers=None,
                  max_autocorrelation_time=None, use_ratio=None,
                  parameter_labels=None, parameter_labels_with_unit=None,
-                 gzip=False, version=None):
+                 version=None):
         """ A class to store the results of the sampling run
 
         Parameters
@@ -370,8 +372,6 @@ class Result(object):
             likelihood was used during sampling
         parameter_labels, parameter_labels_with_unit: list
             Lists of the latex-formatted parameter labels
-        gzip: bool
-            Set to True to gzip the results file (if using json format)
         version: str,
             Version information for software used to generate the result. Note,
             this information is generated when the result object is initialized
@@ -737,11 +737,11 @@ class Result(object):
             default=False
         outdir: str, optional
             Path to the outdir. Default is the one stored in the result object.
-        extension: str, optional {json, hdf5, True}
+        extension: str, optional {json, hdf5, pkl, pickle, True}
             Determines the method to use to store the data (if True defaults
             to json)
         gzip: bool, optional
-            If true, and outputing to a json file, this will gzip the resulting
+            If true, and outputting to a json file, this will gzip the resulting
             file and add '.gz' to the file extension.
         """
 
diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py
index 21184ba60..93202387f 100644
--- a/bilby/core/sampler/__init__.py
+++ b/bilby/core/sampler/__init__.py
@@ -92,9 +92,10 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
         saving. For example, if `meta_data={dtype: 'signal'}`. Warning: in case
         of conflict with keys saved by bilby, the meta_data keys will be
         overwritten.
-    save: bool
+    save: bool, str
         If true, save the priors and results to disk.
         If hdf5, save as an hdf5 file instead of json.
+        If pickle or pkl, save as an pickle file instead of json.
     gzip: bool
         If true, and save is true, gzip the saved results file.
     result_class: bilby.core.result.Result, or child of
diff --git a/test/core/result_test.py b/test/core/result_test.py
index 3f84f703f..b2a7c24c1 100644
--- a/test/core/result_test.py
+++ b/test/core/result_test.py
@@ -108,6 +108,22 @@ class TestResult(unittest.TestCase):
             "{}/{}_result.hdf5".format(outdir, label),
         )
 
+    def test_result_file_name_pkl(self):
+        outdir = "outdir"
+        label = "label"
+        self.assertEqual(
+            bilby.core.result.result_file_name(outdir, label, extension="pkl"),
+            "{}/{}_result.pkl".format(outdir, label),
+        )
+
+    def test_result_file_name_pickle(self):
+        outdir = "outdir"
+        label = "label"
+        self.assertEqual(
+            bilby.core.result.result_file_name(outdir, label, extension="pickle"),
+            "{}/{}_result.pkl".format(outdir, label),
+        )
+
     def test_fail_save_and_load(self):
         with self.assertRaises(ValueError):
             bilby.core.result.read_in_result()
@@ -176,39 +192,22 @@ class TestResult(unittest.TestCase):
         with self.assertRaises(ValueError):
             _ = self.result.posterior
 
+    def test_save_and_load_json(self):
+        self._save_and_load_test(extension='json')
+
+    def test_save_and_load_json_gzip(self):
+        self._save_and_load_test(extension='json', gzip=True)
+
+    def test_save_and_load_pkl(self):
+        self._save_and_load_test(extension='pkl')
+
     def test_save_and_load_hdf5(self):
-        self.result.save_to_file(extension="hdf5")
-        loaded_result = bilby.core.result.read_in_result(
-            outdir=self.result.outdir, label=self.result.label, extension="hdf5"
-        )
-        self.assertTrue(
-            pd.DataFrame.equals(self.result.posterior, loaded_result.posterior)
-        )
-        self.assertTrue(
-            self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
-        )
-        self.assertTrue(
-            self.result.search_parameter_keys == loaded_result.search_parameter_keys
-        )
-        self.assertEqual(self.result.meta_data, loaded_result.meta_data)
-        self.assertEqual(
-            self.result.injection_parameters, loaded_result.injection_parameters
-        )
-        self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
-        self.assertEqual(
-            self.result.log_noise_evidence, loaded_result.log_noise_evidence
-        )
-        self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
-        self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
-        self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
-        self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
-        self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
-        self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
+        self._save_and_load_test(extension='hdf5')
 
-    def test_save_and_load_default(self):
-        self.result.save_to_file()
+    def _save_and_load_test(self, extension, gzip=False):
+        self.result.save_to_file(extension=extension, gzip=gzip)
         loaded_result = bilby.core.result.read_in_result(
-            outdir=self.result.outdir, label=self.result.label
+            outdir=self.result.outdir, label=self.result.label, extension=extension, gzip=gzip
         )
         self.assertTrue(
             np.array_equal(
@@ -237,102 +236,33 @@ class TestResult(unittest.TestCase):
         self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
         self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
 
-    def test_save_and_load_gzip(self):
-        self.result.save_to_file(gzip=True)
-        loaded_result = bilby.core.result.read_in_result(
-            outdir=self.result.outdir, label=self.result.label, gzip=True
-        )
-        self.assertTrue(
-            np.array_equal(
-                self.result.posterior.sort_values(by=["x"]),
-                loaded_result.posterior.sort_values(by=["x"]),
-            )
-        )
-        self.assertTrue(
-            self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
-        )
-        self.assertTrue(
-            self.result.search_parameter_keys == loaded_result.search_parameter_keys
-        )
-        self.assertEqual(self.result.meta_data, loaded_result.meta_data)
-        self.assertEqual(
-            self.result.injection_parameters, loaded_result.injection_parameters
-        )
-        self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
-        self.assertEqual(
-            self.result.log_noise_evidence, loaded_result.log_noise_evidence
-        )
-        self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
-        self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
-        self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
-        self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
-        self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
-        self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
+    def test_save_and_dont_overwrite_json(self):
+        self._save_and_dont_overwrite_test(extension='json')
 
-    def test_save_and_dont_overwrite_default(self):
-        shutil.rmtree(
-            "{}/{}_result.json.old".format(self.result.outdir, self.result.label),
-            ignore_errors=True,
-        )
-        self.result.save_to_file(overwrite=False)
-        self.result.save_to_file(overwrite=False)
-        self.assertTrue(
-            os.path.isfile(
-                "{}/{}_result.json.old".format(self.result.outdir, self.result.label)
-            )
-        )
+    def test_save_and_dont_overwrite_pkl(self):
+        self._save_and_dont_overwrite_test(extension='pkl')
 
     def test_save_and_dont_overwrite_hdf5(self):
-        shutil.rmtree(
-            "{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label),
-            ignore_errors=True,
-        )
-        self.result.save_to_file(overwrite=False, extension="hdf5")
-        self.result.save_to_file(overwrite=False, extension="hdf5")
-        self.assertTrue(
-            os.path.isfile(
-                "{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label)
-            )
-        )
+        self._save_and_dont_overwrite_test(extension='hdf5')
 
-    def test_save_and_overwrite_hdf5(self):
-        shutil.rmtree(
-            "{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label),
-            ignore_errors=True,
-        )
-        self.result.save_to_file(overwrite=True, extension="hdf5")
-        self.result.save_to_file(overwrite=True, extension="hdf5")
-        self.assertFalse(
-            os.path.isfile(
-                "{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label)
-            )
-        )
+    def _save_and_dont_overwrite_test(self, extension):
+        self.result.save_to_file(overwrite=False, extension=extension)
+        self.result.save_to_file(overwrite=False, extension=extension)
+        self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
 
-    def test_save_and_overwrite_default(self):
-        shutil.rmtree(
-            "{}/{}_result.json.old".format(self.result.outdir, self.result.label),
-            ignore_errors=True,
-        )
-        self.result.save_to_file(overwrite=True, extension="hdf5")
-        self.result.save_to_file(overwrite=True, extension="hdf5")
-        self.assertFalse(
-            os.path.isfile(
-                "{}/{}_result.h5.old".format(self.result.outdir, self.result.label)
-            )
-        )
+    def test_save_and_overwrite_json(self):
+        self._save_and_overwrite_test(extension='json')
 
-    def test_save_and_overwrite_default_2(self):
-        shutil.rmtree(
-            "{}/{}_result.json.old".format(self.result.outdir, self.result.label),
-            ignore_errors=True,
-        )
-        self.result.save_to_file(overwrite=True)
-        self.result.save_to_file(overwrite=True)
-        self.assertFalse(
-            os.path.isfile(
-                "{}/{}_result.json.old".format(self.result.outdir, self.result.label)
-            )
-        )
+    def test_save_and_overwrite_pkl(self):
+        self._save_and_overwrite_test(extension='pkl')
+
+    def test_save_and_overwrite_hdf5(self):
+        self._save_and_overwrite_test(extension='hdf5')
+
+    def _save_and_overwrite_test(self, extension):
+        self.result.save_to_file(overwrite=True, extension=extension)
+        self.result.save_to_file(overwrite=True, extension=extension)
+        self.assertFalse(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
 
     def test_save_samples(self):
         self.result.save_posterior_samples()
-- 
GitLab