From 41cb703c5a0c283843095ab94750c5f9ee5ba491 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Thu, 8 Aug 2024 12:40:53 +0000
Subject: [PATCH] FEAT: enable caching to be disabled in hyper.model.Model

---
 bilby/hyper/model.py | 24 +++++++++++++++++-------
 1 file changed, 17 insertions(+), 7 deletions(-)

diff --git a/bilby/hyper/model.py b/bilby/hyper/model.py
index 359274079..3e84fbb6b 100644
--- a/bilby/hyper/model.py
+++ b/bilby/hyper/model.py
@@ -1,7 +1,7 @@
 from ..core.utils import infer_args_from_function_except_n_args
 
 
-class Model(object):
+class Model:
     r"""
     Population model that combines a set of factorizable models.
 
@@ -12,18 +12,24 @@ class Model(object):
         p(\theta | \Lambda) = \prod_{i} p_{i}(\theta | \Lambda)
     """
 
-    def __init__(self, model_functions=None):
+    def __init__(self, model_functions=None, cache=True):
         """
         Parameters
         ==========
         model_functions: list
             List of callables to compute the probability.
-            If this includes classes, the `__call__` method should return the
-            probability.
+            If this includes classes, the :code:`__call__`: method
+            should return the probability.
             The requires variables are chosen at run time based on either
             inspection or querying a :code:`variable_names` attribute.
+        cache: bool
+            Whether to cache the value returned by the model functions,
+            default=:code:`True`. The caching only looks at the parameters
+            not the data, so should be used with caution. The caching also
+            breaks :code:`jax` JIT compilation.
         """
         self.models = model_functions
+        self.cache = cache
         self._cached_parameters = {model: None for model in self.models}
         self._cached_probability = {model: None for model in self.models}
 
@@ -48,14 +54,18 @@ class Model(object):
         probability = 1.0
         for ii, function in enumerate(self.models):
             function_parameters = self._get_function_parameters(function)
-            if self._cached_parameters[function] == function_parameters:
+            if (
+                self.cache
+                and self._cached_parameters[function] == function_parameters
+            ):
                 new_probability = self._cached_probability[function]
             else:
                 new_probability = function(
                     data, **self._get_function_parameters(function)
                 )
-                self._cached_parameters[function] = function_parameters
-                self._cached_probability[function] = new_probability
+                if self.cache:
+                    self._cached_parameters[function] = function_parameters
+                    self._cached_probability[function] = new_probability
             probability *= new_probability
         return probability
 
-- 
GitLab