Skip to content
Snippets Groups Projects
Commit 3e3c5cd5 authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'prior-load-functions' into 'master'

FEATURE: allow prior arguments to be functions

See merge request lscsoft/bilby!1144
parents ad2c9415 20a6c9f5
No related branches found
No related tags found
1 merge request!1144FEATURE: allow prior arguments to be functions
Pipeline #477247 passed
......@@ -408,6 +408,10 @@ class Prior(object):
The string is interpreted as a call to instantiate another prior
class, Bilby will attempt to recursively construct that prior,
e.g., Uniform(minimum=0, maximum=1), my.custom.PriorClass(**kwargs).
- Else If the string contains a ".":
It is treated as a path to a Python function and imported, e.g.,
"some_module.some_function" returns
:code:`import some_module; return some_module.some_function`
- Else:
Try to evaluate the string using `eval`. Only built-in functions
and numpy methods can be used, e.g., np.pi / 2, 1.57.
......@@ -448,10 +452,17 @@ class Prior(object):
try:
val = eval(val, dict(), dict(np=np, inf=np.inf, pi=np.pi))
except NameError:
raise TypeError(
"Cannot evaluate prior, "
"failed to parse argument {}".format(val)
)
if "." in val:
module = '.'.join(val.split('.')[:-1])
func = val.split('.')[-1]
new_val = getattr(import_module(module), func, val)
if val == new_val:
raise TypeError(
"Cannot evaluate prior, "
f"failed to parse argument {val}"
)
else:
val = new_val
return val
......
......@@ -464,6 +464,17 @@ class TestLoadPrior(unittest.TestCase):
prior = bilby.core.prior.PriorDict(filename)
self.assertTrue(isinstance(prior["logA"], bilby.core.prior.Uniform))
def test_load_prior_with_function(self):
filename = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"prior_files/prior_with_function.prior",
)
prior = bilby.core.prior.ConditionalPriorDict(filename)
self.assertTrue("mass_1" in prior)
self.assertTrue("mass_2" in prior)
samples = prior.sample(10000)
self.assertTrue(all(samples["mass_1"] > samples["mass_2"]))
class TestCreateDefaultPrior(unittest.TestCase):
def test_none_behaviour(self):
......
mass_1 = Uniform(name='mass_1', minimum=5, maximum=100, unit='$M_{\odot}$', boundary=None)
mass_2 = ConditionalUniform(name="mass_1", minimum=5, maximum=100, condition_func="bilby.gw.prior.secondary_mass_condition_function")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment