diff --git a/bilby/core/prior.py b/bilby/core/prior.py index be34de2031e1f4ed0f9a37a257037b64dde89011..15ece0185011f25ca0ecf52dfd498797b60bd7ee 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -857,9 +857,47 @@ class Prior(object): @classmethod def _parse_argument_string(cls, val): - if re.sub(r'\'.*\'', '', val) in ['r', 'u']: - # If the val is a latex label like "r"\log(x)"' then ignore it - pass + """ + Parse a string into the appropriate type for prior reading. + + Four tests are applied in the following order: + + - If the string is 'None': + `None` is returned. + - Else If the string is a raw string, e.g., r'foo': + A stripped version of the string is returned, e.g., foo. + - Else If the string contains ', e.g., 'foo': + A stripped version of the string is returned, e.g., foo. + - Else If the string contains an open parenthesis, (: + The string is interpreted as a call to instantiate another prior + class, Bilby will attempt to recursively construct that prior, + e.g., Uniform(minimum=0, maximum=1), my.custom.PriorClass(**kwargs). + - Else: + Try to evaluate the string using `eval`. Only built-in functions + and numpy methods can be used, e.g., np.pi / 2, 1.57. + + + Parameters + ---------- + val: str + The string version of the agument + + Returns + ------- + val: object + The parsed version of the argument. + + Raises + ------ + TypeError: + If val cannot be parsed as described above. + """ + if val == 'None': + val = None + elif re.sub(r'\'.*\'', '', val) in ['r', 'u']: + val = val[2:-1] + elif "'" in val: + val = val.strip("'") elif '(' in val: other_cls = val.split('(')[0] vals = '('.join(val.split('(')[1:])[:-1] @@ -870,15 +908,14 @@ class Prior(object): module = __name__ other_cls = getattr(import_module(module), other_cls) val = other_cls.from_repr(vals) - elif "'" in val: - val = val.strip("'") - elif val == 'None': - val = None else: try: val = eval(val, dict(), dict(np=np)) except NameError: - raise TypeError() + raise TypeError( + "Cannot evaluate prior, " + "failed to parse argument {}".format(val) + ) return val diff --git a/test/prior_files/prior_with_parentheses.prior b/test/prior_files/prior_with_parentheses.prior index 2430db3cac0f749a81c6a3dc27a0ca0e5de85fdb..5296497f70f5b51ac0fde4c6633a8a82e9d5ca44 100644 --- a/test/prior_files/prior_with_parentheses.prior +++ b/test/prior_files/prior_with_parentheses.prior @@ -1 +1,2 @@ logA = Uniform(name='logA', minimum=10, maximum=20, latex_label=r'$\log(A_{0})$') +logB = Uniform(name='logB', minimum=10, maximum=20, latex_label='$\log(B_{0})$')