From 09ea2cb3091312550b5723c215b9829fc9a8c023 Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Thu, 10 Feb 2022 14:15:42 +0000 Subject: [PATCH] Allow user to provide `variable_names` in hyper Model --- bilby/hyper/model.py | 47 ++++++++++++++++++++++++++++++------- test/hyper/hyper_pe_test.py | 31 ++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 9 deletions(-) diff --git a/bilby/hyper/model.py b/bilby/hyper/model.py index e5f4cb7bf..359274079 100644 --- a/bilby/hyper/model.py +++ b/bilby/hyper/model.py @@ -2,10 +2,14 @@ from ..core.utils import infer_args_from_function_except_n_args class Model(object): - """ - Population model + r""" + Population model that combines a set of factorizable models. This should take population parameters and return the probability. + + .. math:: + + p(\theta | \Lambda) = \prod_{i} p_{i}(\theta | \Lambda) """ def __init__(self, model_functions=None): @@ -13,7 +17,11 @@ class Model(object): Parameters ========== model_functions: list - List of functions to compute. + List of callables to compute the probability. + If this includes classes, the `__call__` method should return the + probability. + The requires variables are chosen at run time based on either + inspection or querying a :code:`variable_names` attribute. """ self.models = model_functions self._cached_parameters = {model: None for model in self.models} @@ -22,6 +30,21 @@ class Model(object): self.parameters = dict() def prob(self, data, **kwargs): + """ + Compute the total population probability for the provided data given + the keyword arguments. + + Parameters + ========== + data: dict + Dictionary containing the points at which to evaluate the + population model. + kwargs: dict + The population parameters. These cannot include any of + :code:`["dataset", "data", "self", "cls"]` unless the + :code:`variable_names` attribute is available for the relevant + model. + """ probability = 1.0 for ii, function in enumerate(self.models): function_parameters = self._get_function_parameters(function) @@ -37,11 +60,17 @@ class Model(object): return probability def _get_function_parameters(self, func): - """If the function is a class method we need to remove more arguments""" - param_keys = infer_args_from_function_except_n_args(func, n=0) - ignore = ['dataset', 'self', 'cls'] - for key in ignore: - if key in param_keys: - del param_keys[param_keys.index(key)] + """ + If the function is a class method we need to remove more arguments or + have the variable names provided in the class. + """ + if hasattr(func, "variable_names"): + param_keys = func.variable_names + else: + param_keys = infer_args_from_function_except_n_args(func, n=0) + ignore = ["dataset", "data", "self", "cls"] + for key in ignore: + if key in param_keys: + del param_keys[param_keys.index(key)] parameters = {key: self.parameters[key] for key in param_keys} return parameters diff --git a/test/hyper/hyper_pe_test.py b/test/hyper/hyper_pe_test.py index 127738020..4ca58927d 100644 --- a/test/hyper/hyper_pe_test.py +++ b/test/hyper/hyper_pe_test.py @@ -1,9 +1,28 @@ import unittest import numpy as np import pandas as pd +from parameterized import parameterized + import bilby.hyper as hyp +def _toy_function(data, dataset, self, cls, a, b, c): + return a + + +class _ToyClassNoVariableNames: + def __call__(self, a, b, c): + return a + + +class _ToyClassVariableNames: + + variable_names = ["a", "b", "c"] + + def __call__(self, **kwargs): + return kwargs.get("a", 1) + + class TestHyperLikelihood(unittest.TestCase): def setUp(self): self.keys = ["a", "b", "c"] @@ -38,6 +57,18 @@ class TestHyperLikelihood(unittest.TestCase): ) self.assertTrue(np.isnan(like.evidence_factor)) + @parameterized.expand([ + ("func", _toy_function), + ("class_no_names", _ToyClassNoVariableNames()), + ("class_with_names", _ToyClassVariableNames()), + ]) + def test_get_function_parameters(self, _, model): + expected = dict(a=1, b=2, c=3) + model = hyp.model.Model([model]) + model.parameters.update(expected) + result = model._get_function_parameters(model.models[0]) + self.assertDictEqual(expected, result) + def test_len_samples_with_max_samples(self): like = hyp.likelihood.HyperparameterLikelihood( self.posteriors, -- GitLab