Skip to content
Snippets Groups Projects
Commit 729c1804 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'read_custom_prior' into 'master'

Add ability to load results produced with custom priors

See merge request lscsoft/bilby!1010
parents 8a735040 09e3c9f3
No related branches found
No related tags found
1 merge request!1010Add ability to load results produced with custom priors
Pipeline #282305 canceled
...@@ -25,25 +25,29 @@ def check_directory_exists_and_if_not_mkdir(directory): ...@@ -25,25 +25,29 @@ def check_directory_exists_and_if_not_mkdir(directory):
class BilbyJsonEncoder(json.JSONEncoder): class BilbyJsonEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
from ..prior import MultivariateGaussianDist, Prior, PriorDict from ..prior import MultivariateGaussianDist, Prior, PriorDict
from ...gw.prior import HealPixMapPriorDist from ...gw.prior import HealPixMapPriorDist
from ...bilby_mcmc.proposals import ProposalCycle from ...bilby_mcmc.proposals import ProposalCycle
if isinstance(obj, np.integer): if isinstance(obj, np.integer):
return int(obj) return int(obj)
if isinstance(obj, np.floating): if isinstance(obj, np.floating):
return float(obj) return float(obj)
if isinstance(obj, PriorDict): if isinstance(obj, PriorDict):
return {'__prior_dict__': True, 'content': obj._get_json_dict()} return {"__prior_dict__": True, "content": obj._get_json_dict()}
if isinstance(obj, (MultivariateGaussianDist, HealPixMapPriorDist, Prior)): if isinstance(obj, (MultivariateGaussianDist, HealPixMapPriorDist, Prior)):
return {'__prior__': True, '__module__': obj.__module__, return {
'__name__': obj.__class__.__name__, "__prior__": True,
'kwargs': dict(obj.get_instantiation_dict())} "__module__": obj.__module__,
"__name__": obj.__class__.__name__,
"kwargs": dict(obj.get_instantiation_dict()),
}
if isinstance(obj, ProposalCycle): if isinstance(obj, ProposalCycle):
return str(obj) return str(obj)
try: try:
from astropy import cosmology as cosmo, units from astropy import cosmology as cosmo, units
if isinstance(obj, cosmo.FLRW): if isinstance(obj, cosmo.FLRW):
return encode_astropy_cosmology(obj) return encode_astropy_cosmology(obj)
if isinstance(obj, units.Quantity): if isinstance(obj, units.Quantity):
...@@ -53,82 +57,104 @@ class BilbyJsonEncoder(json.JSONEncoder): ...@@ -53,82 +57,104 @@ class BilbyJsonEncoder(json.JSONEncoder):
except ImportError: except ImportError:
logger.debug("Cannot import astropy, cannot write cosmological priors") logger.debug("Cannot import astropy, cannot write cosmological priors")
if isinstance(obj, np.ndarray): if isinstance(obj, np.ndarray):
return {'__array__': True, 'content': obj.tolist()} return {"__array__": True, "content": obj.tolist()}
if isinstance(obj, complex): if isinstance(obj, complex):
return {'__complex__': True, 'real': obj.real, 'imag': obj.imag} return {"__complex__": True, "real": obj.real, "imag": obj.imag}
if isinstance(obj, pd.DataFrame): if isinstance(obj, pd.DataFrame):
return {'__dataframe__': True, 'content': obj.to_dict(orient='list')} return {"__dataframe__": True, "content": obj.to_dict(orient="list")}
if isinstance(obj, pd.Series): if isinstance(obj, pd.Series):
return {'__series__': True, 'content': obj.to_dict()} return {"__series__": True, "content": obj.to_dict()}
if inspect.isfunction(obj): if inspect.isfunction(obj):
return {"__function__": True, "__module__": obj.__module__, "__name__": obj.__name__} return {
"__function__": True,
"__module__": obj.__module__,
"__name__": obj.__name__,
}
if inspect.isclass(obj): if inspect.isclass(obj):
return {"__class__": True, "__module__": obj.__module__, "__name__": obj.__name__} return {
"__class__": True,
"__module__": obj.__module__,
"__name__": obj.__name__,
}
return json.JSONEncoder.default(self, obj) return json.JSONEncoder.default(self, obj)
def encode_astropy_cosmology(obj): def encode_astropy_cosmology(obj):
cls_name = obj.__class__.__name__ cls_name = obj.__class__.__name__
dct = {key: getattr(obj, key) for dct = {key: getattr(obj, key) for key in infer_args_from_method(obj.__init__)}
key in infer_args_from_method(obj.__init__)} dct["__cosmology__"] = True
dct['__cosmology__'] = True dct["__name__"] = cls_name
dct['__name__'] = cls_name
return dct return dct
def encode_astropy_quantity(dct): def encode_astropy_quantity(dct):
dct = dict(__astropy_quantity__=True, value=dct.value, unit=str(dct.unit)) dct = dict(__astropy_quantity__=True, value=dct.value, unit=str(dct.unit))
if isinstance(dct['value'], np.ndarray): if isinstance(dct["value"], np.ndarray):
dct['value'] = list(dct['value']) dct["value"] = list(dct["value"])
return dct return dct
def decode_astropy_cosmology(dct): def decode_astropy_cosmology(dct):
try: try:
from astropy import cosmology as cosmo from astropy import cosmology as cosmo
cosmo_cls = getattr(cosmo, dct['__name__'])
del dct['__cosmology__'], dct['__name__'] cosmo_cls = getattr(cosmo, dct["__name__"])
del dct["__cosmology__"], dct["__name__"]
return cosmo_cls(**dct) return cosmo_cls(**dct)
except ImportError: except ImportError:
logger.debug("Cannot import astropy, cosmological priors may not be " logger.debug(
"properly loaded.") "Cannot import astropy, cosmological priors may not be " "properly loaded."
)
return dct return dct
def decode_astropy_quantity(dct): def decode_astropy_quantity(dct):
try: try:
from astropy import units from astropy import units
if dct['value'] is None:
if dct["value"] is None:
return None return None
else: else:
del dct['__astropy_quantity__'] del dct["__astropy_quantity__"]
return units.Quantity(**dct) return units.Quantity(**dct)
except ImportError: except ImportError:
logger.debug("Cannot import astropy, cosmological priors may not be " logger.debug(
"properly loaded.") "Cannot import astropy, cosmological priors may not be " "properly loaded."
)
return dct return dct
def load_json(filename, gzip): def load_json(filename, gzip):
if gzip or os.path.splitext(filename)[1].lstrip('.') == 'gz': if gzip or os.path.splitext(filename)[1].lstrip(".") == "gz":
import gzip import gzip
with gzip.GzipFile(filename, 'r') as file:
json_str = file.read().decode('utf-8') with gzip.GzipFile(filename, "r") as file:
json_str = file.read().decode("utf-8")
dictionary = json.loads(json_str, object_hook=decode_bilby_json) dictionary = json.loads(json_str, object_hook=decode_bilby_json)
else: else:
with open(filename, 'r') as file: with open(filename, "r") as file:
dictionary = json.load(file, object_hook=decode_bilby_json) dictionary = json.load(file, object_hook=decode_bilby_json)
return dictionary return dictionary
def decode_bilby_json(dct): def decode_bilby_json(dct):
if dct.get("__prior_dict__", False): if dct.get("__prior_dict__", False):
cls = getattr(import_module(dct['__module__']), dct['__name__']) cls = getattr(import_module(dct["__module__"]), dct["__name__"])
obj = cls._get_from_json_dict(dct) obj = cls._get_from_json_dict(dct)
return obj return obj
if dct.get("__prior__", False): if dct.get("__prior__", False):
cls = getattr(import_module(dct['__module__']), dct['__name__']) try:
obj = cls(**dct['kwargs']) cls = getattr(import_module(dct["__module__"]), dct["__name__"])
except AttributeError:
logger.debug(
"Unknown prior class for parameter {}, defaulting to base Prior object".format(
dct["kwargs"]["name"]
)
)
from ..prior import Prior
cls = Prior
obj = cls(**dct["kwargs"])
return obj return obj
if dct.get("__cosmology__", False): if dct.get("__cosmology__", False):
return decode_astropy_cosmology(dct) return decode_astropy_cosmology(dct)
...@@ -139,9 +165,9 @@ def decode_bilby_json(dct): ...@@ -139,9 +165,9 @@ def decode_bilby_json(dct):
if dct.get("__complex__", False): if dct.get("__complex__", False):
return complex(dct["real"], dct["imag"]) return complex(dct["real"], dct["imag"])
if dct.get("__dataframe__", False): if dct.get("__dataframe__", False):
return pd.DataFrame(dct['content']) return pd.DataFrame(dct["content"])
if dct.get("__series__", False): if dct.get("__series__", False):
return pd.Series(dct['content']) return pd.Series(dct["content"])
if dct.get("__function__", False) or dct.get("__class__", False): if dct.get("__function__", False) or dct.get("__class__", False):
default = ".".join([dct["__module__"], dct["__name__"]]) default = ".".join([dct["__module__"], dct["__name__"]])
return getattr(import_module(dct["__module__"]), dct["__name__"], default) return getattr(import_module(dct["__module__"]), dct["__name__"], default)
...@@ -225,6 +251,7 @@ def encode_for_hdf5(key, item): ...@@ -225,6 +251,7 @@ def encode_for_hdf5(key, item):
Input item converted into HDF5 saveable format Input item converted into HDF5 saveable format
""" """
from ..prior.dict import PriorDict from ..prior.dict import PriorDict
if isinstance(item, np.int_): if isinstance(item, np.int_):
item = int(item) item = int(item)
elif isinstance(item, np.float_): elif isinstance(item, np.float_):
...@@ -258,7 +285,9 @@ def encode_for_hdf5(key, item): ...@@ -258,7 +285,9 @@ def encode_for_hdf5(key, item):
elif isinstance(item, pd.Series): elif isinstance(item, pd.Series):
output = item.to_dict() output = item.to_dict()
elif inspect.isfunction(item) or inspect.isclass(item): elif inspect.isfunction(item) or inspect.isclass(item):
output = dict(__module__=item.__module__, __name__=item.__name__, __class__=True) output = dict(
__module__=item.__module__, __name__=item.__name__, __class__=True
)
elif isinstance(item, dict): elif isinstance(item, dict):
output = item.copy() output = item.copy()
elif isinstance(item, tuple): elif isinstance(item, tuple):
...@@ -287,12 +316,15 @@ def recursively_load_dict_contents_from_group(h5file, path): ...@@ -287,12 +316,15 @@ def recursively_load_dict_contents_from_group(h5file, path):
The contents of the HDF5 file unpacked into the dictionary. The contents of the HDF5 file unpacked into the dictionary.
""" """
import h5py import h5py
output = dict() output = dict()
for key, item in h5file[path].items(): for key, item in h5file[path].items():
if isinstance(item, h5py.Dataset): if isinstance(item, h5py.Dataset):
output[key] = decode_from_hdf5(item[()]) output[key] = decode_from_hdf5(item[()])
elif isinstance(item, h5py.Group): elif isinstance(item, h5py.Group):
output[key] = recursively_load_dict_contents_from_group(h5file, path + key + '/') output[key] = recursively_load_dict_contents_from_group(
h5file, path + key + "/"
)
return output return output
...@@ -314,7 +346,7 @@ def recursively_save_dict_contents_to_group(h5file, path, dic): ...@@ -314,7 +346,7 @@ def recursively_save_dict_contents_to_group(h5file, path, dic):
for key, item in dic.items(): for key, item in dic.items():
item = encode_for_hdf5(key, item) item = encode_for_hdf5(key, item)
if isinstance(item, dict): if isinstance(item, dict):
recursively_save_dict_contents_to_group(h5file, path + key + '/', item) recursively_save_dict_contents_to_group(h5file, path + key + "/", item)
else: else:
h5file[path + key] = item h5file[path + key] = item
...@@ -351,24 +383,23 @@ def move_old_file(filename, overwrite=False): ...@@ -351,24 +383,23 @@ def move_old_file(filename, overwrite=False):
""" """
if os.path.isfile(filename): if os.path.isfile(filename):
if overwrite: if overwrite:
logger.debug('Removing existing file {}'.format(filename)) logger.debug("Removing existing file {}".format(filename))
os.remove(filename) os.remove(filename)
else: else:
logger.debug( logger.debug(
'Renaming existing file {} to {}.old'.format(filename, "Renaming existing file {} to {}.old".format(filename, filename)
filename)) )
shutil.move(filename, filename + '.old') shutil.move(filename, filename + ".old")
logger.debug("Saving result to {}".format(filename)) logger.debug("Saving result to {}".format(filename))
def safe_save_figure(fig, filename, **kwargs): def safe_save_figure(fig, filename, **kwargs):
check_directory_exists_and_if_not_mkdir(os.path.dirname(filename)) check_directory_exists_and_if_not_mkdir(os.path.dirname(filename))
from matplotlib import rcParams from matplotlib import rcParams
try: try:
fig.savefig(fname=filename, **kwargs) fig.savefig(fname=filename, **kwargs)
except RuntimeError: except RuntimeError:
logger.debug( logger.debug("Failed to save plot with tex labels turning off tex.")
"Failed to save plot with tex labels turning off tex."
)
rcParams["text.usetex"] = False rcParams["text.usetex"] = False
fig.savefig(fname=filename, **kwargs) fig.savefig(fname=filename, **kwargs)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment