Skip to content
Snippets Groups Projects
Commit e73737a0 authored by Colm Talbot's avatar Colm Talbot Committed by Gregory Ashton
Browse files

make prior reading go through from_dict class

parent 18d7219f
No related branches found
No related tags found
No related merge requests found
from importlib import import_module from importlib import import_module
from io import open as ioopen from io import open as ioopen
import json import json
import numpy as np
import os import os
from future.utils import iteritems from future.utils import iteritems
from matplotlib.cbook import flatten from matplotlib.cbook import flatten
import numpy as np
# keep 'import *' to make eval() statement further down work consistently
from bilby.core.prior.analytical import * # noqa
from bilby.core.prior.analytical import DeltaFunction from bilby.core.prior.analytical import DeltaFunction
from bilby.core.prior.base import Prior, Constraint from bilby.core.prior.base import Prior, Constraint
from bilby.core.prior.joint import JointPrior from bilby.core.prior.joint import JointPrior
...@@ -141,7 +139,6 @@ class PriorDict(dict): ...@@ -141,7 +139,6 @@ class PriorDict(dict):
comments = ['#', '\n'] comments = ['#', '\n']
prior = dict() prior = dict()
mvgdict = dict(inf=np.inf) # evaluate inf as np.inf
with ioopen(filename, 'r', encoding='unicode_escape') as f: with ioopen(filename, 'r', encoding='unicode_escape') as f:
for line in f: for line in f:
if line[0] in comments: if line[0] in comments:
...@@ -150,39 +147,8 @@ class PriorDict(dict): ...@@ -150,39 +147,8 @@ class PriorDict(dict):
elements = line.split('=') elements = line.split('=')
key = elements[0].replace(' ', '') key = elements[0].replace(' ', '')
val = '='.join(elements[1:]).strip() val = '='.join(elements[1:]).strip()
cls = val.split('(')[0] prior[key] = val
args = '('.join(val.split('(')[1:])[:-1] self.from_dictionary(prior)
try:
prior[key] = DeltaFunction(peak=float(cls))
logger.debug("{} converted to DeltaFunction prior".format(
key))
continue
except ValueError:
pass
if "." in cls:
module = '.'.join(cls.split('.')[:-1])
cls = cls.split('.')[-1]
else:
module = __name__.replace('.' + os.path.basename(__file__).replace('.py', ''), '')
cls = getattr(import_module(module), cls, cls)
if key.lower() in ["conversion_function", "condition_func"]:
setattr(self, key, cls)
elif (cls.__name__ in ['MultivariateGaussianDist',
'MultivariateNormalDist']):
if key not in mvgdict:
mvgdict[key] = eval(val, None, mvgdict)
elif (cls.__name__ in ['MultivariateGaussian',
'MultivariateNormal']):
prior[key] = eval(val, None, mvgdict)
else:
try:
prior[key] = cls.from_repr(args)
except TypeError as e:
raise TypeError(
"Unable to parse dictionary file {}, bad line: {} "
"= {}. Error message {}".format(
filename, key, val, e))
self.update(prior)
@classmethod @classmethod
def _get_from_json_dict(cls, prior_dict): def _get_from_json_dict(cls, prior_dict):
...@@ -217,22 +183,61 @@ class PriorDict(dict): ...@@ -217,22 +183,61 @@ class PriorDict(dict):
return obj return obj
def from_dictionary(self, dictionary): def from_dictionary(self, dictionary):
eval_dict = dict(inf=np.inf)
for key, val in iteritems(dictionary): for key, val in iteritems(dictionary):
if isinstance(val, str): if isinstance(val, Prior):
continue
elif isinstance(val, (int, float)):
dictionary[key] = DeltaFunction(peak=val)
elif isinstance(val, str):
cls = val.split('(')[0]
args = '('.join(val.split('(')[1:])[:-1]
try: try:
prior = eval(val) dictionary[key] = DeltaFunction(peak=float(cls))
if isinstance(prior, (Prior, float, int, str)): logger.debug("{} converted to DeltaFunction prior".format(key))
val = prior continue
except (NameError, SyntaxError, TypeError): except ValueError:
logger.debug(
"Failed to load dictionary value {} correctly"
.format(key))
pass pass
if "." in cls:
module = '.'.join(cls.split('.')[:-1])
cls = cls.split('.')[-1]
else:
module = __name__.replace(
'.' + os.path.basename(__file__).replace('.py', ''), ''
)
cls = getattr(import_module(module), cls, cls)
if key.lower() in ["conversion_function", "condition_func"]:
setattr(self, key, cls)
elif isinstance(cls, str):
if "(" in val:
raise TypeError("Unable to parse prior class {}".format(cls))
else:
continue
elif (cls.__name__ in ['MultivariateGaussianDist',
'MultivariateNormalDist']):
if key not in eval_dict:
eval_dict[key] = eval(val, None, eval_dict)
elif (cls.__name__ in ['MultivariateGaussian',
'MultivariateNormal']):
dictionary[key] = eval(val, None, eval_dict)
else:
try:
dictionary[key] = cls.from_repr(args)
except TypeError as e:
raise TypeError(
"Unable to parse prior, bad entry: {} "
"= {}. Error message {}".format(key, val, e)
)
elif isinstance(val, dict): elif isinstance(val, dict):
logger.warning( logger.warning(
'Cannot convert {} into a prior object. ' 'Cannot convert {} into a prior object. '
'Leaving as dictionary.'.format(key)) 'Leaving as dictionary.'.format(key))
self[key] = val else:
raise TypeError(
"Unable to parse prior, bad entry: {} "
"= {} of type {}".format(key, val, type(val))
)
self.update(dictionary)
def convert_floats_to_delta_functions(self): def convert_floats_to_delta_functions(self):
""" Convert all float parameters to delta functions """ """ Convert all float parameters to delta functions """
......
...@@ -929,7 +929,7 @@ class TestFillPrior(unittest.TestCase): ...@@ -929,7 +929,7 @@ class TestFillPrior(unittest.TestCase):
self.likelihood = Mock() self.likelihood = Mock()
self.likelihood.parameters = dict(a=0, b=0, c=0, d=0, asdf=0, ra=1) self.likelihood.parameters = dict(a=0, b=0, c=0, d=0, asdf=0, ra=1)
self.likelihood.non_standard_sampling_parameter_keys = dict(t=8) self.likelihood.non_standard_sampling_parameter_keys = dict(t=8)
self.priors = dict(a=1, b=1.1, c='string', d=bilby.core.prior.Uniform(0, 1)) self.priors = dict(a=1, b=1.1, c="string", d=bilby.core.prior.Uniform(0, 1))
self.priors = bilby.core.prior.PriorDict(dictionary=self.priors) self.priors = bilby.core.prior.PriorDict(dictionary=self.priors)
self.default_prior_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), self.default_prior_file = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'prior_files/binary_black_holes.prior') 'prior_files/binary_black_holes.prior')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment