Skip to content
Snippets Groups Projects
Commit 49588418 authored by Colm Talbot's avatar Colm Talbot Committed by Gregory Ashton
Browse files

make aligned spin priors load and cosmological safer

parent 1740dd5a
No related branches found
No related tags found
No related merge requests found
......@@ -104,6 +104,11 @@ class PriorDict(OrderedDict):
Notes
-----
Lines beginning with '#' or empty lines will be ignored.
Priors can be loaded from:
bilby.core.prior as, e.g., foo = Uniform(minimum=0, maximum=1)
floats, e.g., foo = 1
bilby.gw.prior as, e.g., foo = bilby.gw.prior.AlignedSpin()
other external modules, e.g., foo = my.module.CustomPrior(...)
"""
comments = ['#', '\n']
......@@ -118,6 +123,13 @@ class PriorDict(OrderedDict):
val = '='.join(elements[1:]).strip()
cls = val.split('(')[0]
args = '('.join(val.split('(')[1:])[:-1]
try:
prior[key] = DeltaFunction(peak=float(cls))
logger.debug("{} converted ot DeltaFunction prior".format(
key))
continue
except ValueError:
pass
if "." in cls:
module = '.'.join(cls.split('.')[:-1])
cls = cls.split('.')[-1]
......@@ -676,6 +688,11 @@ class Prior(object):
@classmethod
def from_repr(cls, string):
"""Generate the prior from it's __repr__"""
return cls._from_repr(string)
@classmethod
def _from_repr(cls, string):
subclass_args = infer_args_from_method(cls.__init__)
string = string.replace(' ', '')
......@@ -696,20 +713,25 @@ class Prior(object):
remove = list()
for ii, key in enumerate(args):
if '(' in key:
args[ii] = ','.join([args[ii], args[ii + 1]]).strip()
remove.append(ii + 1)
jj = ii
while ')' not in args[jj]:
jj += 1
args[ii] = ','.join([args[ii], args[jj]]).strip()
remove.append(jj)
remove.reverse()
for ii in remove:
del args[ii]
kwargs = dict()
for ii, arg in enumerate(args):
try:
key, val = arg.split('=')
except ValueError:
if '=' not in arg:
logger.debug(
'Reading priors with non-keyword arguments is dangerous!')
key = subclass_args[ii]
val = arg
else:
split_arg = arg.split('=')
key = split_arg[0]
val = '='.join(split_arg[1:])
kwargs[key] = val
return kwargs
......
......@@ -149,6 +149,17 @@ class Cosmological(Interped):
def _get_redshift_arrays(self):
raise NotImplementedError
@classmethod
def from_repr(cls, string):
if "FlatLambdaCDM" in string:
logger.warning(
"Cosmological priors cannot be loaded from a string. "
"If the prior has a name, use that instead."
)
return string
else:
return cls._from_repr(string)
class UniformComovingVolume(Cosmological):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment