Commit 49588418 authored by Colm Talbot's avatar Colm Talbot Committed by Gregory Ashton

make aligned spin priors load and cosmological safer

parent 1740dd5a
......@@ -104,6 +104,11 @@ class PriorDict(OrderedDict):
Notes
-----
Lines beginning with '#' or empty lines will be ignored.
Priors can be loaded from:
bilby.core.prior as, e.g., foo = Uniform(minimum=0, maximum=1)
floats, e.g., foo = 1
bilby.gw.prior as, e.g., foo = bilby.gw.prior.AlignedSpin()
other external modules, e.g., foo = my.module.CustomPrior(...)
"""
comments = ['#', '\n']
......@@ -118,6 +123,13 @@ class PriorDict(OrderedDict):
val = '='.join(elements[1:]).strip()
cls = val.split('(')[0]
args = '('.join(val.split('(')[1:])[:-1]
try:
prior[key] = DeltaFunction(peak=float(cls))
logger.debug("{} converted ot DeltaFunction prior".format(
key))
continue
except ValueError:
pass
if "." in cls:
module = '.'.join(cls.split('.')[:-1])
cls = cls.split('.')[-1]
......@@ -676,6 +688,11 @@ class Prior(object):
@classmethod
def from_repr(cls, string):
"""Generate the prior from it's __repr__"""
return cls._from_repr(string)
@classmethod
def _from_repr(cls, string):
subclass_args = infer_args_from_method(cls.__init__)
string = string.replace(' ', '')
......@@ -696,20 +713,25 @@ class Prior(object):
remove = list()
for ii, key in enumerate(args):
if '(' in key:
args[ii] = ','.join([args[ii], args[ii + 1]]).strip()
remove.append(ii + 1)
jj = ii
while ')' not in args[jj]:
jj += 1
args[ii] = ','.join([args[ii], args[jj]]).strip()
remove.append(jj)
remove.reverse()
for ii in remove:
del args[ii]
kwargs = dict()
for ii, arg in enumerate(args):
try:
key, val = arg.split('=')
except ValueError:
if '=' not in arg:
logger.debug(
'Reading priors with non-keyword arguments is dangerous!')
key = subclass_args[ii]
val = arg
else:
split_arg = arg.split('=')
key = split_arg[0]
val = '='.join(split_arg[1:])
kwargs[key] = val
return kwargs
......
......@@ -149,6 +149,17 @@ class Cosmological(Interped):
def _get_redshift_arrays(self):
raise NotImplementedError
@classmethod
def from_repr(cls, string):
if "FlatLambdaCDM" in string:
logger.warning(
"Cosmological priors cannot be loaded from a string. "
"If the prior has a name, use that instead."
)
return string
else:
return cls._from_repr(string)
class UniformComovingVolume(Cosmological):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment