Skip to content
Snippets Groups Projects
Commit 18707791 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Simplified tests and added pickle

parent 8cc36985
No related branches found
No related tags found
1 merge request!932Fix pickle
Pipeline #221132 failed
......@@ -108,6 +108,22 @@ class TestResult(unittest.TestCase):
"{}/{}_result.hdf5".format(outdir, label),
)
def test_result_file_name_pkl(self):
outdir = "outdir"
label = "label"
self.assertEqual(
bilby.core.result.result_file_name(outdir, label, extension="pkl"),
"{}/{}_result.pkl".format(outdir, label),
)
def test_result_file_name_pickle(self):
outdir = "outdir"
label = "label"
self.assertEqual(
bilby.core.result.result_file_name(outdir, label, extension="pickle"),
"{}/{}_result.pkl".format(outdir, label),
)
def test_fail_save_and_load(self):
with self.assertRaises(ValueError):
bilby.core.result.read_in_result()
......@@ -176,163 +192,56 @@ class TestResult(unittest.TestCase):
with self.assertRaises(ValueError):
_ = self.result.posterior
def test_save_and_load_hdf5(self):
self.result.save_to_file(extension="hdf5")
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label, extension="hdf5"
)
self.assertTrue(
pd.DataFrame.equals(self.result.posterior, loaded_result.posterior)
)
self.assertTrue(
self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
)
self.assertTrue(
self.result.search_parameter_keys == loaded_result.search_parameter_keys
)
self.assertEqual(self.result.meta_data, loaded_result.meta_data)
self.assertEqual(
self.result.injection_parameters, loaded_result.injection_parameters
)
self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
self.assertEqual(
self.result.log_noise_evidence, loaded_result.log_noise_evidence
)
self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
def test_save_and_load_default(self):
self.result.save_to_file()
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label
)
self.assertTrue(
np.array_equal(
self.result.posterior.sort_values(by=["x"]),
loaded_result.posterior.sort_values(by=["x"]),
def test_save_and_load(self):
gzips = [True, False, False, False]
extensions = ['json', 'json', 'pkl', 'hdf5']
for gzip, extension in zip(gzips, extensions):
self.result.save_to_file(extension=extension, gzip=gzip)
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label, extension=extension, gzip=gzip
)
)
self.assertTrue(
self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
)
self.assertTrue(
self.result.search_parameter_keys == loaded_result.search_parameter_keys
)
self.assertEqual(self.result.meta_data, loaded_result.meta_data)
self.assertEqual(
self.result.injection_parameters, loaded_result.injection_parameters
)
self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
self.assertEqual(
self.result.log_noise_evidence, loaded_result.log_noise_evidence
)
self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
def test_save_and_load_gzip(self):
self.result.save_to_file(gzip=True)
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label, gzip=True
)
self.assertTrue(
np.array_equal(
self.result.posterior.sort_values(by=["x"]),
loaded_result.posterior.sort_values(by=["x"]),
self.assertTrue(
np.array_equal(
self.result.posterior.sort_values(by=["x"]),
loaded_result.posterior.sort_values(by=["x"]),
)
)
)
self.assertTrue(
self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
)
self.assertTrue(
self.result.search_parameter_keys == loaded_result.search_parameter_keys
)
self.assertEqual(self.result.meta_data, loaded_result.meta_data)
self.assertEqual(
self.result.injection_parameters, loaded_result.injection_parameters
)
self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
self.assertEqual(
self.result.log_noise_evidence, loaded_result.log_noise_evidence
)
self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
def test_save_and_dont_overwrite_default(self):
shutil.rmtree(
"{}/{}_result.json.old".format(self.result.outdir, self.result.label),
ignore_errors=True,
)
self.result.save_to_file(overwrite=False)
self.result.save_to_file(overwrite=False)
self.assertTrue(
os.path.isfile(
"{}/{}_result.json.old".format(self.result.outdir, self.result.label)
self.assertTrue(
self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys
)
)
def test_save_and_dont_overwrite_hdf5(self):
shutil.rmtree(
"{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label),
ignore_errors=True,
)
self.result.save_to_file(overwrite=False, extension="hdf5")
self.result.save_to_file(overwrite=False, extension="hdf5")
self.assertTrue(
os.path.isfile(
"{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label)
self.assertTrue(
self.result.search_parameter_keys == loaded_result.search_parameter_keys
)
)
def test_save_and_overwrite_hdf5(self):
shutil.rmtree(
"{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label),
ignore_errors=True,
)
self.result.save_to_file(overwrite=True, extension="hdf5")
self.result.save_to_file(overwrite=True, extension="hdf5")
self.assertFalse(
os.path.isfile(
"{}/{}_result.hdf5.old".format(self.result.outdir, self.result.label)
self.assertEqual(self.result.meta_data, loaded_result.meta_data)
self.assertEqual(
self.result.injection_parameters, loaded_result.injection_parameters
)
)
def test_save_and_overwrite_default(self):
shutil.rmtree(
"{}/{}_result.json.old".format(self.result.outdir, self.result.label),
ignore_errors=True,
)
self.result.save_to_file(overwrite=True, extension="hdf5")
self.result.save_to_file(overwrite=True, extension="hdf5")
self.assertFalse(
os.path.isfile(
"{}/{}_result.h5.old".format(self.result.outdir, self.result.label)
self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
self.assertEqual(
self.result.log_noise_evidence, loaded_result.log_noise_evidence
)
)
self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
self.assertEqual(self.result.priors["x"], loaded_result.priors["x"])
self.assertEqual(self.result.priors["y"], loaded_result.priors["y"])
self.assertEqual(self.result.priors["c"], loaded_result.priors["c"])
self.assertEqual(self.result.priors["d"], loaded_result.priors["d"])
def test_save_and_dont_overwrite(self):
extensions = ['json', 'pkl', 'hdf5']
for extension in extensions:
shutil.rmtree(f"{self.result.outdir}/{self.result.label}_result.{extension}.old", ignore_errors=True)
self.result.save_to_file(overwrite=False)
self.result.save_to_file(overwrite=False)
self.assertTrue(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
def test_save_and_overwrite_default_2(self):
shutil.rmtree(
"{}/{}_result.json.old".format(self.result.outdir, self.result.label),
ignore_errors=True,
)
self.result.save_to_file(overwrite=True)
self.result.save_to_file(overwrite=True)
self.assertFalse(
os.path.isfile(
"{}/{}_result.json.old".format(self.result.outdir, self.result.label)
)
)
def test_save_and_overwrite_default(self):
extensions = ['json', 'pkl', 'hdf5']
for extension in extensions:
shutil.rmtree(f"{self.result.outdir}/{self.result.label}_result.{extension}.old", ignore_errors=True)
self.result.save_to_file(overwrite=True, extension=extension)
self.result.save_to_file(overwrite=True, extension=extension)
self.assertFalse(os.path.isfile(f"{self.result.outdir}/{self.result.label}_result.{extension}.old"))
def test_save_samples(self):
self.result.save_posterior_samples()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment