diff --git a/bilby/hyper/model.py b/bilby/hyper/model.py index 9b6e45b53251daccffc3bc5845d51718e1317ba5..925525e68607fd2aaeb65377ae5d84464879e4d3 100644 --- a/bilby/hyper/model.py +++ b/bilby/hyper/model.py @@ -16,13 +16,24 @@ class Model(object): List of functions to compute. """ self.models = model_functions + self._cached_parameters = {model: None for model in self.models} + self._cached_probability = {model: None for model in self.models} self.parameters = dict() def prob(self, data, **kwargs): probability = 1.0 for ii, function in enumerate(self.models): - probability *= function(data, **self._get_function_parameters(function)) + function_parameters = self._get_function_parameters(function) + if self._cached_parameters[function] == function_parameters: + new_probability = self._cached_probability[function] + else: + new_probability = function( + data, **self._get_function_parameters(function) + ) + self._cached_parameters[function] = function_parameters + self._cached_probability[function] = new_probability + probability *= new_probability return probability def _get_function_parameters(self, func):