Skip to content
Snippets Groups Projects
Commit 41cb703c authored by Colm Talbot's avatar Colm Talbot
Browse files

FEAT: enable caching to be disabled in hyper.model.Model

parent e37c308a
No related branches found
No related tags found
No related merge requests found
from ..core.utils import infer_args_from_function_except_n_args
class Model(object):
class Model:
r"""
Population model that combines a set of factorizable models.
......@@ -12,18 +12,24 @@ class Model(object):
p(\theta | \Lambda) = \prod_{i} p_{i}(\theta | \Lambda)
"""
def __init__(self, model_functions=None):
def __init__(self, model_functions=None, cache=True):
"""
Parameters
==========
model_functions: list
List of callables to compute the probability.
If this includes classes, the `__call__` method should return the
probability.
If this includes classes, the :code:`__call__`: method
should return the probability.
The requires variables are chosen at run time based on either
inspection or querying a :code:`variable_names` attribute.
cache: bool
Whether to cache the value returned by the model functions,
default=:code:`True`. The caching only looks at the parameters
not the data, so should be used with caution. The caching also
breaks :code:`jax` JIT compilation.
"""
self.models = model_functions
self.cache = cache
self._cached_parameters = {model: None for model in self.models}
self._cached_probability = {model: None for model in self.models}
......@@ -48,14 +54,18 @@ class Model(object):
probability = 1.0
for ii, function in enumerate(self.models):
function_parameters = self._get_function_parameters(function)
if self._cached_parameters[function] == function_parameters:
if (
self.cache
and self._cached_parameters[function] == function_parameters
):
new_probability = self._cached_probability[function]
else:
new_probability = function(
data, **self._get_function_parameters(function)
)
self._cached_parameters[function] = function_parameters
self._cached_probability[function] = new_probability
if self.cache:
self._cached_parameters[function] = function_parameters
self._cached_probability[function] = new_probability
probability *= new_probability
return probability
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment