From c2c3776464e3ea62ae7927ce9d3aaa4b24b5778b Mon Sep 17 00:00:00 2001 From: Moritz Huebner <moritz.huebner@ligo.org> Date: Thu, 29 Apr 2021 04:47:14 +0000 Subject: [PATCH] Fix pickle --- bilby/core/result.py | 14 +-- bilby/core/sampler/__init__.py | 3 +- test/core/result_test.py | 170 ++++++++++----------------------- 3 files changed, 59 insertions(+), 128 deletions(-) diff --git a/bilby/core/result.py b/bilby/core/result.py index b50c89806..f81558c2a 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -35,7 +35,7 @@ def result_file_name(outdir, label, extension='json', gzip=False): label: str Naming scheme of the output file extension: str, optional - Whether to save as `hdf5` or `json` + Whether to save as `hdf5`, `json`, or `pickle` gzip: bool, optional Set to True to append `.gz` to the extension for saving in gzipped format @@ -43,7 +43,9 @@ def result_file_name(outdir, label, extension='json', gzip=False): ======= str: File name of the output file """ - if extension in ['json', 'hdf5']: + if extension == 'pickle': + extension = 'pkl' + if extension in ['json', 'hdf5', 'pkl']: if extension == 'json' and gzip: return os.path.join(outdir, '{}_result.{}.gz'.format(label, extension)) else: @@ -324,7 +326,7 @@ class Result(object): num_likelihood_evaluations=None, walkers=None, max_autocorrelation_time=None, use_ratio=None, parameter_labels=None, parameter_labels_with_unit=None, - gzip=False, version=None): + version=None): """ A class to store the results of the sampling run Parameters @@ -370,8 +372,6 @@ class Result(object): likelihood was used during sampling parameter_labels, parameter_labels_with_unit: list Lists of the latex-formatted parameter labels - gzip: bool - Set to True to gzip the results file (if using json format) version: str, Version information for software used to generate the result. Note, this information is generated when the result object is initialized @@ -737,11 +737,11 @@ class Result(object): default=False outdir: str, optional Path to the outdir. Default is the one stored in the result object. - extension: str, optional {json, hdf5, True} + extension: str, optional {json, hdf5, pkl, pickle, True} Determines the method to use to store the data (if True defaults to json) gzip: bool, optional - If true, and outputing to a json file, this will gzip the resulting + If true, and outputting to a json file, this will gzip the resulting file and add '.gz' to the file extension. """ diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 21184ba60..93202387f 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -92,9 +92,10 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', saving. For example, if `meta_data={dtype: 'signal'}`. Warning: in case of conflict with keys saved by bilby, the meta_data keys will be overwritten. - save: bool + save: bool, str If true, save the priors and results to disk. If hdf5, save as an hdf5 file instead of json. + If pickle or pkl, save as an pickle file instead of json. gzip: bool If true, and save is true, gzip the saved results file. result_class: bilby.core.result.Result, or child of diff --git a/test/core/result_test.py b/test/core/result_test.py index 3f84f703f..b2a7c24c1 100644 --- a/test/core/result_test.py +++ b/test/core/result_test.py @@ -108,6 +108,22 @@ class TestResult(unittest.TestCase): "{}/{}_result.hdf5".format(outdir, label), ) + def test_result_file_name_pkl(self): + outdir = "outdir" + label = "label" + self.assertEqual( + bilby.core.result.result_file_name(outdir, label, extension="pkl"), + "{}/{}_result.pkl".format(outdir, label), + ) + + def test_result_file_name_pickle(self): + outdir = "outdir" + label = "label" + self.assertEqual( + bilby.core.result.result_file_name(outdir, label, extension="pickle"), + "{}/{}_result.pkl".format(outdir, label), + ) + def test_fail_save_and_load(self): with self.assertRaises(ValueError): bilby.core.result.read_in_result() @@ -176,39 +192,22 @@ class TestResult(unittest.TestCase): with self.assertRaises(ValueError): _ = self.result.posterior + def test_save_and_load_json(self): + self._save_and_load_test(extension='json') + + def test_save_and_load_json_gzip(self): + self._save_and_load_test(extension='json', gzip=True) + + def test_save_and_load_pkl(self): + self._save_and_load_test(extension='pkl') + def test_save_and_load_hdf5(self): - self.result.save_to_file(extension="hdf5") - loaded_result = bilby.core.result.read_in_result( - outdir=self.result.outdir, label=self.result.label, extension="hdf5" - ) - self.assertTrue( - pd.DataFrame.equals(self.result.posterior, loaded_result.posterior) - ) - self.assertTrue( - self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys - ) - self.assertTrue( - self.result.search_parameter_keys == loaded_result.search_parameter_keys - ) - self.assertEqual(self.result.meta_data, loaded_result.meta_data) - self.assertEqual( - self.result.injection_parameters, loaded_result.injection_parameters - ) - self.assertEqual(self.result.log_evidence, loaded_result.log_evidence) - self.assertEqual( - self.result.log_noise_evidence, loaded_result.log_noise_evidence - ) - self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err) - self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor) - self.assertEqual(self.result.priors["x"], loaded_result.priors["x"]) - self.assertEqual(self.result.priors["y"], loaded_result.priors["y"]) - self.assertEqual(self.result.priors["c"], loaded_result.priors["c"]) - self.assertEqual(self.result.priors["d"], loaded_result.priors["d"]) + self._save_and_load_test(extension='hdf5') - def test_save_and_load_default(self): - self.result.save_to_file() + def _save_and_load_test(self, extension, gzip=False): + self.result.save_to_file(extension=extension, gzip=gzip) loaded_result = bilby.core.result.read_in_result( - outdir=self.result.outdir, label=self.result.label + outdir=self.result.outdir, label=self.result.label, extension=extension, gzip=gzip ) self.assertTrue( np.array_equal( @@ -237,102 +236,33 @@ class TestResult(unittest.TestCase): self.assertEqual(self.result.priors["c"], loaded_result.priors["c"]) self.assertEqual(self.result.priors["d"], loaded_result.priors["d"]) - def test_save_and_load_gzip(self): - self.result.save_to_file(gzip=True) - loaded_result = bilby.core.result.read_in_result( - outdir=self.result.outdir, label=self.result.label, gzip=True - ) - self.assertTrue( - np.array_equal( - self.result.posterior.sort_values(by=["x"]), - loaded_result.posterior.sort_values(by=["x"]), - ) - ) - self.assertTrue( - self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys - ) - self.assertTrue( - self.result.search_parameter_keys == loaded_result.search_parameter_keys - ) - self.assertEqual(self.result.meta_data, loaded_result.meta_data) - self.assertEqual( - self.result.injection_parameters, loaded_result.injection_parameters - ) - self.assertEqual(self.result.log_evidence, loaded_result.log_evidence) - self.assertEqual( - self.result.log_noise_evidence, loaded_result.log_noise_evidence - ) - self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err) - self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor) - self.assertEqual(self.result.priors["x"], loaded_result.priors["x"]) - self.assertEqual(self.result.priors["y"], loaded_result.priors["y"]) - self.assertEqual(self.result.priors["c"], loaded_result.priors["c"]) - self.assertEqual(self.result.priors["d"], loaded_result.priors["d"]) + def test_save_and_dont_overwrite_json(self): + self._save_and_dont_overwrite_test(extension='json') - def test_save_and_dont_overwrite_default(self): - shutil.rmtree( - "{}/{}_result.json.old".format(self.result.outdir, self.result.label), - ignore_errors=True, - ) - self.result.save_to_file(overwrite=False) - self.result.save_to_file(overwrite=False) - self.assertTrue( - os.path.isfile( - "{}/{}_result.json.old".format(self.result.outdir, self.result.label) - ) - ) + def test_save_and_dont_overwrite_pkl(self): + self._save_and_dont_overwrite_test(extension='pkl') def test_save_and_dont_overwrite_hdf5(self): - shutil.rmtree( - "{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label), - ignore_errors=True, - ) - self.result.save_to_file(overwrite=False, extension="hdf5") - self.result.save_to_file(overwrite=False, extension="hdf5") - self.assertTrue( - os.path.isfile( - "{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label) - ) - ) + self._save_and_dont_overwrite_test(extension='hdf5') - def test_save_and_overwrite_hdf5(self): - shutil.rmtree( - "{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label), - ignore_errors=True, - ) - self.result.save_to_file(overwrite=True, extension="hdf5") - self.result.save_to_file(overwrite=True, extension="hdf5") - self.assertFalse( - os.path.isfile( - "{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label) - ) - ) + def _save_and_dont_overwrite_test(self, extension): + self.result.save_to_file(overwrite=False, extension=extension) + self.result.save_to_file(overwrite=False, extension=extension) + self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old")) - def test_save_and_overwrite_default(self): - shutil.rmtree( - "{}/{}_result.json.old".format(self.result.outdir, self.result.label), - ignore_errors=True, - ) - self.result.save_to_file(overwrite=True, extension="hdf5") - self.result.save_to_file(overwrite=True, extension="hdf5") - self.assertFalse( - os.path.isfile( - "{}/{}_result.h5.old".format(self.result.outdir, self.result.label) - ) - ) + def test_save_and_overwrite_json(self): + self._save_and_overwrite_test(extension='json') - def test_save_and_overwrite_default_2(self): - shutil.rmtree( - "{}/{}_result.json.old".format(self.result.outdir, self.result.label), - ignore_errors=True, - ) - self.result.save_to_file(overwrite=True) - self.result.save_to_file(overwrite=True) - self.assertFalse( - os.path.isfile( - "{}/{}_result.json.old".format(self.result.outdir, self.result.label) - ) - ) + def test_save_and_overwrite_pkl(self): + self._save_and_overwrite_test(extension='pkl') + + def test_save_and_overwrite_hdf5(self): + self._save_and_overwrite_test(extension='hdf5') + + def _save_and_overwrite_test(self, extension): + self.result.save_to_file(overwrite=True, extension=extension) + self.result.save_to_file(overwrite=True, extension=extension) + self.assertFalse(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old")) def test_save_samples(self): self.result.save_posterior_samples() -- GitLab