Commit 1f8a575b authored by Gregory Ashton's avatar Gregory Ashton Committed by Moritz Huebner

Improvements to the JSON interface

parent 1db36052
......@@ -36,19 +36,47 @@ def result_file_name(outdir, label, extension='json'):
-------
str: File name of the output file
"""
if extension == 'hdf5':
return '{}/{}_result.h5'.format(outdir, label)
if extension in ['json', 'hdf5']:
return '{}/{}_result.{}'.format(outdir, label, extension)
else:
return '{}/{}_result.json'.format(outdir, label)
raise ValueError("Extension type {} not understood".format(extension))
def read_in_result(filename=None, outdir=None, label=None):
""" Wrapper to bilby.core.result.Result.from_hdf5
or bilby.core.result.Result.from_json """
try:
result = Result.from_json(filename=filename, outdir=outdir, label=label)
except (IOError, ValueError):
result = Result.from_hdf5(filename=filename, outdir=outdir, label=label)
def _determine_file_name(filename, outdir, label, extension):
""" Helper method to determine the filename """
if filename is not None:
return filename
else:
if (outdir is None) and (label is None):
raise ValueError("No information given to load file")
else:
return result_file_name(outdir, label, extension)
def read_in_result(filename=None, outdir=None, label=None, extension='json'):
""" Reads in a stored bilby result object
Parameters
----------
filename: str
Path to the file to be read (alternative to giving the outdir and label)
outdir, label, extension: str
Name of the output directory, label and extension used for the default
naming scheme.
"""
filename = _determine_file_name(filename, outdir, label, extension)
# Get the actual extension (may differ from the default extension if the filename is given)
extension = os.path.splitext(filename)[1].lstrip('.')
if 'json' in extension:
result = Result.from_json(filename=filename)
elif ('hdf5' in extension) or ('h5' in extension):
result = Result.from_hdf5(filename=filename)
elif extension is None:
raise ValueError("No filetype extension provided")
else:
raise ValueError("Filetype {} not understood".format(extension))
return result
......@@ -162,15 +190,12 @@ class Result(object):
If no bilby.core.result.Result is found in the path
"""
if filename is None:
if (outdir is None) and (label is None):
raise ValueError("No information given to load file")
else:
filename = result_file_name(outdir, label, extension='hdf5')
filename = _determine_file_name(filename, outdir, label, 'hdf5')
if os.path.isfile(filename):
dictionary = deepdish.io.load(filename)
# Some versions of deepdish/pytables return the dictionanary as
# a dictionary with a kay 'data'
# a dictionary with a key 'data'
if len(dictionary) == 1 and 'data' in dictionary:
dictionary = dictionary['data']
try:
......@@ -201,17 +226,12 @@ class Result(object):
If no bilby.core.result.Result is found in the path
"""
if filename is None:
if (outdir is None) and (label is None):
raise ValueError("No information given to load file")
else:
filename = result_file_name(outdir, label)
filename = _determine_file_name(filename, outdir, label, 'json')
if os.path.isfile(filename):
dictionary = json.load(open(filename, 'r'))
with open(filename, 'r') as file:
dictionary = json.load(file, object_hook=decode_bilby_json_result)
for key in dictionary.keys():
# Convert some dictionaries back to DataFrames
if key in ['posterior', 'nested_samples']:
dictionary[key] = pd.DataFrame.from_dict(dictionary[key])
# Convert the loaded priors to bilby prior type
if key == 'priors':
for param in dictionary[key].keys():
......@@ -369,8 +389,8 @@ class Result(object):
default=False
outdir: str, optional
Path to the outdir. Default is the one stored in the result object.
extension: str, optional
Whether to save as hdf5 instead of json
extension: str, optional {json, hdf5}
Determines the method to use to store the data
"""
outdir = self._safe_outdir_creation(outdir, self.save_to_file)
file_name = result_file_name(outdir, self.label, extension)
......@@ -392,25 +412,20 @@ class Result(object):
if dictionary.get('priors', False):
dictionary['priors'] = {key: str(self.priors[key]) for key in self.priors}
# Convert callable sampler_kwargs to strings to avoid pickling issues
# Convert callable sampler_kwargs to strings
if dictionary.get('sampler_kwargs', None) is not None:
for key in dictionary['sampler_kwargs']:
if hasattr(dictionary['sampler_kwargs'][key], '__call__'):
dictionary['sampler_kwargs'][key] = str(dictionary['sampler_kwargs'])
# Convert to json saveable format
if extension != 'hdf5':
for key in dictionary.keys():
if isinstance(dictionary[key], pd.core.frame.DataFrame):
dictionary[key] = dictionary[key].to_dict()
elif isinstance(dictionary[key], np.ndarray):
dictionary[key] = dictionary[key].tolist()
try:
if extension == 'hdf5':
if extension == 'json':
with open(file_name, 'w') as file:
json.dump(dictionary, file, indent=2, cls=BilbyResultJsonEncoder)
elif extension == 'hdf5':
deepdish.io.save(file_name, dictionary)
else:
json.dump(dictionary, open(file_name, 'w'), indent=2)
raise ValueError("Extension type {} not understood".format(extension))
except Exception as e:
logger.error("\n\n Saving the data has failed with the "
"following message:\n {} \n\n".format(e))
......@@ -1107,6 +1122,27 @@ class Result(object):
return outdir
class BilbyResultJsonEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return {'__array__': True, 'content': obj.tolist()}
if isinstance(obj, complex):
return {'__complex__': True, 'real': obj.real, 'imag': obj.imag}
if isinstance(obj, pd.core.frame.DataFrame):
return {'__dataframe__': True, 'content': obj.to_dict(orient='list')}
return json.JSONEncoder.default(self, obj)
def decode_bilby_json_result(dct):
if dct.get("__array__", False):
return np.asarray(dct["content"])
if dct.get("__complex__", False):
return complex(dct["real"], dct["imag"])
if dct.get("__dataframe__", False):
return pd.DataFrame(dct['content'])
return dct
def plot_multiple(results, filename=None, labels=None, colours=None,
save=True, evidences=False, **kwargs):
""" Generate a corner plot overlaying two sets of results
......
from __future__ import absolute_import, division
import bilby
import unittest
import numpy as np
import pandas as pd
import shutil
import os
import json
import bilby
class TestJson(unittest.TestCase):
def test_list_encoding(self):
data = dict(x=[1, 2, 3.4])
encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder)
decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result)
self.assertEqual(data.keys(), decoded.keys())
self.assertEqual(type(data['x']), type(decoded['x']))
self.assertTrue(np.all(data['x']==decoded['x']))
def test_array_encoding(self):
data = dict(x=np.array([1, 2, 3.4]))
encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder)
decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result)
self.assertEqual(data.keys(), decoded.keys())
self.assertEqual(type(data['x']), type(decoded['x']))
self.assertTrue(np.all(data['x']==decoded['x']))
def test_complex_encoding(self):
data = dict(x=1 + 3j)
encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder)
decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result)
self.assertEqual(data.keys(), decoded.keys())
self.assertEqual(type(data['x']), type(decoded['x']))
self.assertTrue(np.all(data['x']==decoded['x']))
def test_dataframe_encoding(self):
data = dict(data=pd.DataFrame(dict(x=[3, 4, 5], y=[5, 6, 7])))
encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder)
decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result)
self.assertEqual(data.keys(), decoded.keys())
self.assertEqual(type(data['data']), type(decoded['data']))
self.assertTrue(np.all(data['data']['x']==decoded['data']['x']))
self.assertTrue(np.all(data['data']['y']==decoded['data']['y']))
class TestResult(unittest.TestCase):
......@@ -55,14 +92,17 @@ class TestResult(unittest.TestCase):
outdir = 'outdir'
label = 'label'
self.assertEqual(bilby.core.result.result_file_name(outdir, label, extension='hdf5'),
'{}/{}_result.h5'.format(outdir, label))
'{}/{}_result.hdf5'.format(outdir, label))
def test_fail_save_and_load(self):
with self.assertRaises(ValueError):
bilby.core.result.read_in_result()
with self.assertRaises(ValueError):
bilby.core.result.read_in_result(filename='no_file_extension')
with self.assertRaises(IOError):
bilby.core.result.read_in_result(filename='not/a/file')
bilby.core.result.read_in_result(filename='not/a/file.json')
def test_unset_priors(self):
result = bilby.core.result.Result(
......@@ -113,7 +153,7 @@ class TestResult(unittest.TestCase):
def test_save_and_load_hdf5(self):
self.result.save_to_file(extension='hdf5')
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label)
outdir=self.result.outdir, label=self.result.label, extension='hdf5')
self.assertTrue(pd.DataFrame.equals
(self.result.posterior, loaded_result.posterior))
self.assertTrue(self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys)
......@@ -160,16 +200,25 @@ class TestResult(unittest.TestCase):
def test_save_and_dont_overwrite_hdf5(self):
shutil.rmtree(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
'{}/{}_result.hdf5.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=False, extension='hdf5')
self.result.save_to_file(overwrite=False, extension='hdf5')
self.assertTrue(os.path.isfile(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label)))
'{}/{}_result.hdf5.old'.format(self.result.outdir, self.result.label)))
def test_save_and_overwrite_hdf5(self):
shutil.rmtree(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
'{}/{}_result.hdf5.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=True, extension='hdf5')
self.result.save_to_file(overwrite=True, extension='hdf5')
self.assertFalse(os.path.isfile(
'{}/{}_result.hdf5.old'.format(self.result.outdir, self.result.label)))
def test_save_and_overwrite_default(self):
shutil.rmtree(
'{}/{}_result.json.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=True, extension='hdf5')
self.result.save_to_file(overwrite=True, extension='hdf5')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment