Skip to content
Snippets Groups Projects
Commit 09e3c9f3 authored by Sylvia Biscoveanu's avatar Sylvia Biscoveanu Committed by Gregory Ashton
Browse files

Add ability to load results produced with custom priors

parent b3c8e741
No related branches found
No related tags found
1 merge request!1010Add ability to load results produced with custom priors
......@@ -25,25 +25,29 @@ def check_directory_exists_and_if_not_mkdir(directory):
class BilbyJsonEncoder(json.JSONEncoder):
def default(self, obj):
from ..prior import MultivariateGaussianDist, Prior, PriorDict
from ...gw.prior import HealPixMapPriorDist
from ...bilby_mcmc.proposals import ProposalCycle
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, PriorDict):
return {'__prior_dict__': True, 'content': obj._get_json_dict()}
return {"__prior_dict__": True, "content": obj._get_json_dict()}
if isinstance(obj, (MultivariateGaussianDist, HealPixMapPriorDist, Prior)):
return {'__prior__': True, '__module__': obj.__module__,
'__name__': obj.__class__.__name__,
'kwargs': dict(obj.get_instantiation_dict())}
return {
"__prior__": True,
"__module__": obj.__module__,
"__name__": obj.__class__.__name__,
"kwargs": dict(obj.get_instantiation_dict()),
}
if isinstance(obj, ProposalCycle):
return str(obj)
try:
from astropy import cosmology as cosmo, units
if isinstance(obj, cosmo.FLRW):
return encode_astropy_cosmology(obj)
if isinstance(obj, units.Quantity):
......@@ -53,82 +57,104 @@ class BilbyJsonEncoder(json.JSONEncoder):
except ImportError:
logger.debug("Cannot import astropy, cannot write cosmological priors")
if isinstance(obj, np.ndarray):
return {'__array__': True, 'content': obj.tolist()}
return {"__array__": True, "content": obj.tolist()}
if isinstance(obj, complex):
return {'__complex__': True, 'real': obj.real, 'imag': obj.imag}
return {"__complex__": True, "real": obj.real, "imag": obj.imag}
if isinstance(obj, pd.DataFrame):
return {'__dataframe__': True, 'content': obj.to_dict(orient='list')}
return {"__dataframe__": True, "content": obj.to_dict(orient="list")}
if isinstance(obj, pd.Series):
return {'__series__': True, 'content': obj.to_dict()}
return {"__series__": True, "content": obj.to_dict()}
if inspect.isfunction(obj):
return {"__function__": True, "__module__": obj.__module__, "__name__": obj.__name__}
return {
"__function__": True,
"__module__": obj.__module__,
"__name__": obj.__name__,
}
if inspect.isclass(obj):
return {"__class__": True, "__module__": obj.__module__, "__name__": obj.__name__}
return {
"__class__": True,
"__module__": obj.__module__,
"__name__": obj.__name__,
}
return json.JSONEncoder.default(self, obj)
def encode_astropy_cosmology(obj):
cls_name = obj.__class__.__name__
dct = {key: getattr(obj, key) for
key in infer_args_from_method(obj.__init__)}
dct['__cosmology__'] = True
dct['__name__'] = cls_name
dct = {key: getattr(obj, key) for key in infer_args_from_method(obj.__init__)}
dct["__cosmology__"] = True
dct["__name__"] = cls_name
return dct
def encode_astropy_quantity(dct):
dct = dict(__astropy_quantity__=True, value=dct.value, unit=str(dct.unit))
if isinstance(dct['value'], np.ndarray):
dct['value'] = list(dct['value'])
if isinstance(dct["value"], np.ndarray):
dct["value"] = list(dct["value"])
return dct
def decode_astropy_cosmology(dct):
try:
from astropy import cosmology as cosmo
cosmo_cls = getattr(cosmo, dct['__name__'])
del dct['__cosmology__'], dct['__name__']
cosmo_cls = getattr(cosmo, dct["__name__"])
del dct["__cosmology__"], dct["__name__"]
return cosmo_cls(**dct)
except ImportError:
logger.debug("Cannot import astropy, cosmological priors may not be "
"properly loaded.")
logger.debug(
"Cannot import astropy, cosmological priors may not be " "properly loaded."
)
return dct
def decode_astropy_quantity(dct):
try:
from astropy import units
if dct['value'] is None:
if dct["value"] is None:
return None
else:
del dct['__astropy_quantity__']
del dct["__astropy_quantity__"]
return units.Quantity(**dct)
except ImportError:
logger.debug("Cannot import astropy, cosmological priors may not be "
"properly loaded.")
logger.debug(
"Cannot import astropy, cosmological priors may not be " "properly loaded."
)
return dct
def load_json(filename, gzip):
if gzip or os.path.splitext(filename)[1].lstrip('.') == 'gz':
if gzip or os.path.splitext(filename)[1].lstrip(".") == "gz":
import gzip
with gzip.GzipFile(filename, 'r') as file:
json_str = file.read().decode('utf-8')
with gzip.GzipFile(filename, "r") as file:
json_str = file.read().decode("utf-8")
dictionary = json.loads(json_str, object_hook=decode_bilby_json)
else:
with open(filename, 'r') as file:
with open(filename, "r") as file:
dictionary = json.load(file, object_hook=decode_bilby_json)
return dictionary
def decode_bilby_json(dct):
if dct.get("__prior_dict__", False):
cls = getattr(import_module(dct['__module__']), dct['__name__'])
cls = getattr(import_module(dct["__module__"]), dct["__name__"])
obj = cls._get_from_json_dict(dct)
return obj
if dct.get("__prior__", False):
cls = getattr(import_module(dct['__module__']), dct['__name__'])
obj = cls(**dct['kwargs'])
try:
cls = getattr(import_module(dct["__module__"]), dct["__name__"])
except AttributeError:
logger.debug(
"Unknown prior class for parameter {}, defaulting to base Prior object".format(
dct["kwargs"]["name"]
)
)
from ..prior import Prior
cls = Prior
obj = cls(**dct["kwargs"])
return obj
if dct.get("__cosmology__", False):
return decode_astropy_cosmology(dct)
......@@ -139,9 +165,9 @@ def decode_bilby_json(dct):
if dct.get("__complex__", False):
return complex(dct["real"], dct["imag"])
if dct.get("__dataframe__", False):
return pd.DataFrame(dct['content'])
return pd.DataFrame(dct["content"])
if dct.get("__series__", False):
return pd.Series(dct['content'])
return pd.Series(dct["content"])
if dct.get("__function__", False) or dct.get("__class__", False):
default = ".".join([dct["__module__"], dct["__name__"]])
return getattr(import_module(dct["__module__"]), dct["__name__"], default)
......@@ -225,6 +251,7 @@ def encode_for_hdf5(key, item):
Input item converted into HDF5 saveable format
"""
from ..prior.dict import PriorDict
if isinstance(item, np.int_):
item = int(item)
elif isinstance(item, np.float_):
......@@ -258,7 +285,9 @@ def encode_for_hdf5(key, item):
elif isinstance(item, pd.Series):
output = item.to_dict()
elif inspect.isfunction(item) or inspect.isclass(item):
output = dict(__module__=item.__module__, __name__=item.__name__, __class__=True)
output = dict(
__module__=item.__module__, __name__=item.__name__, __class__=True
)
elif isinstance(item, dict):
output = item.copy()
elif isinstance(item, tuple):
......@@ -287,12 +316,15 @@ def recursively_load_dict_contents_from_group(h5file, path):
The contents of the HDF5 file unpacked into the dictionary.
"""
import h5py
output = dict()
for key, item in h5file[path].items():
if isinstance(item, h5py.Dataset):
output[key] = decode_from_hdf5(item[()])
elif isinstance(item, h5py.Group):
output[key] = recursively_load_dict_contents_from_group(h5file, path + key + '/')
output[key] = recursively_load_dict_contents_from_group(
h5file, path + key + "/"
)
return output
......@@ -314,7 +346,7 @@ def recursively_save_dict_contents_to_group(h5file, path, dic):
for key, item in dic.items():
item = encode_for_hdf5(key, item)
if isinstance(item, dict):
recursively_save_dict_contents_to_group(h5file, path + key + '/', item)
recursively_save_dict_contents_to_group(h5file, path + key + "/", item)
else:
h5file[path + key] = item
......@@ -351,24 +383,23 @@ def move_old_file(filename, overwrite=False):
"""
if os.path.isfile(filename):
if overwrite:
logger.debug('Removing existing file {}'.format(filename))
logger.debug("Removing existing file {}".format(filename))
os.remove(filename)
else:
logger.debug(
'Renaming existing file {} to {}.old'.format(filename,
filename))
shutil.move(filename, filename + '.old')
"Renaming existing file {} to {}.old".format(filename, filename)
)
shutil.move(filename, filename + ".old")
logger.debug("Saving result to {}".format(filename))
def safe_save_figure(fig, filename, **kwargs):
check_directory_exists_and_if_not_mkdir(os.path.dirname(filename))
from matplotlib import rcParams
try:
fig.savefig(fname=filename, **kwargs)
except RuntimeError:
logger.debug(
"Failed to save plot with tex labels turning off tex."
)
logger.debug("Failed to save plot with tex labels turning off tex.")
rcParams["text.usetex"] = False
fig.savefig(fname=filename, **kwargs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment