Commit da685c53 authored by Sylvia Biscoveanu's avatar Sylvia Biscoveanu Committed by Moritz Huebner

Jsonify results

parent e7e24774
......@@ -8,6 +8,7 @@ import numpy as np
import deepdish
import pandas as pd
import corner
import json
import scipy.stats
import matplotlib
import matplotlib.pyplot as plt
......@@ -19,7 +20,7 @@ from .utils import (logger, infer_parameters_from_function,
from .prior import Prior, PriorDict, DeltaFunction
def result_file_name(outdir, label):
def result_file_name(outdir, label, extension='json'):
""" Returns the standard filename used for a result file
Parameters
......@@ -28,17 +29,27 @@ def result_file_name(outdir, label):
Name of the output directory
label: str
Naming scheme of the output file
extension: str, optional
Whether to save as `hdf5` or `json`
Returns
-------
str: File name of the output file
"""
return '{}/{}_result.h5'.format(outdir, label)
if extension == 'hdf5':
return '{}/{}_result.h5'.format(outdir, label)
else:
return '{}/{}_result.json'.format(outdir, label)
def read_in_result(filename=None, outdir=None, label=None):
""" Wrapper to bilby.core.result.Result.from_hdf5 """
return Result.from_hdf5(filename=filename, outdir=outdir, label=label)
""" Wrapper to bilby.core.result.Result.from_hdf5
or bilby.core.result.Result.from_json """
try:
result = Result.from_json(filename=filename, outdir=outdir, label=label)
except (IOError, ValueError):
result = Result.from_hdf5(filename=filename, outdir=outdir, label=label)
return result
class Result(object):
......@@ -155,7 +166,7 @@ class Result(object):
if (outdir is None) and (label is None):
raise ValueError("No information given to load file")
else:
filename = result_file_name(outdir, label)
filename = result_file_name(outdir, label, extension='hdf5')
if os.path.isfile(filename):
dictionary = deepdish.io.load(filename)
# Some versions of deepdish/pytables return the dictionanary as
......@@ -169,6 +180,50 @@ class Result(object):
else:
raise IOError("No result '{}' found".format(filename))
@classmethod
def from_json(cls, filename=None, outdir=None, label=None):
""" Read in a saved .json data file
Parameters
----------
filename: str
If given, try to load from this filename
outdir, label: str
If given, use the default naming convention for saved results file
Returns
-------
result: bilby.core.result.Result
Raises
-------
ValueError: If no filename is given and either outdir or label is None
If no bilby.core.result.Result is found in the path
"""
if filename is None:
if (outdir is None) and (label is None):
raise ValueError("No information given to load file")
else:
filename = result_file_name(outdir, label)
if os.path.isfile(filename):
dictionary = json.load(open(filename, 'r'))
for key in dictionary.keys():
# Convert some dictionaries back to DataFrames
if key in ['posterior', 'nested_samples']:
dictionary[key] = pd.DataFrame.from_dict(dictionary[key])
# Convert the loaded priors to bilby prior type
if key == 'priors':
for param in dictionary[key].keys():
dictionary[key][param] = str(dictionary[key][param])
dictionary[key] = PriorDict(dictionary[key])
try:
return cls(**dictionary)
except TypeError as e:
raise IOError("Unable to load dictionary, error={}".format(e))
else:
raise IOError("No result '{}' found".format(filename))
def __str__(self):
"""Print a summary """
if getattr(self, 'posterior', None) is not None:
......@@ -303,9 +358,9 @@ class Result(object):
pass
return dictionary
def save_to_file(self, overwrite=False, outdir=None):
def save_to_file(self, overwrite=False, outdir=None, extension='json'):
"""
Writes the Result to a deepdish h5 file
Writes the Result to a json or deepdish h5 file
Parameters
----------
......@@ -314,9 +369,11 @@ class Result(object):
default=False
outdir: str, optional
Path to the outdir. Default is the one stored in the result object.
extension: str, optional
Whether to save as hdf5 instead of json
"""
outdir = self._safe_outdir_creation(outdir, self.save_to_file)
file_name = result_file_name(outdir, self.label)
file_name = result_file_name(outdir, self.label, extension)
if os.path.isfile(file_name):
if overwrite:
......@@ -341,8 +398,19 @@ class Result(object):
if hasattr(dictionary['sampler_kwargs'][key], '__call__'):
dictionary['sampler_kwargs'][key] = str(dictionary['sampler_kwargs'])
# Convert to json saveable format
if extension != 'hdf5':
for key in dictionary.keys():
if isinstance(dictionary[key], pd.core.frame.DataFrame):
dictionary[key] = dictionary[key].to_dict()
elif isinstance(dictionary[key], np.ndarray):
dictionary[key] = dictionary[key].tolist()
try:
deepdish.io.save(file_name, dictionary)
if extension == 'hdf5':
deepdish.io.save(file_name, dictionary)
else:
json.dump(dictionary, open(file_name, 'w'), indent=2)
except Exception as e:
logger.error("\n\n Saving the data has failed with the "
"following message:\n {} \n\n".format(e))
......
......@@ -85,6 +85,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
overwritten.
save: bool
If true, save the priors and results to disk.
If hdf5, save as an hdf5 file instead of json.
result_class: bilby.core.result.Result, or child of
The result class to use. By default, `bilby.core.result.Result` is used,
but objects which inherit from this class can be given providing
......@@ -183,7 +184,10 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
result.samples_to_posterior(likelihood=likelihood, priors=priors,
conversion_function=conversion_function)
if save:
if save == 'hdf5':
result.save_to_file(extension='hdf5')
logger.info("Results saved to {}/".format(outdir))
elif save:
result.save_to_file()
logger.info("Results saved to {}/".format(outdir))
if plot:
......
......@@ -45,10 +45,16 @@ class TestResult(unittest.TestCase):
del self.result
pass
def test_result_file_name(self):
def test_result_file_name_default(self):
outdir = 'outdir'
label = 'label'
self.assertEqual(bilby.core.result.result_file_name(outdir, label),
'{}/{}_result.json'.format(outdir, label))
def test_result_file_name_hdf5(self):
outdir = 'outdir'
label = 'label'
self.assertEqual(bilby.core.result.result_file_name(outdir, label, extension='hdf5'),
'{}/{}_result.h5'.format(outdir, label))
def test_fail_save_and_load(self):
......@@ -104,8 +110,8 @@ class TestResult(unittest.TestCase):
with self.assertRaises(ValueError):
_ = self.result.posterior
def test_save_and_load(self):
self.result.save_to_file()
def test_save_and_load_hdf5(self):
self.result.save_to_file(extension='hdf5')
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label)
self.assertTrue(pd.DataFrame.equals
......@@ -123,23 +129,61 @@ class TestResult(unittest.TestCase):
self.assertEqual(self.result.priors['c'], loaded_result.priors['c'])
self.assertEqual(self.result.priors['d'], loaded_result.priors['d'])
def test_save_and_dont_overwrite(self):
def test_save_and_load_default(self):
self.result.save_to_file()
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label)
self.assertTrue(np.array_equal
(self.result.posterior.sort_values(by=['x']),
loaded_result.posterior.sort_values(by=['x'])))
self.assertTrue(self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys)
self.assertTrue(self.result.search_parameter_keys == loaded_result.search_parameter_keys)
self.assertEqual(self.result.meta_data, loaded_result.meta_data)
self.assertEqual(self.result.injection_parameters, loaded_result.injection_parameters)
self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
self.assertEqual(self.result.log_noise_evidence, loaded_result.log_noise_evidence)
self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
self.assertEqual(self.result.priors['x'], loaded_result.priors['x'])
self.assertEqual(self.result.priors['y'], loaded_result.priors['y'])
self.assertEqual(self.result.priors['c'], loaded_result.priors['c'])
self.assertEqual(self.result.priors['d'], loaded_result.priors['d'])
def test_save_and_dont_overwrite_default(self):
shutil.rmtree(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
'{}/{}_result.json.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=False)
self.result.save_to_file(overwrite=False)
self.assertTrue(os.path.isfile(
'{}/{}_result.json.old'.format(self.result.outdir, self.result.label)))
def test_save_and_dont_overwrite_hdf5(self):
shutil.rmtree(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=False, extension='hdf5')
self.result.save_to_file(overwrite=False, extension='hdf5')
self.assertTrue(os.path.isfile(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label)))
def test_save_and_overwrite(self):
def test_save_and_overwrite_hdf5(self):
shutil.rmtree(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=True, extension='hdf5')
self.result.save_to_file(overwrite=True, extension='hdf5')
self.assertFalse(os.path.isfile(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label)))
def test_save_and_overwrite_default(self):
shutil.rmtree(
'{}/{}_result.json.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=True)
self.result.save_to_file(overwrite=True)
self.assertFalse(os.path.isfile(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label)))
'{}/{}_result.json.old'.format(self.result.outdir, self.result.label)))
def test_save_samples(self):
self.result.save_posterior_samples()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment