Skip to content
Snippets Groups Projects

Fix pickle

Merged Moritz Huebner requested to merge fix_pickle into master
All threads resolved!
1 file
+ 70
48
Compare changes
  • Side-by-side
  • Inline
+ 70
48
@@ -192,55 +192,77 @@ class TestResult(unittest.TestCase):
with self.assertRaises(ValueError):
_ = self.result.posterior
def test_save_and_load(self):
gzips = [True, False, False, False]
extensions = ['json', 'json', 'pkl', 'hdf5']
for gzip, extension in zip(gzips, extensions):
self.result.save_to_file(extension=extension, gzip=gzip)
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label, extension=extension, gzip=gzip
)
self.assertTrue(
np.array_equal(
self.result.posterior.sort_values(by=["x"]),
loaded_result.posterior.sort_values(by=["x"]),
)
)
self.assertTrue(
self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
)
self.assertTrue(
self.result.search_parameter_keys == loaded_result.search_parameter_keys
)
self.assertEqual(self.result.meta_data, loaded_result.meta_data)
self.assertEqual(
self.result.injection_parameters, loaded_result.injection_parameters
)
self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
self.assertEqual(
self.result.log_noise_evidence, loaded_result.log_noise_evidence
def test_save_and_load_json(self):
self._save_and_load_test(extension='json')
def test_save_and_load_json_gzip(self):
self._save_and_load_test(extension='json', gzip=True)
def test_save_and_load_pkl(self):
self._save_and_load_test(extension='pkl')
def test_save_and_load_hdf5(self):
self._save_and_load_test(extension='hdf5')
def _save_and_load_test(self, extension, gzip=False):
self.result.save_to_file(extension=extension, gzip=gzip)
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label, extension=extension, gzip=gzip
)
self.assertTrue(
np.array_equal(
self.result.posterior.sort_values(by=["x"]),
loaded_result.posterior.sort_values(by=["x"]),
)
self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
def test_save_and_dont_overwrite(self):
extensions = ['json', 'pkl', 'hdf5']
for extension in extensions:
print(extension)
self.result.save_to_file(overwrite=False, extension=extension)
self.result.save_to_file(overwrite=False, extension=extension)
self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
def test_save_and_overwrite_default(self):
extensions = ['json', 'pkl', 'hdf5']
for extension in extensions:
self.result.save_to_file(overwrite=True, extension=extension)
self.result.save_to_file(overwrite=True, extension=extension)
self.assertFalse(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
)
self.assertTrue(
self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
)
self.assertTrue(
self.result.search_parameter_keys == loaded_result.search_parameter_keys
)
self.assertEqual(self.result.meta_data, loaded_result.meta_data)
self.assertEqual(
self.result.injection_parameters, loaded_result.injection_parameters
)
self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
self.assertEqual(
self.result.log_noise_evidence, loaded_result.log_noise_evidence
)
self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
def test_save_and_dont_overwrite_json(self):
self._save_and_dont_overwrite_test(extension='json')
def test_save_and_dont_overwrite_pkl(self):
self._save_and_dont_overwrite_test(extension='pkl')
def test_save_and_dont_overwrite_hdf5(self):
self._save_and_dont_overwrite_test(extension='hdf5')
def _save_and_dont_overwrite_test(self, extension):
self.result.save_to_file(overwrite=False, extension=extension)
self.result.save_to_file(overwrite=False, extension=extension)
self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
def test_save_and_overwrite_json(self):
self._save_and_overwrite_test(extension='json')
def test_save_and_overwrite_pkl(self):
self._save_and_overwrite_test(extension='pkl')
def test_save_and_overwrite_hdf5(self):
self._save_and_overwrite_test(extension='hdf5')
def _save_and_overwrite_test(self, extension):
self.result.save_to_file(overwrite=True, extension=extension)
self.result.save_to_file(overwrite=True, extension=extension)
self.assertFalse(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
def test_save_samples(self):
self.result.save_posterior_samples()
Loading