Skip to content
Snippets Groups Projects
Commit edcf3efd authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch 'update-hdf5' into 'master'

HDF5: Fix loading meta data, booleans, string lists

See merge request lscsoft/bilby!941
parents 4e3e6892 992469d6
No related branches found
No related tags found
1 merge request!941HDF5: Fix loading meta data, booleans, string lists
Pipeline #218630 passed with warnings
......@@ -20,6 +20,7 @@ from .utils import (
decode_bilby_json, docstring,
recursively_save_dict_contents_to_group,
recursively_load_dict_contents_from_group,
recursively_decode_bilby_json,
)
from .prior import Prior, PriorDict, DeltaFunction
......@@ -560,6 +561,17 @@ class Result(object):
else:
return ''
@property
def meta_data(self):
return self._meta_data
@meta_data.setter
def meta_data(self, meta_data):
if meta_data is None:
meta_data = dict()
meta_data = recursively_decode_bilby_json(meta_data)
self._meta_data = meta_data
@property
def priors(self):
if self._priors is not None:
......
......@@ -1108,6 +1108,28 @@ def decode_astropy_quantity(dct):
return dct
def recursively_decode_bilby_json(dct):
"""
Recursively call `bilby_decode_json`
Parameters
----------
dct: dict
The dictionary to decode
Returns
-------
dct: dict
The original dictionary with all the elements decode if possible
"""
dct = decode_bilby_json(dct)
if isinstance(dct, dict):
for key in dct:
if isinstance(dct[key], dict):
dct[key] = recursively_decode_bilby_json(dct[key])
return dct
def reflect(u):
"""
Iteratively reflect a number until it is contained in [0, 1].
......@@ -1307,10 +1329,14 @@ def decode_from_hdf5(item):
elif isinstance(item, (bytes, bytearray)):
output = item.decode()
elif isinstance(item, np.ndarray):
if "|S" in str(item.dtype) or isinstance(item[0], bytes):
if item.size == 0:
output = item
elif "|S" in str(item.dtype) or isinstance(item[0], bytes):
output = [it.decode() for it in item]
else:
output = item
elif isinstance(item, np.bool_):
output = bool(item)
else:
output = item
return output
......@@ -1364,9 +1390,11 @@ def encode_for_hdf5(item):
elif isinstance(item, pd.Series):
output = item.to_dict()
elif inspect.isfunction(item) or inspect.isclass(item):
output = dict(__module__=item.__module__, __name__=item.__name__)
output = dict(__module__=item.__module__, __name__=item.__name__, __class__=True)
elif isinstance(item, dict):
output = item.copy()
elif isinstance(item, tuple):
output = {str(ii): elem for ii, elem in enumerate(item)}
else:
raise ValueError(f'Cannot save {type(item)} type')
return output
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment