From 09e3c9f3bd5c4120a283c34d7666572abe8174d6 Mon Sep 17 00:00:00 2001
From: Sylvia Biscoveanu <sylvia.biscoveanu@ligo.org>
Date: Fri, 3 Sep 2021 14:42:40 +0000
Subject: [PATCH] Add ability to load results produced with custom priors

---
 bilby/core/utils/io.py | 119 ++++++++++++++++++++++++++---------------
 1 file changed, 75 insertions(+), 44 deletions(-)

diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py
index 6bfa73de7..85095dc6a 100644
--- a/bilby/core/utils/io.py
+++ b/bilby/core/utils/io.py
@@ -25,25 +25,29 @@ def check_directory_exists_and_if_not_mkdir(directory):
 
 
 class BilbyJsonEncoder(json.JSONEncoder):
-
     def default(self, obj):
         from ..prior import MultivariateGaussianDist, Prior, PriorDict
         from ...gw.prior import HealPixMapPriorDist
         from ...bilby_mcmc.proposals import ProposalCycle
+
         if isinstance(obj, np.integer):
             return int(obj)
         if isinstance(obj, np.floating):
             return float(obj)
         if isinstance(obj, PriorDict):
-            return {'__prior_dict__': True, 'content': obj._get_json_dict()}
+            return {"__prior_dict__": True, "content": obj._get_json_dict()}
         if isinstance(obj, (MultivariateGaussianDist, HealPixMapPriorDist, Prior)):
-            return {'__prior__': True, '__module__': obj.__module__,
-                    '__name__': obj.__class__.__name__,
-                    'kwargs': dict(obj.get_instantiation_dict())}
+            return {
+                "__prior__": True,
+                "__module__": obj.__module__,
+                "__name__": obj.__class__.__name__,
+                "kwargs": dict(obj.get_instantiation_dict()),
+            }
         if isinstance(obj, ProposalCycle):
             return str(obj)
         try:
             from astropy import cosmology as cosmo, units
+
             if isinstance(obj, cosmo.FLRW):
                 return encode_astropy_cosmology(obj)
             if isinstance(obj, units.Quantity):
@@ -53,82 +57,104 @@ class BilbyJsonEncoder(json.JSONEncoder):
         except ImportError:
             logger.debug("Cannot import astropy, cannot write cosmological priors")
         if isinstance(obj, np.ndarray):
-            return {'__array__': True, 'content': obj.tolist()}
+            return {"__array__": True, "content": obj.tolist()}
         if isinstance(obj, complex):
-            return {'__complex__': True, 'real': obj.real, 'imag': obj.imag}
+            return {"__complex__": True, "real": obj.real, "imag": obj.imag}
         if isinstance(obj, pd.DataFrame):
-            return {'__dataframe__': True, 'content': obj.to_dict(orient='list')}
+            return {"__dataframe__": True, "content": obj.to_dict(orient="list")}
         if isinstance(obj, pd.Series):
-            return {'__series__': True, 'content': obj.to_dict()}
+            return {"__series__": True, "content": obj.to_dict()}
         if inspect.isfunction(obj):
-            return {"__function__": True, "__module__": obj.__module__, "__name__": obj.__name__}
+            return {
+                "__function__": True,
+                "__module__": obj.__module__,
+                "__name__": obj.__name__,
+            }
         if inspect.isclass(obj):
-            return {"__class__": True, "__module__": obj.__module__, "__name__": obj.__name__}
+            return {
+                "__class__": True,
+                "__module__": obj.__module__,
+                "__name__": obj.__name__,
+            }
         return json.JSONEncoder.default(self, obj)
 
 
 def encode_astropy_cosmology(obj):
     cls_name = obj.__class__.__name__
-    dct = {key: getattr(obj, key) for
-           key in infer_args_from_method(obj.__init__)}
-    dct['__cosmology__'] = True
-    dct['__name__'] = cls_name
+    dct = {key: getattr(obj, key) for key in infer_args_from_method(obj.__init__)}
+    dct["__cosmology__"] = True
+    dct["__name__"] = cls_name
     return dct
 
 
 def encode_astropy_quantity(dct):
     dct = dict(__astropy_quantity__=True, value=dct.value, unit=str(dct.unit))
-    if isinstance(dct['value'], np.ndarray):
-        dct['value'] = list(dct['value'])
+    if isinstance(dct["value"], np.ndarray):
+        dct["value"] = list(dct["value"])
     return dct
 
 
 def decode_astropy_cosmology(dct):
     try:
         from astropy import cosmology as cosmo
-        cosmo_cls = getattr(cosmo, dct['__name__'])
-        del dct['__cosmology__'], dct['__name__']
+
+        cosmo_cls = getattr(cosmo, dct["__name__"])
+        del dct["__cosmology__"], dct["__name__"]
         return cosmo_cls(**dct)
     except ImportError:
-        logger.debug("Cannot import astropy, cosmological priors may not be "
-                     "properly loaded.")
+        logger.debug(
+            "Cannot import astropy, cosmological priors may not be " "properly loaded."
+        )
         return dct
 
 
 def decode_astropy_quantity(dct):
     try:
         from astropy import units
-        if dct['value'] is None:
+
+        if dct["value"] is None:
             return None
         else:
-            del dct['__astropy_quantity__']
+            del dct["__astropy_quantity__"]
             return units.Quantity(**dct)
     except ImportError:
-        logger.debug("Cannot import astropy, cosmological priors may not be "
-                     "properly loaded.")
+        logger.debug(
+            "Cannot import astropy, cosmological priors may not be " "properly loaded."
+        )
         return dct
 
 
 def load_json(filename, gzip):
-    if gzip or os.path.splitext(filename)[1].lstrip('.') == 'gz':
+    if gzip or os.path.splitext(filename)[1].lstrip(".") == "gz":
         import gzip
-        with gzip.GzipFile(filename, 'r') as file:
-            json_str = file.read().decode('utf-8')
+
+        with gzip.GzipFile(filename, "r") as file:
+            json_str = file.read().decode("utf-8")
         dictionary = json.loads(json_str, object_hook=decode_bilby_json)
     else:
-        with open(filename, 'r') as file:
+        with open(filename, "r") as file:
             dictionary = json.load(file, object_hook=decode_bilby_json)
     return dictionary
 
 
 def decode_bilby_json(dct):
     if dct.get("__prior_dict__", False):
-        cls = getattr(import_module(dct['__module__']), dct['__name__'])
+        cls = getattr(import_module(dct["__module__"]), dct["__name__"])
         obj = cls._get_from_json_dict(dct)
         return obj
     if dct.get("__prior__", False):
-        cls = getattr(import_module(dct['__module__']), dct['__name__'])
-        obj = cls(**dct['kwargs'])
+        try:
+            cls = getattr(import_module(dct["__module__"]), dct["__name__"])
+        except AttributeError:
+            logger.debug(
+                "Unknown prior class for parameter {}, defaulting to base Prior object".format(
+                    dct["kwargs"]["name"]
+                )
+            )
+            from ..prior import Prior
+
+            cls = Prior
+        obj = cls(**dct["kwargs"])
         return obj
     if dct.get("__cosmology__", False):
         return decode_astropy_cosmology(dct)
@@ -139,9 +165,9 @@ def decode_bilby_json(dct):
     if dct.get("__complex__", False):
         return complex(dct["real"], dct["imag"])
     if dct.get("__dataframe__", False):
-        return pd.DataFrame(dct['content'])
+        return pd.DataFrame(dct["content"])
     if dct.get("__series__", False):
-        return pd.Series(dct['content'])
+        return pd.Series(dct["content"])
     if dct.get("__function__", False) or dct.get("__class__", False):
         default = ".".join([dct["__module__"], dct["__name__"]])
         return getattr(import_module(dct["__module__"]), dct["__name__"], default)
@@ -225,6 +251,7 @@ def encode_for_hdf5(key, item):
         Input item converted into HDF5 saveable format
     """
     from ..prior.dict import PriorDict
+
     if isinstance(item, np.int_):
         item = int(item)
     elif isinstance(item, np.float_):
@@ -258,7 +285,9 @@ def encode_for_hdf5(key, item):
     elif isinstance(item, pd.Series):
         output = item.to_dict()
     elif inspect.isfunction(item) or inspect.isclass(item):
-        output = dict(__module__=item.__module__, __name__=item.__name__, __class__=True)
+        output = dict(
+            __module__=item.__module__, __name__=item.__name__, __class__=True
+        )
     elif isinstance(item, dict):
         output = item.copy()
     elif isinstance(item, tuple):
@@ -287,12 +316,15 @@ def recursively_load_dict_contents_from_group(h5file, path):
         The contents of the HDF5 file unpacked into the dictionary.
     """
     import h5py
+
     output = dict()
     for key, item in h5file[path].items():
         if isinstance(item, h5py.Dataset):
             output[key] = decode_from_hdf5(item[()])
         elif isinstance(item, h5py.Group):
-            output[key] = recursively_load_dict_contents_from_group(h5file, path + key + '/')
+            output[key] = recursively_load_dict_contents_from_group(
+                h5file, path + key + "/"
+            )
     return output
 
 
@@ -314,7 +346,7 @@ def recursively_save_dict_contents_to_group(h5file, path, dic):
     for key, item in dic.items():
         item = encode_for_hdf5(key, item)
         if isinstance(item, dict):
-            recursively_save_dict_contents_to_group(h5file, path + key + '/', item)
+            recursively_save_dict_contents_to_group(h5file, path + key + "/", item)
         else:
             h5file[path + key] = item
 
@@ -351,24 +383,23 @@ def move_old_file(filename, overwrite=False):
     """
     if os.path.isfile(filename):
         if overwrite:
-            logger.debug('Removing existing file {}'.format(filename))
+            logger.debug("Removing existing file {}".format(filename))
             os.remove(filename)
         else:
             logger.debug(
-                'Renaming existing file {} to {}.old'.format(filename,
-                                                             filename))
-            shutil.move(filename, filename + '.old')
+                "Renaming existing file {} to {}.old".format(filename, filename)
+            )
+            shutil.move(filename, filename + ".old")
     logger.debug("Saving result to {}".format(filename))
 
 
 def safe_save_figure(fig, filename, **kwargs):
     check_directory_exists_and_if_not_mkdir(os.path.dirname(filename))
     from matplotlib import rcParams
+
     try:
         fig.savefig(fname=filename, **kwargs)
     except RuntimeError:
-        logger.debug(
-            "Failed to save plot with tex labels turning off tex."
-        )
+        logger.debug("Failed to save plot with tex labels turning off tex.")
         rcParams["text.usetex"] = False
         fig.savefig(fname=filename, **kwargs)
-- 
GitLab