From 09e3c9f3bd5c4120a283c34d7666572abe8174d6 Mon Sep 17 00:00:00 2001 From: Sylvia Biscoveanu <sylvia.biscoveanu@ligo.org> Date: Fri, 3 Sep 2021 14:42:40 +0000 Subject: [PATCH] Add ability to load results produced with custom priors --- bilby/core/utils/io.py | 119 ++++++++++++++++++++++++++--------------- 1 file changed, 75 insertions(+), 44 deletions(-) diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py index 6bfa73de7..85095dc6a 100644 --- a/bilby/core/utils/io.py +++ b/bilby/core/utils/io.py @@ -25,25 +25,29 @@ def check_directory_exists_and_if_not_mkdir(directory): class BilbyJsonEncoder(json.JSONEncoder): - def default(self, obj): from ..prior import MultivariateGaussianDist, Prior, PriorDict from ...gw.prior import HealPixMapPriorDist from ...bilby_mcmc.proposals import ProposalCycle + if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): return float(obj) if isinstance(obj, PriorDict): - return {'__prior_dict__': True, 'content': obj._get_json_dict()} + return {"__prior_dict__": True, "content": obj._get_json_dict()} if isinstance(obj, (MultivariateGaussianDist, HealPixMapPriorDist, Prior)): - return {'__prior__': True, '__module__': obj.__module__, - '__name__': obj.__class__.__name__, - 'kwargs': dict(obj.get_instantiation_dict())} + return { + "__prior__": True, + "__module__": obj.__module__, + "__name__": obj.__class__.__name__, + "kwargs": dict(obj.get_instantiation_dict()), + } if isinstance(obj, ProposalCycle): return str(obj) try: from astropy import cosmology as cosmo, units + if isinstance(obj, cosmo.FLRW): return encode_astropy_cosmology(obj) if isinstance(obj, units.Quantity): @@ -53,82 +57,104 @@ class BilbyJsonEncoder(json.JSONEncoder): except ImportError: logger.debug("Cannot import astropy, cannot write cosmological priors") if isinstance(obj, np.ndarray): - return {'__array__': True, 'content': obj.tolist()} + return {"__array__": True, "content": obj.tolist()} if isinstance(obj, complex): - return {'__complex__': True, 'real': obj.real, 'imag': obj.imag} + return {"__complex__": True, "real": obj.real, "imag": obj.imag} if isinstance(obj, pd.DataFrame): - return {'__dataframe__': True, 'content': obj.to_dict(orient='list')} + return {"__dataframe__": True, "content": obj.to_dict(orient="list")} if isinstance(obj, pd.Series): - return {'__series__': True, 'content': obj.to_dict()} + return {"__series__": True, "content": obj.to_dict()} if inspect.isfunction(obj): - return {"__function__": True, "__module__": obj.__module__, "__name__": obj.__name__} + return { + "__function__": True, + "__module__": obj.__module__, + "__name__": obj.__name__, + } if inspect.isclass(obj): - return {"__class__": True, "__module__": obj.__module__, "__name__": obj.__name__} + return { + "__class__": True, + "__module__": obj.__module__, + "__name__": obj.__name__, + } return json.JSONEncoder.default(self, obj) def encode_astropy_cosmology(obj): cls_name = obj.__class__.__name__ - dct = {key: getattr(obj, key) for - key in infer_args_from_method(obj.__init__)} - dct['__cosmology__'] = True - dct['__name__'] = cls_name + dct = {key: getattr(obj, key) for key in infer_args_from_method(obj.__init__)} + dct["__cosmology__"] = True + dct["__name__"] = cls_name return dct def encode_astropy_quantity(dct): dct = dict(__astropy_quantity__=True, value=dct.value, unit=str(dct.unit)) - if isinstance(dct['value'], np.ndarray): - dct['value'] = list(dct['value']) + if isinstance(dct["value"], np.ndarray): + dct["value"] = list(dct["value"]) return dct def decode_astropy_cosmology(dct): try: from astropy import cosmology as cosmo - cosmo_cls = getattr(cosmo, dct['__name__']) - del dct['__cosmology__'], dct['__name__'] + + cosmo_cls = getattr(cosmo, dct["__name__"]) + del dct["__cosmology__"], dct["__name__"] return cosmo_cls(**dct) except ImportError: - logger.debug("Cannot import astropy, cosmological priors may not be " - "properly loaded.") + logger.debug( + "Cannot import astropy, cosmological priors may not be " "properly loaded." + ) return dct def decode_astropy_quantity(dct): try: from astropy import units - if dct['value'] is None: + + if dct["value"] is None: return None else: - del dct['__astropy_quantity__'] + del dct["__astropy_quantity__"] return units.Quantity(**dct) except ImportError: - logger.debug("Cannot import astropy, cosmological priors may not be " - "properly loaded.") + logger.debug( + "Cannot import astropy, cosmological priors may not be " "properly loaded." + ) return dct def load_json(filename, gzip): - if gzip or os.path.splitext(filename)[1].lstrip('.') == 'gz': + if gzip or os.path.splitext(filename)[1].lstrip(".") == "gz": import gzip - with gzip.GzipFile(filename, 'r') as file: - json_str = file.read().decode('utf-8') + + with gzip.GzipFile(filename, "r") as file: + json_str = file.read().decode("utf-8") dictionary = json.loads(json_str, object_hook=decode_bilby_json) else: - with open(filename, 'r') as file: + with open(filename, "r") as file: dictionary = json.load(file, object_hook=decode_bilby_json) return dictionary def decode_bilby_json(dct): if dct.get("__prior_dict__", False): - cls = getattr(import_module(dct['__module__']), dct['__name__']) + cls = getattr(import_module(dct["__module__"]), dct["__name__"]) obj = cls._get_from_json_dict(dct) return obj if dct.get("__prior__", False): - cls = getattr(import_module(dct['__module__']), dct['__name__']) - obj = cls(**dct['kwargs']) + try: + cls = getattr(import_module(dct["__module__"]), dct["__name__"]) + except AttributeError: + logger.debug( + "Unknown prior class for parameter {}, defaulting to base Prior object".format( + dct["kwargs"]["name"] + ) + ) + from ..prior import Prior + + cls = Prior + obj = cls(**dct["kwargs"]) return obj if dct.get("__cosmology__", False): return decode_astropy_cosmology(dct) @@ -139,9 +165,9 @@ def decode_bilby_json(dct): if dct.get("__complex__", False): return complex(dct["real"], dct["imag"]) if dct.get("__dataframe__", False): - return pd.DataFrame(dct['content']) + return pd.DataFrame(dct["content"]) if dct.get("__series__", False): - return pd.Series(dct['content']) + return pd.Series(dct["content"]) if dct.get("__function__", False) or dct.get("__class__", False): default = ".".join([dct["__module__"], dct["__name__"]]) return getattr(import_module(dct["__module__"]), dct["__name__"], default) @@ -225,6 +251,7 @@ def encode_for_hdf5(key, item): Input item converted into HDF5 saveable format """ from ..prior.dict import PriorDict + if isinstance(item, np.int_): item = int(item) elif isinstance(item, np.float_): @@ -258,7 +285,9 @@ def encode_for_hdf5(key, item): elif isinstance(item, pd.Series): output = item.to_dict() elif inspect.isfunction(item) or inspect.isclass(item): - output = dict(__module__=item.__module__, __name__=item.__name__, __class__=True) + output = dict( + __module__=item.__module__, __name__=item.__name__, __class__=True + ) elif isinstance(item, dict): output = item.copy() elif isinstance(item, tuple): @@ -287,12 +316,15 @@ def recursively_load_dict_contents_from_group(h5file, path): The contents of the HDF5 file unpacked into the dictionary. """ import h5py + output = dict() for key, item in h5file[path].items(): if isinstance(item, h5py.Dataset): output[key] = decode_from_hdf5(item[()]) elif isinstance(item, h5py.Group): - output[key] = recursively_load_dict_contents_from_group(h5file, path + key + '/') + output[key] = recursively_load_dict_contents_from_group( + h5file, path + key + "/" + ) return output @@ -314,7 +346,7 @@ def recursively_save_dict_contents_to_group(h5file, path, dic): for key, item in dic.items(): item = encode_for_hdf5(key, item) if isinstance(item, dict): - recursively_save_dict_contents_to_group(h5file, path + key + '/', item) + recursively_save_dict_contents_to_group(h5file, path + key + "/", item) else: h5file[path + key] = item @@ -351,24 +383,23 @@ def move_old_file(filename, overwrite=False): """ if os.path.isfile(filename): if overwrite: - logger.debug('Removing existing file {}'.format(filename)) + logger.debug("Removing existing file {}".format(filename)) os.remove(filename) else: logger.debug( - 'Renaming existing file {} to {}.old'.format(filename, - filename)) - shutil.move(filename, filename + '.old') + "Renaming existing file {} to {}.old".format(filename, filename) + ) + shutil.move(filename, filename + ".old") logger.debug("Saving result to {}".format(filename)) def safe_save_figure(fig, filename, **kwargs): check_directory_exists_and_if_not_mkdir(os.path.dirname(filename)) from matplotlib import rcParams + try: fig.savefig(fname=filename, **kwargs) except RuntimeError: - logger.debug( - "Failed to save plot with tex labels turning off tex." - ) + logger.debug("Failed to save plot with tex labels turning off tex.") rcParams["text.usetex"] = False fig.savefig(fname=filename, **kwargs) -- GitLab