Commit 680b7d53 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Resolve "Replace inspect.getargspec references"

parent 5daa5b4c
# This script is an edited version of the example found at # This script is an edited version of the example found at
# https://git.ligo.org/lscsoft/example-ci-project/blob/python/.gitlab-ci.yml # https://git.ligo.org/lscsoft/example-ci-project/blob/python/.gitlab-ci.yml
# Each 0th-indendation level is a job that will be run within GitLab CI # Each 0th-indentation level is a job that will be run within GitLab CI
# The only exception are a short list of reserved keywords # The only exception are a short list of reserved keywords
# #
# https://docs.gitlab.com/ee/ci/yaml/#gitlab-ci-yml # https://docs.gitlab.com/ee/ci/yaml/#gitlab-ci-yml
...@@ -25,7 +25,7 @@ python-2: ...@@ -25,7 +25,7 @@ python-2:
script: script:
- python setup.py install - python setup.py install
# Run tests without finding coverage # Run tests without finding coverage
- pytest - pytest --ignore=test/utils_py3_test.py
# test example on python 3 # test example on python 3
python-3: python-3:
......
...@@ -9,10 +9,9 @@ import os ...@@ -9,10 +9,9 @@ import os
from collections import OrderedDict from collections import OrderedDict
from future.utils import iteritems from future.utils import iteritems
from .utils import logger from .utils import logger, infer_args_from_method
from . import utils from . import utils
import bilby # noqa import bilby # noqa
import inspect
class PriorSet(OrderedDict): class PriorSet(OrderedDict):
...@@ -426,8 +425,7 @@ class Prior(object): ...@@ -426,8 +425,7 @@ class Prior(object):
str: A string representation of this instance str: A string representation of this instance
""" """
subclass_args = inspect.getargspec(self.__init__).args subclass_args = infer_args_from_method(self.__init__)
subclass_args.pop(0)
prior_name = self.__class__.__name__ prior_name = self.__class__.__name__
property_names = [p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)] property_names = [p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)]
......
from __future__ import absolute_import, print_function from __future__ import absolute_import, print_function
from collections import OrderedDict from collections import OrderedDict
import inspect
import numpy as np import numpy as np
from ..utils import derivatives, logger from ..utils import derivatives, logger, infer_args_from_method
from ..prior import Prior from ..prior import Prior
from ..result import Result from ..result import Result
from .base_sampler import Sampler, MCMCSampler from .base_sampler import Sampler, MCMCSampler
...@@ -437,7 +436,7 @@ class Pymc3(MCMCSampler): ...@@ -437,7 +436,7 @@ class Pymc3(MCMCSampler):
# then use that log_likelihood function, with the assumption that it # then use that log_likelihood function, with the assumption that it
# takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines # takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
# the likelihood within that context manager # the likelihood within that context manager
likeargs = inspect.getargspec(self.likelihood.log_likelihood).args likeargs = infer_args_from_method(self.likelihood.log_likelihood)
if 'sampler' in likeargs: if 'sampler' in likeargs:
self.likelihood.log_likelihood(sampler=self) self.likelihood.log_likelihood(sampler=self)
else: else:
...@@ -480,7 +479,7 @@ class Pymc3(MCMCSampler): ...@@ -480,7 +479,7 @@ class Pymc3(MCMCSampler):
for key in self.priors: for key in self.priors:
# if the prior contains ln_prob method that takes a 'sampler' argument # if the prior contains ln_prob method that takes a 'sampler' argument
# then try using that # then try using that
lnprobargs = inspect.getargspec(self.priors[key].ln_prob).args lnprobargs = infer_args_from_method(self.priors[key].ln_prob)
if 'sampler' in lnprobargs: if 'sampler' in lnprobargs:
try: try:
self.pymc3_priors[key] = self.priors[key].ln_prob(sampler=self) self.pymc3_priors[key] = self.priors[key].ln_prob(sampler=self)
...@@ -500,7 +499,7 @@ class Pymc3(MCMCSampler): ...@@ -500,7 +499,7 @@ class Pymc3(MCMCSampler):
if pymc3distname not in pymc3.__dict__: if pymc3distname not in pymc3.__dict__:
raise ValueError("Prior '{}' is not a known PyMC3 distribution.".format(pymc3distname)) raise ValueError("Prior '{}' is not a known PyMC3 distribution.".format(pymc3distname))
reqargs = inspect.getargspec(pymc3.__dict__[pymc3distname].__init__).args[1:] reqargs = infer_args_from_method(pymc3.__dict__[pymc3distname].__init__)
# set keyword arguments # set keyword arguments
priorkwargs = {} priorkwargs = {}
......
...@@ -18,10 +18,39 @@ radius_of_earth = 6371 * 1e3 # metres ...@@ -18,10 +18,39 @@ radius_of_earth = 6371 * 1e3 # metres
def infer_parameters_from_function(func): def infer_parameters_from_function(func):
""" Infers the arguments of function (except the first arg which is """ Infers the arguments of a function
assumed to be the dep. variable) (except the first arg which is assumed to be the dep. variable).
Throws out *args and **kwargs type arguments
Can deal with type hinting!
Returns
---------
list: A list of strings with the parameters
""" """
parameters = inspect.getargspec(func).args return _infer_args_from_function_except_for_first_arg(func=func)
def infer_args_from_method(method):
""" Infers all arguments of a method except for 'self'
Throws out *args and **kwargs type arguments.
Can deal with type hinting!
Returns
---------
list: A list of strings with the parameters
"""
return _infer_args_from_function_except_for_first_arg(func=method)
def _infer_args_from_function_except_for_first_arg(func):
try:
parameters = inspect.getfullargspec(func).args
except AttributeError:
parameters = inspect.getargspec(func).args
parameters.pop(0) parameters.pop(0)
return parameters return parameters
......
import inspect from bilby.core.utils import infer_parameters_from_function
class Model(object): class Model(object):
...@@ -18,8 +18,9 @@ class Model(object): ...@@ -18,8 +18,9 @@ class Model(object):
self.models = model_functions self.models = model_functions
self.parameters = dict() self.parameters = dict()
for function in self.models: for func in self.models:
for key in inspect.getargspec(function).args[1:]: param_keys = infer_parameters_from_function(func)
for key in param_keys:
self.parameters[key] = None self.parameters[key] = None
def prob(self, data): def prob(self, data):
...@@ -30,6 +31,7 @@ class Model(object): ...@@ -30,6 +31,7 @@ class Model(object):
probability *= function(data, **self._get_function_parameters(function)) probability *= function(data, **self._get_function_parameters(function))
return probability return probability
def _get_function_parameters(self, function): def _get_function_parameters(self, func):
parameters = {key: self.parameters[key] for key in inspect.getargspec(function).args[1:]} param_keys = infer_parameters_from_function(func)
parameters = {key: self.parameters[key] for key in param_keys}
return parameters return parameters
from __future__ import absolute_import, division
import unittest
from bilby.core import utils
class TestInferParameters(unittest.TestCase):
def setUp(self):
def source_function1(freqs, a, b: int):
return None
def source_function2(freqs, a, b, *args, **kwargs):
return None
def source_function3(freqs, a, b: int, *args, **kwargs):
return None
class TestClass:
def test_method(self, a, b: int, *args, **kwargs):
pass
self.source1 = source_function1
self.source2 = source_function2
self.source3 = source_function3
test_obj = TestClass()
self.source4 = test_obj.test_method
def tearDown(self):
del self.source1
del self.source2
del self.source3
del self.source4
def test_type_hinting(self):
expected = ['a', 'b']
actual = utils.infer_parameters_from_function(self.source1)
self.assertListEqual(expected, actual)
def test_args_kwargs_handling(self):
expected = ['a', 'b']
actual = utils.infer_parameters_from_function(self.source2)
self.assertListEqual(expected, actual)
def test_both(self):
expected = ['a', 'b']
actual = utils.infer_parameters_from_function(self.source3)
self.assertListEqual(expected, actual)
def test_self_handling(self):
expected = ['a', 'b']
actual = utils.infer_args_from_method(self.source4)
self.assertListEqual(expected, actual)
from __future__ import absolute_import, division from __future__ import absolute_import, division
import bilby
import unittest import unittest
import numpy as np import numpy as np
import matplotlib.pyplot as plt
import bilby
from bilby.core import utils
class TestFFT(unittest.TestCase): class TestFFT(unittest.TestCase):
...@@ -30,5 +31,34 @@ class TestFFT(unittest.TestCase): ...@@ -30,5 +31,34 @@ class TestFFT(unittest.TestCase):
self.assertTrue(np.all(np.abs((tds - tds2) / tds) < 1e-12)) self.assertTrue(np.all(np.abs((tds - tds2) / tds) < 1e-12))
class TestInferParameters(unittest.TestCase):
def setUp(self):
def source_function(freqs, a, b, *args, **kwargs):
return None
class TestClass:
def test_method(self, a, b, *args, **kwargs):
pass
self.source1 = source_function
test_obj = TestClass()
self.source2 = test_obj.test_method
def tearDown(self):
del self.source1
del self.source2
def test_args_kwargs_handling(self):
expected = ['a', 'b']
actual = utils.infer_parameters_from_function(self.source1)
self.assertListEqual(expected, actual)
def test_self_handling(self):
expected = ['a', 'b']
actual = utils.infer_args_from_method(self.source2)
self.assertListEqual(expected, actual)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment