From b30462d0ddf03ad772fac9d14f6c6e44f849319a Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Tue, 9 Apr 2019 22:11:05 -0500
Subject: [PATCH] add basic waveform caching

---
 bilby/gw/waveform_generator.py  |  5 +++++
 test/waveform_generator_test.py | 27 +++++++++++++++++++++++++++
 2 files changed, 32 insertions(+)

diff --git a/bilby/gw/waveform_generator.py b/bilby/gw/waveform_generator.py
index 00e24bde6..98d069146 100644
--- a/bilby/gw/waveform_generator.py
+++ b/bilby/gw/waveform_generator.py
@@ -62,6 +62,7 @@ class WaveformGenerator(object):
             self.waveform_arguments = dict()
         if isinstance(parameters, dict):
             self.parameters = parameters
+        self._cache = dict(parameters=None, waveform=None)
 
     def __repr__(self):
         if self.frequency_domain_source_model is not None:
@@ -147,6 +148,8 @@ class WaveformGenerator(object):
                           transformed_model_data_points, parameters):
         if parameters is not None:
             self.parameters = parameters
+        if self.parameters == self._cache['parameters']:
+            return self._cache['waveform']
         if model is not None:
             model_strain = self._strain_from_model(model_data_points, model)
         elif transformed_model is not None:
@@ -154,6 +157,8 @@ class WaveformGenerator(object):
                                                                transformation_function)
         else:
             raise RuntimeError("No source model given")
+        self._cache['waveform'] = model_strain
+        self._cache['parameters'] = self.parameters.copy()
         return model_strain
 
     def _strain_from_model(self, model_data_points, model):
diff --git a/test/waveform_generator_test.py b/test/waveform_generator_test.py
index 8eba859ac..15aa65e46 100644
--- a/test/waveform_generator_test.py
+++ b/test/waveform_generator_test.py
@@ -263,6 +263,19 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
         self.assertListEqual(sorted(self.waveform_generator.parameters.keys()),
                              sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi']))
 
+    def test_caching_with_parameters(self):
+        original_waveform = self.waveform_generator.frequency_domain_strain(
+            parameters=self.simulation_parameters)
+        new_waveform = self.waveform_generator.frequency_domain_strain(
+            parameters=self.simulation_parameters)
+        self.assertDictEqual(original_waveform, new_waveform)
+
+    def test_caching_without_parameters(self):
+        original_waveform = self.waveform_generator.frequency_domain_strain(
+            parameters=self.simulation_parameters)
+        new_waveform = self.waveform_generator.frequency_domain_strain()
+        self.assertDictEqual(original_waveform, new_waveform)
+
 
 class TestTimeDomainStrainMethod(unittest.TestCase):
 
@@ -354,6 +367,20 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
         self.assertListEqual(sorted(self.waveform_generator.parameters.keys()),
                              sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi']))
 
+    def test_caching_with_parameters(self):
+        original_waveform = self.waveform_generator.time_domain_strain(
+            parameters=self.simulation_parameters)
+        new_waveform = self.waveform_generator.time_domain_strain(
+            parameters=self.simulation_parameters)
+        self.assertDictEqual(original_waveform, new_waveform)
+
+    def test_caching_without_parameters(self):
+        original_waveform = self.waveform_generator.time_domain_strain(
+            parameters=self.simulation_parameters)
+        new_waveform = self.waveform_generator.time_domain_strain()
+        self.assertDictEqual(original_waveform, new_waveform)
+
+
 
 if __name__ == '__main__':
     unittest.main()
-- 
GitLab