diff --git a/bilby/core/grid.py b/bilby/core/grid.py index a9b957f96a9ce649b4a4b1e29807f73347d433d9..d53bfa1af29313d9700e754e26dd7ff24c0ce2cc 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -7,7 +7,8 @@ from collections import OrderedDict from .prior import Prior, PriorDict from .utils import (logtrapzexp, check_directory_exists_and_if_not_mkdir, - logger, BilbyJsonEncoder, decode_bilby_json) + logger) +from .utils import BilbyJsonEncoder, decode_bilby_json from .result import FileMovedError @@ -406,12 +407,10 @@ class Grid(object): logger.debug("Saving result to {}".format(filename)) - # Convert the prior to a string representation for saving on disk dictionary = self._get_save_data_dictionary() - if dictionary.get('priors', False): - dictionary['priors'] = {key: str(self.priors[key]) for key in self.priors} try: + dictionary["priors"] = dictionary["priors"]._get_json_dict() if gzip or (os.path.splitext(filename)[-1] == '.gz'): import gzip # encode to a string @@ -468,12 +467,6 @@ class Grid(object): else: with open(fname, 'r') as file: dictionary = json.load(file, object_hook=decode_bilby_json) - for key in dictionary.keys(): - # Convert the loaded priors to bilby prior type - if key == 'priors': - for param in dictionary[key].keys(): - dictionary[key][param] = str(dictionary[key][param]) - dictionary[key] = PriorDict(dictionary[key]) try: grid = cls(likelihood=None, priors=dictionary['priors'], grid_size=dictionary['sample_points'], diff --git a/bilby/core/prior.py b/bilby/core/prior.py index 2ba628560947b6d577880d4fe34ea3af6047b150..9ea0b1934ddf35e4d021329fe6d4a7bd2036f955 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -5,16 +5,21 @@ from importlib import import_module import os from collections import OrderedDict from future.utils import iteritems -from matplotlib.cbook import flatten +import json import numpy as np import scipy.stats from scipy.integrate import cumtrapz from scipy.interpolate import interp1d from scipy.special import erf, erfinv +from matplotlib.cbook import flatten +# Keep import bilby statement, it is necessary for some eval() statements +from .utils import BilbyJsonEncoder, decode_bilby_json from .utils import ( - logger, infer_args_from_method, check_directory_exists_and_if_not_mkdir) + check_directory_exists_and_if_not_mkdir, + infer_args_from_method, logger +) class PriorDict(OrderedDict): @@ -107,6 +112,22 @@ class PriorDict(OrderedDict): outfile.write( "{} = {}\n".format(key, self[key])) + def _get_json_dict(self): + self.convert_floats_to_delta_functions() + total_dict = {key: json.loads(self[key].to_json()) for key in self} + total_dict["__prior_dict__"] = True + total_dict["__module__"] = self.__module__ + total_dict["__name__"] = self.__class__.__name__ + return total_dict + + def to_json(self, outdir, label): + check_directory_exists_and_if_not_mkdir(outdir) + prior_file = os.path.join(outdir, "{}_prior.json".format(label)) + logger.debug("Writing priors to {}".format(prior_file)) + with open(prior_file, "w") as outfile: + json.dump(self._get_json_dict(), outfile, cls=BilbyJsonEncoder, + indent=2) + def from_file(self, filename): """ Reads in a prior from a file specification @@ -150,7 +171,7 @@ class PriorDict(OrderedDict): cls = cls.split('.')[-1] else: module = __name__ - cls = getattr(import_module(module), cls) + cls = getattr(import_module(module), cls, cls) if key.lower() == "conversion_function": setattr(self, key, cls) elif (cls.__name__ in ['MultivariateGaussianDist', @@ -170,6 +191,38 @@ class PriorDict(OrderedDict): filename, key, val, e)) self.update(prior) + @classmethod + def _get_from_json_dict(cls, prior_dict): + try: + cls == getattr( + import_module(prior_dict["__module__"]), + prior_dict["__name__"]) + except ImportError: + logger.debug("Cannot import prior module {}.{}".format( + prior_dict["__module__"], prior_dict["__name__"] + )) + except KeyError: + logger.debug("Cannot find module name to load") + for key in ["__module__", "__name__", "__prior_dict__"]: + if key in prior_dict: + del prior_dict[key] + obj = cls(dict()) + obj.from_dictionary(prior_dict) + return obj + + @classmethod + def from_json(cls, filename): + """ Reads in a prior from a json file + + Parameters + ---------- + filename: str + Name of the file to be read in + """ + with open(filename, "r") as ff: + obj = json.load(ff, object_hook=decode_bilby_json) + return obj + def from_dictionary(self, dictionary): for key, val in iteritems(dictionary): if isinstance(val, str): @@ -182,6 +235,10 @@ class PriorDict(OrderedDict): "Failed to load dictionary value {} correctly" .format(key)) pass + elif isinstance(val, dict): + logger.warning( + 'Cannot convert {} into a prior object. ' + 'Leaving as dictionary.'.format(key)) self[key] = val def convert_floats_to_delta_functions(self): @@ -628,7 +685,9 @@ class Prior(object): """ prior_name = self.__class__.__name__ - args = ', '.join(['{}={}'.format(key, repr(self._repr_dict[key])) for key in self._repr_dict]) + instantiation_dict = self._get_instantiation_dict() + args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key])) + for key in instantiation_dict]) return "{}({})".format(prior_name, args) @property @@ -709,6 +768,18 @@ class Prior(object): def maximum(self, maximum): self._maximum = maximum + def _get_instantiation_dict(self): + subclass_args = infer_args_from_method(self.__init__) + property_names = [p for p in dir(self.__class__) + if isinstance(getattr(self.__class__, p), property)] + dict_with_properties = self.__dict__.copy() + for key in property_names: + dict_with_properties[key] = getattr(self, key) + instantiation_dict = OrderedDict() + for key in subclass_args: + instantiation_dict[key] = dict_with_properties[key] + return instantiation_dict + @property def boundary(self): return self._boundary @@ -727,6 +798,13 @@ class Prior(object): label = self.name return label + def to_json(self): + return json.dumps(self, cls=BilbyJsonEncoder) + + @classmethod + def from_json(cls, dct): + return decode_bilby_json(dct) + @classmethod def from_repr(cls, string): """Generate the prior from it's __repr__""" @@ -2914,6 +2992,22 @@ class MultivariateGaussianDist(object): return np.exp(self.ln_prob(samp)) + def _get_instantiation_dict(self): + subclass_args = infer_args_from_method(self.__init__) + property_names = [p for p in dir(self.__class__) + if isinstance(getattr(self.__class__, p), property)] + dict_with_properties = self.__dict__.copy() + for key in property_names: + dict_with_properties[key] = getattr(self, key) + instantiation_dict = OrderedDict() + for key in subclass_args: + if isinstance(dict_with_properties[key], list): + value = np.asarray(dict_with_properties[key]).tolist() + else: + value = dict_with_properties[key] + instantiation_dict[key] = value + return instantiation_dict + def __len__(self): return len(self.names) @@ -2928,23 +3022,10 @@ class MultivariateGaussianDist(object): str: A string representation of this instance """ - subclass_args = infer_args_from_method(self.__init__) dist_name = self.__class__.__name__ - - property_names = [p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)] - dict_with_properties = self.__dict__.copy() - for key in property_names: - dict_with_properties[key] = getattr(self, key) - - argslist = [] - for key in subclass_args: - # make sure lists containing arrays are returned just as lists - if isinstance(dict_with_properties[key], list): - argsval = np.asarray(dict_with_properties[key]).tolist() - else: - argsval = dict_with_properties[key] - argslist.append('{}={}'.format(key, repr(argsval))) - args = ', '.join(argslist) + instantiation_dict = self._get_instantiation_dict() + args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key])) + for key in instantiation_dict]) return "{}({})".format(dist_name, args) def __eq__(self, other): diff --git a/bilby/core/result.py b/bilby/core/result.py index 8b4d97dda30734a85eef823b55fa0787f77cc9fb..5ac7615e78e8a6f1204fd138c5ec0728e86b0a8c 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -18,8 +18,8 @@ from scipy.special import logsumexp from . import utils from .utils import (logger, infer_parameters_from_function, - check_directory_exists_and_if_not_mkdir, - BilbyJsonEncoder, decode_bilby_json) + check_directory_exists_and_if_not_mkdir,) +from .utils import BilbyJsonEncoder, decode_bilby_json from .prior import Prior, PriorDict, DeltaFunction @@ -264,12 +264,6 @@ class Result(object): else: with open(filename, 'r') as file: dictionary = json.load(file, object_hook=decode_bilby_json) - for key in dictionary.keys(): - # Convert the loaded priors to bilby prior type - if key == 'priors': - for param in dictionary[key].keys(): - dictionary[key][param] = str(dictionary[key][param]) - dictionary[key] = PriorDict(dictionary[key]) try: return cls(**dictionary) except TypeError as e: @@ -467,8 +461,6 @@ class Result(object): # Convert the prior to a string representation for saving on disk dictionary = self._get_save_data_dictionary() - if dictionary.get('priors', False): - dictionary['priors'] = {key: str(self.priors[key]) for key in self.priors} # Convert callable sampler_kwargs to strings if dictionary.get('sampler_kwargs', None) is not None: @@ -478,6 +470,7 @@ class Result(object): try: if extension == 'json': + dictionary["priors"] = dictionary["priors"]._get_json_dict() if gzip: import gzip # encode to a string diff --git a/bilby/core/utils.py b/bilby/core/utils.py index 5506a443819016c905760c9a0cc55ecfad93341a..6f53601f79d80a942c6317720c838583d2c9c979 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -7,8 +7,9 @@ import argparse import traceback import inspect import subprocess -import json import multiprocessing +from importlib import import_module +import json import numpy as np from scipy.interpolate import interp2d @@ -907,7 +908,25 @@ else: class BilbyJsonEncoder(json.JSONEncoder): + def default(self, obj): + from .prior import MultivariateGaussianDist, Prior, PriorDict + if isinstance(obj, PriorDict): + return {'__prior_dict__': True, 'content': obj._get_json_dict()} + if isinstance(obj, (MultivariateGaussianDist, Prior)): + return {'__prior__': True, '__module__': obj.__module__, + '__name__': obj.__class__.__name__, + 'kwargs': dict(obj._get_instantiation_dict())} + try: + from astropy import cosmology as cosmo, units + if isinstance(obj, cosmo.FLRW): + return encode_astropy_cosmology(obj) + if isinstance(obj, units.Quantity): + return encode_astropy_quantity(obj) + if isinstance(obj, units.PrefixUnit): + return str(obj) + except ImportError: + logger.info("Cannot import astropy, cannot write cosmological priors") if isinstance(obj, np.ndarray): return {'__array__': True, 'content': obj.tolist()} if isinstance(obj, complex): @@ -917,7 +936,35 @@ class BilbyJsonEncoder(json.JSONEncoder): return json.JSONEncoder.default(self, obj) +def encode_astropy_cosmology(obj): + cls_name = obj.__class__.__name__ + dct = {key: getattr(obj, key) for + key in infer_args_from_method(obj.__init__)} + dct['__cosmology__'] = True + dct['__name__'] = cls_name + return dct + + +def encode_astropy_quantity(dct): + dct = dict(__astropy_quantity__=True, value=dct.value, unit=str(dct.unit)) + if isinstance(dct['value'], np.ndarray): + dct['value'] = list(dct['value']) + return dct + + def decode_bilby_json(dct): + if dct.get("__prior_dict__", False): + cls = getattr(import_module(dct['__module__']), dct['__name__']) + obj = cls._get_from_json_dict(dct) + return obj + if dct.get("__prior__", False): + cls = getattr(import_module(dct['__module__']), dct['__name__']) + obj = cls(**dct['kwargs']) + return obj + if dct.get("__cosmology__", False): + return decode_astropy_cosmology(dct) + if dct.get("__astropy_quantity__", False): + return decode_astropy_quantity(dct) if dct.get("__array__", False): return np.asarray(dct["content"]) if dct.get("__complex__", False): @@ -927,5 +974,31 @@ def decode_bilby_json(dct): return dct +def decode_astropy_cosmology(dct): + try: + from astropy import cosmology as cosmo + cosmo_cls = getattr(cosmo, dct['__name__']) + del dct['__cosmology__'], dct['__name__'] + return cosmo_cls(**dct) + except ImportError: + logger.info("Cannot import astropy, cosmological priors may not be " + "properly loaded.") + return dct + + +def decode_astropy_quantity(dct): + try: + from astropy import units + if dct['value'] is None: + return None + else: + del dct['__astropy_quantity__'] + return units.Quantity(**dct) + except ImportError: + logger.info("Cannot import astropy, cosmological priors may not be " + "properly loaded.") + return dct + + class IllegalDurationAndSamplingFrequencyException(Exception): pass diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py index 73585744ab81af1a86b1980a37d23c2d3b451776..14f14b7646262a641f4a6ade6c79e57478bd8a3e 100644 --- a/bilby/gw/likelihood.py +++ b/bilby/gw/likelihood.py @@ -16,10 +16,10 @@ except ImportError: from scipy.special import i0e from ..core import likelihood +from ..core.utils import BilbyJsonEncoder, decode_bilby_json from ..core.utils import ( - logger, UnsortedInterp2d, BilbyJsonEncoder, decode_bilby_json, - create_frequency_series, create_time_series, speed_of_light, - radius_of_earth) + logger, UnsortedInterp2d, create_frequency_series, create_time_series, + speed_of_light, radius_of_earth) from ..core.prior import Interped, Prior, Uniform from .detector import InterferometerList from .prior import BBHPriorDict diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 83dcf925d0fd33581746feee7d055d3eafaa7547..890659561bda732a96e1b01d5798eea159612ead 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -42,8 +42,8 @@ class Cosmological(Interped): if latex_label is not None: label_args['latex_label'] = latex_label if unit is not None: - if isinstance(unit, str): - unit = units.__dict__[unit] + if not isinstance(unit, units.Unit): + unit = units.Unit(unit) label_args['unit'] = unit self.unit = label_args['unit'] self._minimum = dict() diff --git a/test/gw_prior_test.py b/test/gw_prior_test.py index 57748331680547a4cd4551ca90772285da3f24e4..9195aebed61948e8759b26c642f4351eaec17f7a 100644 --- a/test/gw_prior_test.py +++ b/test/gw_prior_test.py @@ -1,4 +1,5 @@ from __future__ import division, absolute_import +from collections import OrderedDict import unittest import os import sys @@ -144,7 +145,7 @@ class TestPackagedPriors(unittest.TestCase): class TestBNSPriorDict(unittest.TestCase): def setUp(self): - self.prior_dict = dict() + self.prior_dict = OrderedDict() self.base_directory =\ '/'.join(os.path.dirname( os.path.abspath(sys.argv[0])).split('/')[:-1]) diff --git a/test/prior_test.py b/test/prior_test.py index a042cdbe7032e5d401e8f9fa40ea164ba6e69750..852031a7286cdf68d55653f49f10a6dd4c8fcc14 100644 --- a/test/prior_test.py +++ b/test/prior_test.py @@ -816,5 +816,65 @@ class TestCreateDefaultPrior(unittest.TestCase): self.assertIsNone(bilby.core.prior.create_default_prior(name='name', default_priors_file=prior_file)) +class TestJsonIO(unittest.TestCase): + + def setUp(self): + mvg = bilby.core.prior.MultivariateGaussianDist(names=['testa', 'testb'], + mus=[1, 1], + covs=np.array([[2., 0.5], [0.5, 2.]]), + weights=1.) + mvn = bilby.core.prior.MultivariateGaussianDist(names=['testa', 'testb'], + mus=[1, 1], + covs=np.array([[2., 0.5], [0.5, 2.]]), + weights=1.) + + self.priors = bilby.core.prior.PriorDict(dict( + a=bilby.core.prior.DeltaFunction(name='test', unit='unit', peak=1), + b=bilby.core.prior.Gaussian(name='test', unit='unit', mu=0, sigma=1), + c=bilby.core.prior.Normal(name='test', unit='unit', mu=0, sigma=1), + d=bilby.core.prior.PowerLaw(name='test', unit='unit', alpha=0, minimum=0, maximum=1), + e=bilby.core.prior.PowerLaw(name='test', unit='unit', alpha=-1, minimum=0.5, maximum=1), + f=bilby.core.prior.PowerLaw(name='test', unit='unit', alpha=2, minimum=1, maximum=1e2), + g=bilby.core.prior.Uniform(name='test', unit='unit', minimum=0, maximum=1), + h=bilby.core.prior.LogUniform(name='test', unit='unit', minimum=5e0, maximum=1e2), + i=bilby.gw.prior.UniformComovingVolume(name='redshift', minimum=0.1, maximum=1.0), + j=bilby.gw.prior.UniformSourceFrame(name='luminosity_distance', minimum=1.0, maximum=1000.0), + k=bilby.core.prior.Sine(name='test', unit='unit'), + l=bilby.core.prior.Cosine(name='test', unit='unit'), + m=bilby.core.prior.Interped(name='test', unit='unit', xx=np.linspace(0, 10, 1000), + yy=np.linspace(0, 10, 1000) ** 4, + minimum=3, maximum=5), + n=bilby.core.prior.TruncatedGaussian(name='test', unit='unit', mu=1, sigma=0.4, minimum=-1, maximum=1), + o=bilby.core.prior.TruncatedNormal(name='test', unit='unit', mu=1, sigma=0.4, minimum=-1, maximum=1), + p=bilby.core.prior.HalfGaussian(name='test', unit='unit', sigma=1), + q=bilby.core.prior.HalfNormal(name='test', unit='unit', sigma=1), + r=bilby.core.prior.LogGaussian(name='test', unit='unit', mu=0, sigma=1), + s=bilby.core.prior.LogNormal(name='test', unit='unit', mu=0, sigma=1), + t=bilby.core.prior.Exponential(name='test', unit='unit', mu=1), + u=bilby.core.prior.StudentT(name='test', unit='unit', df=3, mu=0, scale=1), + v=bilby.core.prior.Beta(name='test', unit='unit', alpha=2.0, beta=2.0), + x=bilby.core.prior.Logistic(name='test', unit='unit', mu=0, scale=1), + y=bilby.core.prior.Cauchy(name='test', unit='unit', alpha=0, beta=1), + z=bilby.core.prior.Lorentzian(name='test', unit='unit', alpha=0, beta=1), + aa=bilby.core.prior.Gamma(name='test', unit='unit', k=1, theta=1), + ab=bilby.core.prior.ChiSquared(name='test', unit='unit', nu=2), + ac=bilby.gw.prior.AlignedSpin(name='test', unit='unit'), + ad=bilby.core.prior.MultivariateGaussian(mvg=mvg, name='testa', unit='unit'), + ae=bilby.core.prior.MultivariateGaussian(mvg=mvg, name='testb', unit='unit'), + af=bilby.core.prior.MultivariateNormal(mvg=mvn, name='testa', unit='unit'), + ag=bilby.core.prior.MultivariateNormal(mvg=mvn, name='testb', unit='unit') + )) + + def test_read_write_to_json(self): + """ Interped prior is removed as there is numerical error in the recovered prior.""" + self.priors.to_json(outdir="prior_files", label="json_test") + new_priors = bilby.core.prior.PriorDict.from_json(filename="prior_files/json_test_prior.json") + old_interped = self.priors.pop("m") + new_interped = new_priors.pop("m") + self.assertDictEqual(self.priors, new_priors) + self.assertLess(max(abs(old_interped.xx - new_interped.xx)), 1e-15) + self.assertLess(max(abs(old_interped.yy - new_interped.yy)), 1e-15) + + if __name__ == '__main__': unittest.main()