diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 79b9522cfca601fe0c617fec6282029f5bd4824d..ca2eed960589fe5cb8d0bdfc05ca470fa00dfce5 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -1,14 +1,12 @@ from importlib import import_module from io import open as ioopen import json -import numpy as np import os from future.utils import iteritems from matplotlib.cbook import flatten +import numpy as np -# keep 'import *' to make eval() statement further down work consistently -from bilby.core.prior.analytical import * # noqa from bilby.core.prior.analytical import DeltaFunction from bilby.core.prior.base import Prior, Constraint from bilby.core.prior.joint import JointPrior @@ -141,7 +139,6 @@ class PriorDict(dict): comments = ['#', '\n'] prior = dict() - mvgdict = dict(inf=np.inf) # evaluate inf as np.inf with ioopen(filename, 'r', encoding='unicode_escape') as f: for line in f: if line[0] in comments: @@ -150,39 +147,8 @@ class PriorDict(dict): elements = line.split('=') key = elements[0].replace(' ', '') val = '='.join(elements[1:]).strip() - cls = val.split('(')[0] - args = '('.join(val.split('(')[1:])[:-1] - try: - prior[key] = DeltaFunction(peak=float(cls)) - logger.debug("{} converted to DeltaFunction prior".format( - key)) - continue - except ValueError: - pass - if "." in cls: - module = '.'.join(cls.split('.')[:-1]) - cls = cls.split('.')[-1] - else: - module = __name__.replace('.' + os.path.basename(__file__).replace('.py', ''), '') - cls = getattr(import_module(module), cls, cls) - if key.lower() in ["conversion_function", "condition_func"]: - setattr(self, key, cls) - elif (cls.__name__ in ['MultivariateGaussianDist', - 'MultivariateNormalDist']): - if key not in mvgdict: - mvgdict[key] = eval(val, None, mvgdict) - elif (cls.__name__ in ['MultivariateGaussian', - 'MultivariateNormal']): - prior[key] = eval(val, None, mvgdict) - else: - try: - prior[key] = cls.from_repr(args) - except TypeError as e: - raise TypeError( - "Unable to parse dictionary file {}, bad line: {} " - "= {}. Error message {}".format( - filename, key, val, e)) - self.update(prior) + prior[key] = val + self.from_dictionary(prior) @classmethod def _get_from_json_dict(cls, prior_dict): @@ -217,22 +183,61 @@ class PriorDict(dict): return obj def from_dictionary(self, dictionary): + eval_dict = dict(inf=np.inf) for key, val in iteritems(dictionary): - if isinstance(val, str): + if isinstance(val, Prior): + continue + elif isinstance(val, (int, float)): + dictionary[key] = DeltaFunction(peak=val) + elif isinstance(val, str): + cls = val.split('(')[0] + args = '('.join(val.split('(')[1:])[:-1] try: - prior = eval(val) - if isinstance(prior, (Prior, float, int, str)): - val = prior - except (NameError, SyntaxError, TypeError): - logger.debug( - "Failed to load dictionary value {} correctly" - .format(key)) + dictionary[key] = DeltaFunction(peak=float(cls)) + logger.debug("{} converted to DeltaFunction prior".format(key)) + continue + except ValueError: pass + if "." in cls: + module = '.'.join(cls.split('.')[:-1]) + cls = cls.split('.')[-1] + else: + module = __name__.replace( + '.' + os.path.basename(__file__).replace('.py', ''), '' + ) + cls = getattr(import_module(module), cls, cls) + if key.lower() in ["conversion_function", "condition_func"]: + setattr(self, key, cls) + elif isinstance(cls, str): + if "(" in val: + raise TypeError("Unable to parse prior class {}".format(cls)) + else: + continue + elif (cls.__name__ in ['MultivariateGaussianDist', + 'MultivariateNormalDist']): + if key not in eval_dict: + eval_dict[key] = eval(val, None, eval_dict) + elif (cls.__name__ in ['MultivariateGaussian', + 'MultivariateNormal']): + dictionary[key] = eval(val, None, eval_dict) + else: + try: + dictionary[key] = cls.from_repr(args) + except TypeError as e: + raise TypeError( + "Unable to parse prior, bad entry: {} " + "= {}. Error message {}".format(key, val, e) + ) elif isinstance(val, dict): logger.warning( 'Cannot convert {} into a prior object. ' 'Leaving as dictionary.'.format(key)) - self[key] = val + else: + raise TypeError( + "Unable to parse prior, bad entry: {} " + "= {} of type {}".format(key, val, type(val)) + ) + self.update(dictionary) def convert_floats_to_delta_functions(self): """ Convert all float parameters to delta functions """ diff --git a/test/prior_test.py b/test/prior_test.py index d579e31a64ec1319c4856ee76795aa7c47c24922..832956928bc2c44217395afd287cc564868d4325 100644 --- a/test/prior_test.py +++ b/test/prior_test.py @@ -929,7 +929,7 @@ class TestFillPrior(unittest.TestCase): self.likelihood = Mock() self.likelihood.parameters = dict(a=0, b=0, c=0, d=0, asdf=0, ra=1) self.likelihood.non_standard_sampling_parameter_keys = dict(t=8) - self.priors = dict(a=1, b=1.1, c='string', d=bilby.core.prior.Uniform(0, 1)) + self.priors = dict(a=1, b=1.1, c="string", d=bilby.core.prior.Uniform(0, 1)) self.priors = bilby.core.prior.PriorDict(dictionary=self.priors) self.default_prior_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'prior_files/binary_black_holes.prior')