diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index f6ceb56370c370db095756dbc954cbdde6b791f4..a7951562cb3b317a492365f78ab1dc867e79bec3 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,6 +1,6 @@ # This script is an edited version of the example found at # https://git.ligo.org/lscsoft/example-ci-project/blob/python/.gitlab-ci.yml -# Each 0th-indendation level is a job that will be run within GitLab CI +# Each 0th-indentation level is a job that will be run within GitLab CI # The only exception are a short list of reserved keywords # # https://docs.gitlab.com/ee/ci/yaml/#gitlab-ci-yml @@ -25,7 +25,7 @@ python-2: script: - python setup.py install # Run tests without finding coverage - - pytest + - pytest --ignore=test/utils_py3_test.py # test example on python 3 python-3: diff --git a/bilby/core/prior.py b/bilby/core/prior.py index 82e97ec869fa6b860e11c396d8f4890c9e39ae2e..8e25a09538f84ae47340c52317b84c54f6038233 100644 --- a/bilby/core/prior.py +++ b/bilby/core/prior.py @@ -9,10 +9,9 @@ import os from collections import OrderedDict from future.utils import iteritems -from .utils import logger +from .utils import logger, infer_args_from_method from . import utils import bilby # noqa -import inspect class PriorSet(OrderedDict): @@ -426,8 +425,7 @@ class Prior(object): str: A string representation of this instance """ - subclass_args = inspect.getargspec(self.__init__).args - subclass_args.pop(0) + subclass_args = infer_args_from_method(self.__init__) prior_name = self.__class__.__name__ property_names = [p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)] diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py index a797836f26617118acffa67915388d9cbd90fe29..bf6435ab8002969d662e49ea22391fff5227076d 100644 --- a/bilby/core/sampler/pymc3.py +++ b/bilby/core/sampler/pymc3.py @@ -1,10 +1,9 @@ from __future__ import absolute_import, print_function from collections import OrderedDict -import inspect import numpy as np -from ..utils import derivatives, logger +from ..utils import derivatives, logger, infer_args_from_method from ..prior import Prior from ..result import Result from .base_sampler import Sampler, MCMCSampler @@ -437,7 +436,7 @@ class Pymc3(MCMCSampler): # then use that log_likelihood function, with the assumption that it # takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines # the likelihood within that context manager - likeargs = inspect.getargspec(self.likelihood.log_likelihood).args + likeargs = infer_args_from_method(self.likelihood.log_likelihood) if 'sampler' in likeargs: self.likelihood.log_likelihood(sampler=self) else: @@ -480,7 +479,7 @@ class Pymc3(MCMCSampler): for key in self.priors: # if the prior contains ln_prob method that takes a 'sampler' argument # then try using that - lnprobargs = inspect.getargspec(self.priors[key].ln_prob).args + lnprobargs = infer_args_from_method(self.priors[key].ln_prob) if 'sampler' in lnprobargs: try: self.pymc3_priors[key] = self.priors[key].ln_prob(sampler=self) @@ -500,7 +499,7 @@ class Pymc3(MCMCSampler): if pymc3distname not in pymc3.__dict__: raise ValueError("Prior '{}' is not a known PyMC3 distribution.".format(pymc3distname)) - reqargs = inspect.getargspec(pymc3.__dict__[pymc3distname].__init__).args[1:] + reqargs = infer_args_from_method(pymc3.__dict__[pymc3distname].__init__) # set keyword arguments priorkwargs = {} diff --git a/bilby/core/utils.py b/bilby/core/utils.py index 555bedad089c8b5c47c6faeae45b34179c823617..4c9474b8b34c1f7702a4191405e01c0e0f05e21e 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -18,10 +18,39 @@ radius_of_earth = 6371 * 1e3 # metres def infer_parameters_from_function(func): - """ Infers the arguments of function (except the first arg which is - assumed to be the dep. variable) + """ Infers the arguments of a function + (except the first arg which is assumed to be the dep. variable). + + Throws out *args and **kwargs type arguments + + Can deal with type hinting! + + Returns + --------- + list: A list of strings with the parameters """ - parameters = inspect.getargspec(func).args + return _infer_args_from_function_except_for_first_arg(func=func) + + +def infer_args_from_method(method): + """ Infers all arguments of a method except for 'self' + + Throws out *args and **kwargs type arguments. + + Can deal with type hinting! + + Returns + --------- + list: A list of strings with the parameters + """ + return _infer_args_from_function_except_for_first_arg(func=method) + + +def _infer_args_from_function_except_for_first_arg(func): + try: + parameters = inspect.getfullargspec(func).args + except AttributeError: + parameters = inspect.getargspec(func).args parameters.pop(0) return parameters diff --git a/bilby/hyper/model.py b/bilby/hyper/model.py index d4859dc7feb3894cfcdc97f7fe16fbbc3a962f69..6fe300cf9a83d427fbce6a5993a2c07060460a03 100644 --- a/bilby/hyper/model.py +++ b/bilby/hyper/model.py @@ -1,4 +1,4 @@ -import inspect +from bilby.core.utils import infer_parameters_from_function class Model(object): @@ -18,8 +18,9 @@ class Model(object): self.models = model_functions self.parameters = dict() - for function in self.models: - for key in inspect.getargspec(function).args[1:]: + for func in self.models: + param_keys = infer_parameters_from_function(func) + for key in param_keys: self.parameters[key] = None def prob(self, data): @@ -30,6 +31,7 @@ class Model(object): probability *= function(data, **self._get_function_parameters(function)) return probability - def _get_function_parameters(self, function): - parameters = {key: self.parameters[key] for key in inspect.getargspec(function).args[1:]} + def _get_function_parameters(self, func): + param_keys = infer_parameters_from_function(func) + parameters = {key: self.parameters[key] for key in param_keys} return parameters diff --git a/test/utils_py3_test.py b/test/utils_py3_test.py new file mode 100644 index 0000000000000000000000000000000000000000..da7aeab899ea92413277323b1f6d11d4bd562434 --- /dev/null +++ b/test/utils_py3_test.py @@ -0,0 +1,54 @@ +from __future__ import absolute_import, division + +import unittest + +from bilby.core import utils + + +class TestInferParameters(unittest.TestCase): + + def setUp(self): + def source_function1(freqs, a, b: int): + return None + + def source_function2(freqs, a, b, *args, **kwargs): + return None + + def source_function3(freqs, a, b: int, *args, **kwargs): + return None + + class TestClass: + def test_method(self, a, b: int, *args, **kwargs): + pass + + self.source1 = source_function1 + self.source2 = source_function2 + self.source3 = source_function3 + test_obj = TestClass() + self.source4 = test_obj.test_method + + def tearDown(self): + del self.source1 + del self.source2 + del self.source3 + del self.source4 + + def test_type_hinting(self): + expected = ['a', 'b'] + actual = utils.infer_parameters_from_function(self.source1) + self.assertListEqual(expected, actual) + + def test_args_kwargs_handling(self): + expected = ['a', 'b'] + actual = utils.infer_parameters_from_function(self.source2) + self.assertListEqual(expected, actual) + + def test_both(self): + expected = ['a', 'b'] + actual = utils.infer_parameters_from_function(self.source3) + self.assertListEqual(expected, actual) + + def test_self_handling(self): + expected = ['a', 'b'] + actual = utils.infer_args_from_method(self.source4) + self.assertListEqual(expected, actual) diff --git a/test/utils_test.py b/test/utils_test.py index bea0d19e73d3e91af2a69ef5a5c12880842ec4d5..9b3abf9414755f6e641397caa653a1e452a4ae69 100644 --- a/test/utils_test.py +++ b/test/utils_test.py @@ -1,9 +1,10 @@ from __future__ import absolute_import, division -import bilby import unittest import numpy as np -import matplotlib.pyplot as plt + +import bilby +from bilby.core import utils class TestFFT(unittest.TestCase): @@ -30,5 +31,34 @@ class TestFFT(unittest.TestCase): self.assertTrue(np.all(np.abs((tds - tds2) / tds) < 1e-12)) +class TestInferParameters(unittest.TestCase): + + def setUp(self): + def source_function(freqs, a, b, *args, **kwargs): + return None + + class TestClass: + def test_method(self, a, b, *args, **kwargs): + pass + + self.source1 = source_function + test_obj = TestClass() + self.source2 = test_obj.test_method + + def tearDown(self): + del self.source1 + del self.source2 + + def test_args_kwargs_handling(self): + expected = ['a', 'b'] + actual = utils.infer_parameters_from_function(self.source1) + self.assertListEqual(expected, actual) + + def test_self_handling(self): + expected = ['a', 'b'] + actual = utils.infer_args_from_method(self.source2) + self.assertListEqual(expected, actual) + + if __name__ == '__main__': unittest.main()