Skip to content
Snippets Groups Projects
Commit 1f8a575b authored by Gregory Ashton's avatar Gregory Ashton Committed by Moritz Huebner
Browse files

Improvements to the JSON interface

parent 1db36052
No related branches found
No related tags found
No related merge requests found
......@@ -36,19 +36,47 @@ def result_file_name(outdir, label, extension='json'):
-------
str: File name of the output file
"""
if extension == 'hdf5':
return '{}/{}_result.h5'.format(outdir, label)
if extension in ['json', 'hdf5']:
return '{}/{}_result.{}'.format(outdir, label, extension)
else:
return '{}/{}_result.json'.format(outdir, label)
raise ValueError("Extension type {} not understood".format(extension))
def read_in_result(filename=None, outdir=None, label=None):
""" Wrapper to bilby.core.result.Result.from_hdf5
or bilby.core.result.Result.from_json """
try:
result = Result.from_json(filename=filename, outdir=outdir, label=label)
except (IOError, ValueError):
result = Result.from_hdf5(filename=filename, outdir=outdir, label=label)
def _determine_file_name(filename, outdir, label, extension):
""" Helper method to determine the filename """
if filename is not None:
return filename
else:
if (outdir is None) and (label is None):
raise ValueError("No information given to load file")
else:
return result_file_name(outdir, label, extension)
def read_in_result(filename=None, outdir=None, label=None, extension='json'):
""" Reads in a stored bilby result object
Parameters
----------
filename: str
Path to the file to be read (alternative to giving the outdir and label)
outdir, label, extension: str
Name of the output directory, label and extension used for the default
naming scheme.
"""
filename = _determine_file_name(filename, outdir, label, extension)
# Get the actual extension (may differ from the default extension if the filename is given)
extension = os.path.splitext(filename)[1].lstrip('.')
if 'json' in extension:
result = Result.from_json(filename=filename)
elif ('hdf5' in extension) or ('h5' in extension):
result = Result.from_hdf5(filename=filename)
elif extension is None:
raise ValueError("No filetype extension provided")
else:
raise ValueError("Filetype {} not understood".format(extension))
return result
......@@ -162,15 +190,12 @@ class Result(object):
If no bilby.core.result.Result is found in the path
"""
if filename is None:
if (outdir is None) and (label is None):
raise ValueError("No information given to load file")
else:
filename = result_file_name(outdir, label, extension='hdf5')
filename = _determine_file_name(filename, outdir, label, 'hdf5')
if os.path.isfile(filename):
dictionary = deepdish.io.load(filename)
# Some versions of deepdish/pytables return the dictionanary as
# a dictionary with a kay 'data'
# a dictionary with a key 'data'
if len(dictionary) == 1 and 'data' in dictionary:
dictionary = dictionary['data']
try:
......@@ -201,17 +226,12 @@ class Result(object):
If no bilby.core.result.Result is found in the path
"""
if filename is None:
if (outdir is None) and (label is None):
raise ValueError("No information given to load file")
else:
filename = result_file_name(outdir, label)
filename = _determine_file_name(filename, outdir, label, 'json')
if os.path.isfile(filename):
dictionary = json.load(open(filename, 'r'))
with open(filename, 'r') as file:
dictionary = json.load(file, object_hook=decode_bilby_json_result)
for key in dictionary.keys():
# Convert some dictionaries back to DataFrames
if key in ['posterior', 'nested_samples']:
dictionary[key] = pd.DataFrame.from_dict(dictionary[key])
# Convert the loaded priors to bilby prior type
if key == 'priors':
for param in dictionary[key].keys():
......@@ -369,8 +389,8 @@ class Result(object):
default=False
outdir: str, optional
Path to the outdir. Default is the one stored in the result object.
extension: str, optional
Whether to save as hdf5 instead of json
extension: str, optional {json, hdf5}
Determines the method to use to store the data
"""
outdir = self._safe_outdir_creation(outdir, self.save_to_file)
file_name = result_file_name(outdir, self.label, extension)
......@@ -392,25 +412,20 @@ class Result(object):
if dictionary.get('priors', False):
dictionary['priors'] = {key: str(self.priors[key]) for key in self.priors}
# Convert callable sampler_kwargs to strings to avoid pickling issues
# Convert callable sampler_kwargs to strings
if dictionary.get('sampler_kwargs', None) is not None:
for key in dictionary['sampler_kwargs']:
if hasattr(dictionary['sampler_kwargs'][key], '__call__'):
dictionary['sampler_kwargs'][key] = str(dictionary['sampler_kwargs'])
# Convert to json saveable format
if extension != 'hdf5':
for key in dictionary.keys():
if isinstance(dictionary[key], pd.core.frame.DataFrame):
dictionary[key] = dictionary[key].to_dict()
elif isinstance(dictionary[key], np.ndarray):
dictionary[key] = dictionary[key].tolist()
try:
if extension == 'hdf5':
if extension == 'json':
with open(file_name, 'w') as file:
json.dump(dictionary, file, indent=2, cls=BilbyResultJsonEncoder)
elif extension == 'hdf5':
deepdish.io.save(file_name, dictionary)
else:
json.dump(dictionary, open(file_name, 'w'), indent=2)
raise ValueError("Extension type {} not understood".format(extension))
except Exception as e:
logger.error("\n\n Saving the data has failed with the "
"following message:\n {} \n\n".format(e))
......@@ -1107,6 +1122,27 @@ class Result(object):
return outdir
class BilbyResultJsonEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return {'__array__': True, 'content': obj.tolist()}
if isinstance(obj, complex):
return {'__complex__': True, 'real': obj.real, 'imag': obj.imag}
if isinstance(obj, pd.core.frame.DataFrame):
return {'__dataframe__': True, 'content': obj.to_dict(orient='list')}
return json.JSONEncoder.default(self, obj)
def decode_bilby_json_result(dct):
if dct.get("__array__", False):
return np.asarray(dct["content"])
if dct.get("__complex__", False):
return complex(dct["real"], dct["imag"])
if dct.get("__dataframe__", False):
return pd.DataFrame(dct['content'])
return dct
def plot_multiple(results, filename=None, labels=None, colours=None,
save=True, evidences=False, **kwargs):
""" Generate a corner plot overlaying two sets of results
......
from __future__ import absolute_import, division
import bilby
import unittest
import numpy as np
import pandas as pd
import shutil
import os
import json
import bilby
class TestJson(unittest.TestCase):
def test_list_encoding(self):
data = dict(x=[1, 2, 3.4])
encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder)
decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result)
self.assertEqual(data.keys(), decoded.keys())
self.assertEqual(type(data['x']), type(decoded['x']))
self.assertTrue(np.all(data['x']==decoded['x']))
def test_array_encoding(self):
data = dict(x=np.array([1, 2, 3.4]))
encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder)
decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result)
self.assertEqual(data.keys(), decoded.keys())
self.assertEqual(type(data['x']), type(decoded['x']))
self.assertTrue(np.all(data['x']==decoded['x']))
def test_complex_encoding(self):
data = dict(x=1 + 3j)
encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder)
decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result)
self.assertEqual(data.keys(), decoded.keys())
self.assertEqual(type(data['x']), type(decoded['x']))
self.assertTrue(np.all(data['x']==decoded['x']))
def test_dataframe_encoding(self):
data = dict(data=pd.DataFrame(dict(x=[3, 4, 5], y=[5, 6, 7])))
encoded = json.dumps(data, cls=bilby.core.result.BilbyResultJsonEncoder)
decoded = json.loads(encoded, object_hook=bilby.core.result.decode_bilby_json_result)
self.assertEqual(data.keys(), decoded.keys())
self.assertEqual(type(data['data']), type(decoded['data']))
self.assertTrue(np.all(data['data']['x']==decoded['data']['x']))
self.assertTrue(np.all(data['data']['y']==decoded['data']['y']))
class TestResult(unittest.TestCase):
......@@ -55,14 +92,17 @@ class TestResult(unittest.TestCase):
outdir = 'outdir'
label = 'label'
self.assertEqual(bilby.core.result.result_file_name(outdir, label, extension='hdf5'),
'{}/{}_result.h5'.format(outdir, label))
'{}/{}_result.hdf5'.format(outdir, label))
def test_fail_save_and_load(self):
with self.assertRaises(ValueError):
bilby.core.result.read_in_result()
with self.assertRaises(ValueError):
bilby.core.result.read_in_result(filename='no_file_extension')
with self.assertRaises(IOError):
bilby.core.result.read_in_result(filename='not/a/file')
bilby.core.result.read_in_result(filename='not/a/file.json')
def test_unset_priors(self):
result = bilby.core.result.Result(
......@@ -113,7 +153,7 @@ class TestResult(unittest.TestCase):
def test_save_and_load_hdf5(self):
self.result.save_to_file(extension='hdf5')
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label)
outdir=self.result.outdir, label=self.result.label, extension='hdf5')
self.assertTrue(pd.DataFrame.equals
(self.result.posterior, loaded_result.posterior))
self.assertTrue(self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys)
......@@ -160,16 +200,25 @@ class TestResult(unittest.TestCase):
def test_save_and_dont_overwrite_hdf5(self):
shutil.rmtree(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
'{}/{}_result.hdf5.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=False, extension='hdf5')
self.result.save_to_file(overwrite=False, extension='hdf5')
self.assertTrue(os.path.isfile(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label)))
'{}/{}_result.hdf5.old'.format(self.result.outdir, self.result.label)))
def test_save_and_overwrite_hdf5(self):
shutil.rmtree(
'{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
'{}/{}_result.hdf5.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=True, extension='hdf5')
self.result.save_to_file(overwrite=True, extension='hdf5')
self.assertFalse(os.path.isfile(
'{}/{}_result.hdf5.old'.format(self.result.outdir, self.result.label)))
def test_save_and_overwrite_default(self):
shutil.rmtree(
'{}/{}_result.json.old'.format(self.result.outdir, self.result.label),
ignore_errors=True)
self.result.save_to_file(overwrite=True, extension='hdf5')
self.result.save_to_file(overwrite=True, extension='hdf5')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment