diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py index e11f7862ad7360737760eaa272cd6157f85d8ba6..1917fb3d0b86e36a923b4aaa646611724aa2b061 100644 --- a/bilby/core/prior/base.py +++ b/bilby/core/prior/base.py @@ -408,6 +408,10 @@ class Prior(object): The string is interpreted as a call to instantiate another prior class, Bilby will attempt to recursively construct that prior, e.g., Uniform(minimum=0, maximum=1), my.custom.PriorClass(**kwargs). + - Else If the string contains a ".": + It is treated as a path to a Python function and imported, e.g., + "some_module.some_function" returns + :code:`import some_module; return some_module.some_function` - Else: Try to evaluate the string using `eval`. Only built-in functions and numpy methods can be used, e.g., np.pi / 2, 1.57. @@ -448,10 +452,17 @@ class Prior(object): try: val = eval(val, dict(), dict(np=np, inf=np.inf, pi=np.pi)) except NameError: - raise TypeError( - "Cannot evaluate prior, " - "failed to parse argument {}".format(val) - ) + if "." in val: + module = '.'.join(val.split('.')[:-1]) + func = val.split('.')[-1] + new_val = getattr(import_module(module), func, val) + if val == new_val: + raise TypeError( + "Cannot evaluate prior, " + f"failed to parse argument {val}" + ) + else: + val = new_val return val diff --git a/test/core/prior/dict_test.py b/test/core/prior/dict_test.py index 0bfc02ad096bf01ecea0ee3c1bb7a04daf493686..f2045ed227eddbc63d2038bc806a4ea442b63c08 100644 --- a/test/core/prior/dict_test.py +++ b/test/core/prior/dict_test.py @@ -464,6 +464,17 @@ class TestLoadPrior(unittest.TestCase): prior = bilby.core.prior.PriorDict(filename) self.assertTrue(isinstance(prior["logA"], bilby.core.prior.Uniform)) + def test_load_prior_with_function(self): + filename = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "prior_files/prior_with_function.prior", + ) + prior = bilby.core.prior.ConditionalPriorDict(filename) + self.assertTrue("mass_1" in prior) + self.assertTrue("mass_2" in prior) + samples = prior.sample(10000) + self.assertTrue(all(samples["mass_1"] > samples["mass_2"])) + class TestCreateDefaultPrior(unittest.TestCase): def test_none_behaviour(self): diff --git a/test/core/prior/prior_files/prior_with_function.prior b/test/core/prior/prior_files/prior_with_function.prior new file mode 100644 index 0000000000000000000000000000000000000000..6ecb9f801178829d52f6623b212f3666a20f02f7 --- /dev/null +++ b/test/core/prior/prior_files/prior_with_function.prior @@ -0,0 +1,2 @@ +mass_1 = Uniform(name='mass_1', minimum=5, maximum=100, unit='$M_{\odot}$', boundary=None) +mass_2 = ConditionalUniform(name="mass_1", minimum=5, maximum=100, condition_func="bilby.gw.prior.secondary_mass_condition_function")