Skip to content
Snippets Groups Projects
Commit da685c53 authored by Sylvia Biscoveanu's avatar Sylvia Biscoveanu Committed by Moritz Huebner
Browse files

Jsonify results

parent e7e24774
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,7 @@ import numpy as np
import deepdish
import pandas as pd
import corner
import json
import scipy.stats
import matplotlib
import matplotlib.pyplot as plt
......@@ -19,7 +20,7 @@ from .utils import (logger, infer_parameters_from_function,
from .prior import Prior, PriorDict, DeltaFunction
def result_file_name(outdir, label):
def result_file_name(outdir, label, extension='json'):
""" Returns the standard filename used for a result file
Parameters
......@@ -28,17 +29,27 @@ def result_file_name(outdir, label):
Name of the output directory
label: str
Naming scheme of the output file
extension: str, optional
Whether to save as `hdf5` or `json`
Returns
-------
str: File name of the output file
"""
return '{}/{}_result.h5'.format(outdir, label)
if extension == 'hdf5':
return '{}/{}_result.h5'.format(outdir, label)
else:
return '{}/{}_result.json'.format(outdir, label)
def read_in_result(filename=None, outdir=None, label=None):
""" Wrapper to bilby.core.result.Result.from_hdf5 """
return Result.from_hdf5(filename=filename, outdir=outdir, label=label)
""" Wrapper to bilby.core.result.Result.from_hdf5
or bilby.core.result.Result.from_json """
try:
result = Result.from_json(filename=filename, outdir=outdir, label=label)
except (IOError, ValueError):
result = Result.from_hdf5(filename=filename, outdir=outdir, label=label)
return result
class Result(object):
......@@ -155,7 +166,7 @@ class Result(object):
if (outdir is None) and (label is None):
raise ValueError("No information given to load file")
else:
filename = result_file_name(outdir, label)
filename = result_file_name(outdir, label, extension='hdf5')
if os.path.isfile(filename):
dictionary = deepdish.io.load(filename)
# Some versions of deepdish/pytables return the dictionanary as
......@@ -169,6 +180,50 @@ class Result(object):
else:
raise IOError("No result '{}' found".format(filename))
@classmethod
def from_json(cls, filename=None, outdir=None, label=None):
""" Read in a saved .json data file
Parameters
----------
filename: str
If given, try to load from this filename
outdir, label: str
If given, use the default naming convention for saved results file
Returns
-------
result: bilby.core.result.Result
Raises
-------
ValueError: If no filename is given and either outdir or label is None
If no bilby.core.result.Result is found in the path
"""
if filename is None:
if (outdir is None) and (label is None):
raise ValueError("No information given to load file")
else:
filename = result_file_name(outdir, label)
if os.path.isfile(filename):
dictionary = json.load(open(filename, 'r'))
for key in dictionary.keys():
# Convert some dictionaries back to DataFrames
if key in ['posterior', 'nested_samples']:
dictionary[key] = pd.DataFrame.from_dict(dictionary[key])
# Convert the loaded priors to bilby prior type
if key == 'priors':
for param in dictionary[key].keys():
dictionary[key][param] = str(dictionary[key][param])
dictionary[key] = PriorDict(dictionary[key])
try:
return cls(**dictionary)
except TypeError as e:
raise IOError("Unable to load dictionary, error={}".format(e))
else:
raise IOError("No result '{}' found".format(filename))
def __str__(self):
"""Print a summary """
if getattr(self, 'posterior', None) is not None:
......@@ -303,9 +358,9 @@ class Result(object):
pass
return dictionary
def save_to_file(self, overwrite=False, outdir=None):
def save_to_file(self, overwrite=False, outdir=None, extension='json'):
"""
Writes the Result to a deepdish h5 file
Writes the Result to a json or deepdish h5 file
Parameters
----------
......@@ -314,9 +369,11 @@ class Result(object):
default=False
outdir: str, optional
Path to the outdir. Default is the one stored in the result object.
extension: str, optional
Whether to save as hdf5 instead of json
"""
outdir = self._safe_outdir_creation(outdir, self.save_to_file)
file_name = result_file_name(outdir, self.label)
file_name = result_file_name(outdir, self.label, extension)
if os.path.isfile(file_name):
if overwrite:
......@@ -341,8 +398,19 @@ class Result(object):
if hasattr(dictionary['sampler_kwargs'][key], '__call__'):
dictionary['sampler_kwargs'][key] = str(dictionary['sampler_kwargs'])
# Convert to json saveable format
if extension != 'hdf5':
for key in dictionary.keys():
if isinstance(dictionary[key], pd.core.frame.DataFrame):
dictionary[key] = dictionary[key].to_dict()
elif isinstance(dictionary[key], np.ndarray):
dictionary[key] = dictionary[key].tolist()
try:
deepdish.io.save(file_name, dictionary)
if extension == 'hdf5':
deepdish.io.save(file_name, dictionary)
else:
json.dump(dictionary, open(file_name, 'w'), indent=2)
except Exception as e:
logger.error("\n\n Saving the data has failed with the "
"following message:\n {} \n\n".format(e))
......
......@@ -85,6 +85,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
overwritten.
save: bool
If true, save the priors and results to disk.
If hdf5, save as an hdf5 file instead of json.
result_class: bilby.core.result.Result, or child of
The result class to use. By default, `bilby.core.result.Result` is used,
but objects which inherit from this class can be given providing
......@@ -183,7 +184,10 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
result.samples_to_posterior(likelihood=likelihood, priors=priors,
conversion_function=conversion_function)
if save:
if save == 'hdf5':
result.save_to_file(extension='hdf5')
logger.info("Results saved to {}/".format(outdir))
elif save:
result.save_to_file()
logger.info("Results saved to {}/".format(outdir))
if plot:
......
......@@ -45,10 +45,16 @@ class TestResult(unittest.TestCase):
del self.result
pass
def test_result_file_name(self):
def test_result_file_name_default(self):
outdir = 'outdir'
label = 'label'
self.assertEqual(bilby.core.result.result_file_name(outdir, label),
'{}/{}_result.json'.format(outdir, label))
def test_result_file_name_hdf5(self):
outdir = 'outdir'
label = 'label'
self.assertEqual(bilby.core.result.result_file_name(outdir, label, extension='hdf5'),
'{}/{}_result.h5'.format(outdir, label))
def test_fail_save_and_load(self):
......@@ -104,8 +110,8 @@ class TestResult(unittest.TestCase):
with self.assertRaises(ValueError):
_ = self.result.posterior
def test_save_and_load(self):
self.result.save_to_file()
def test_save_and_load_hdf5(self):
self.result.save_to_file(extension='hdf5')
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label)
self.assertTrue(pd.DataFrame.equals
......@@ -123,23 +129,61 @@ class TestResult(unittest.TestCase):
self.assertEqual(self.result.priors['c'], loaded_result.priors['c'])
self.assertEqual(self.result.priors['d'], loaded_result.priors['d'])
def test_save_and_dont_overwrite(self):
def test_save_and_load_default(self):
self.result.save_to_file()
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label)
self.assertTrue(np.array_equal
(self.result.posterior.sort_values(by=['x']),
loaded_result.posterior.sort_values(by=['x'])))
self.assertTrue(self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys)
self.assertTrue(self.result.search_parameter_keys == loaded_result.search_parameter_keys)
self.assertEqual(self.result.meta_data, loaded_result.meta_data)
self.assertEqual(self.result.injection_parameters, loaded_result.injection_parameters)
self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
self.assertEqual(self.result.log_noise_evidence, loaded_result.log_noise_evidence)
self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
self.assertEqual(self.result.priors['x'], loaded_result.priors['x'])
self.assertEqual(self.result.priors['y'], loaded_result.priors['y'])
self.assertEqual(self.result.priors['c'], loaded_result.priors['c'])
self.assertEqual(self.result.priors['d'], loaded_result.priors['d'])
def test_save_and_dont_overwrite_default(self):
shutil.rmtree(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
'{}/{}_result.json.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=False)
self.result.save_to_file(overwrite=False)
self.assertTrue(os.path.isfile(
'{}/{}_result.json.old'.format(self.result.outdir, self.result.label)))
def test_save_and_dont_overwrite_hdf5(self):
shutil.rmtree(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=False, extension='hdf5')
self.result.save_to_file(overwrite=False, extension='hdf5')
self.assertTrue(os.path.isfile(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label)))
def test_save_and_overwrite(self):
def test_save_and_overwrite_hdf5(self):
shutil.rmtree(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=True, extension='hdf5')
self.result.save_to_file(overwrite=True, extension='hdf5')
self.assertFalse(os.path.isfile(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label)))
def test_save_and_overwrite_default(self):
shutil.rmtree(
'{}/{}_result.json.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=True)
self.result.save_to_file(overwrite=True)
self.assertFalse(os.path.isfile(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label)))
'{}/{}_result.json.old'.format(self.result.outdir, self.result.label)))
def test_save_samples(self):
self.result.save_posterior_samples()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment