diff --git a/bilby/hyper/model.py b/bilby/hyper/model.py
index e5f4cb7bf4b1da6a807be9e02a173af94f73c171..359274079241b866e2311160306e3b1a2f41da5c 100644
--- a/bilby/hyper/model.py
+++ b/bilby/hyper/model.py
@@ -2,10 +2,14 @@ from ..core.utils import infer_args_from_function_except_n_args
 
 
 class Model(object):
-    """
-    Population model
+    r"""
+    Population model that combines a set of factorizable models.
 
     This should take population parameters and return the probability.
+
+    .. math::
+
+        p(\theta | \Lambda) = \prod_{i} p_{i}(\theta | \Lambda)
     """
 
     def __init__(self, model_functions=None):
@@ -13,7 +17,11 @@ class Model(object):
         Parameters
         ==========
         model_functions: list
-            List of functions to compute.
+            List of callables to compute the probability.
+            If this includes classes, the `__call__` method should return the
+            probability.
+            The requires variables are chosen at run time based on either
+            inspection or querying a :code:`variable_names` attribute.
         """
         self.models = model_functions
         self._cached_parameters = {model: None for model in self.models}
@@ -22,6 +30,21 @@ class Model(object):
         self.parameters = dict()
 
     def prob(self, data, **kwargs):
+        """
+        Compute the total population probability for the provided data given
+        the keyword arguments.
+
+        Parameters
+        ==========
+        data: dict
+            Dictionary containing the points at which to evaluate the
+            population model.
+        kwargs: dict
+            The population parameters. These cannot include any of
+            :code:`["dataset", "data", "self", "cls"]` unless the
+            :code:`variable_names` attribute is available for the relevant
+            model.
+        """
         probability = 1.0
         for ii, function in enumerate(self.models):
             function_parameters = self._get_function_parameters(function)
@@ -37,11 +60,17 @@ class Model(object):
         return probability
 
     def _get_function_parameters(self, func):
-        """If the function is a class method we need to remove more arguments"""
-        param_keys = infer_args_from_function_except_n_args(func, n=0)
-        ignore = ['dataset', 'self', 'cls']
-        for key in ignore:
-            if key in param_keys:
-                del param_keys[param_keys.index(key)]
+        """
+        If the function is a class method we need to remove more arguments or
+        have the variable names provided in the class.
+        """
+        if hasattr(func, "variable_names"):
+            param_keys = func.variable_names
+        else:
+            param_keys = infer_args_from_function_except_n_args(func, n=0)
+            ignore = ["dataset", "data", "self", "cls"]
+            for key in ignore:
+                if key in param_keys:
+                    del param_keys[param_keys.index(key)]
         parameters = {key: self.parameters[key] for key in param_keys}
         return parameters
diff --git a/test/hyper/hyper_pe_test.py b/test/hyper/hyper_pe_test.py
index 12773802008d4d64dafed5fc647953472574d50b..4ca58927d1cb0bfceb3ee06fa4db5f8c401bf2eb 100644
--- a/test/hyper/hyper_pe_test.py
+++ b/test/hyper/hyper_pe_test.py
@@ -1,9 +1,28 @@
 import unittest
 import numpy as np
 import pandas as pd
+from parameterized import parameterized
+
 import bilby.hyper as hyp
 
 
+def _toy_function(data, dataset, self, cls, a, b, c):
+    return a
+
+
+class _ToyClassNoVariableNames:
+    def __call__(self, a, b, c):
+        return a
+
+
+class _ToyClassVariableNames:
+
+    variable_names = ["a", "b", "c"]
+
+    def __call__(self, **kwargs):
+        return kwargs.get("a", 1)
+
+
 class TestHyperLikelihood(unittest.TestCase):
     def setUp(self):
         self.keys = ["a", "b", "c"]
@@ -38,6 +57,18 @@ class TestHyperLikelihood(unittest.TestCase):
         )
         self.assertTrue(np.isnan(like.evidence_factor))
 
+    @parameterized.expand([
+        ("func", _toy_function),
+        ("class_no_names", _ToyClassNoVariableNames()),
+        ("class_with_names", _ToyClassVariableNames()),
+    ])
+    def test_get_function_parameters(self, _, model):
+        expected = dict(a=1, b=2, c=3)
+        model = hyp.model.Model([model])
+        model.parameters.update(expected)
+        result = model._get_function_parameters(model.models[0])
+        self.assertDictEqual(expected, result)
+
     def test_len_samples_with_max_samples(self):
         like = hyp.likelihood.HyperparameterLikelihood(
             self.posteriors,