Skip to content
Snippets Groups Projects
Commit 6a2cf961 authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'extend-hyper-model' into 'master'

Allow user to provide `variable_names` in hyper Model

See merge request !1069
parents 7007cd22 09ea2cb3
No related branches found
No related tags found
1 merge request!1069Allow user to provide `variable_names` in hyper Model
Pipeline #355050 passed
......@@ -2,10 +2,14 @@ from ..core.utils import infer_args_from_function_except_n_args
class Model(object):
"""
Population model
r"""
Population model that combines a set of factorizable models.
This should take population parameters and return the probability.
.. math::
p(\theta | \Lambda) = \prod_{i} p_{i}(\theta | \Lambda)
"""
def __init__(self, model_functions=None):
......@@ -13,7 +17,11 @@ class Model(object):
Parameters
==========
model_functions: list
List of functions to compute.
List of callables to compute the probability.
If this includes classes, the `__call__` method should return the
probability.
The requires variables are chosen at run time based on either
inspection or querying a :code:`variable_names` attribute.
"""
self.models = model_functions
self._cached_parameters = {model: None for model in self.models}
......@@ -22,6 +30,21 @@ class Model(object):
self.parameters = dict()
def prob(self, data, **kwargs):
"""
Compute the total population probability for the provided data given
the keyword arguments.
Parameters
==========
data: dict
Dictionary containing the points at which to evaluate the
population model.
kwargs: dict
The population parameters. These cannot include any of
:code:`["dataset", "data", "self", "cls"]` unless the
:code:`variable_names` attribute is available for the relevant
model.
"""
probability = 1.0
for ii, function in enumerate(self.models):
function_parameters = self._get_function_parameters(function)
......@@ -37,11 +60,17 @@ class Model(object):
return probability
def _get_function_parameters(self, func):
"""If the function is a class method we need to remove more arguments"""
param_keys = infer_args_from_function_except_n_args(func, n=0)
ignore = ['dataset', 'self', 'cls']
for key in ignore:
if key in param_keys:
del param_keys[param_keys.index(key)]
"""
If the function is a class method we need to remove more arguments or
have the variable names provided in the class.
"""
if hasattr(func, "variable_names"):
param_keys = func.variable_names
else:
param_keys = infer_args_from_function_except_n_args(func, n=0)
ignore = ["dataset", "data", "self", "cls"]
for key in ignore:
if key in param_keys:
del param_keys[param_keys.index(key)]
parameters = {key: self.parameters[key] for key in param_keys}
return parameters
import unittest
import numpy as np
import pandas as pd
from parameterized import parameterized
import bilby.hyper as hyp
def _toy_function(data, dataset, self, cls, a, b, c):
return a
class _ToyClassNoVariableNames:
def __call__(self, a, b, c):
return a
class _ToyClassVariableNames:
variable_names = ["a", "b", "c"]
def __call__(self, **kwargs):
return kwargs.get("a", 1)
class TestHyperLikelihood(unittest.TestCase):
def setUp(self):
self.keys = ["a", "b", "c"]
......@@ -38,6 +57,18 @@ class TestHyperLikelihood(unittest.TestCase):
)
self.assertTrue(np.isnan(like.evidence_factor))
@parameterized.expand([
("func", _toy_function),
("class_no_names", _ToyClassNoVariableNames()),
("class_with_names", _ToyClassVariableNames()),
])
def test_get_function_parameters(self, _, model):
expected = dict(a=1, b=2, c=3)
model = hyp.model.Model([model])
model.parameters.update(expected)
result = model._get_function_parameters(model.models[0])
self.assertDictEqual(expected, result)
def test_len_samples_with_max_samples(self):
like = hyp.likelihood.HyperparameterLikelihood(
self.posteriors,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment