From 4d5d6640386e51882ca41b03b6596868150d27b4 Mon Sep 17 00:00:00 2001
From: Moritz Huebner <moritz.huebner@ligo.org>
Date: Mon, 9 Dec 2019 18:13:13 -0600
Subject: [PATCH] Fix waveform generator caching issue

---
 bilby/gw/waveform_generator.py  |  4 +++-
 test/waveform_generator_test.py | 23 +++++++++++++++++++++++
 2 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/bilby/gw/waveform_generator.py b/bilby/gw/waveform_generator.py
index 80d2334c6..ace070353 100644
--- a/bilby/gw/waveform_generator.py
+++ b/bilby/gw/waveform_generator.py
@@ -152,7 +152,8 @@ class WaveformGenerator(object):
                           transformed_model_data_points, parameters):
         if parameters is not None:
             self.parameters = parameters
-        if self.parameters == self._cache['parameters'] and self._cache['model'] == model:
+        if self.parameters == self._cache['parameters'] and self._cache['model'] == model and \
+                self._cache['transformed_model'] == transformed_model:
             return self._cache['waveform']
         if model is not None:
             model_strain = self._strain_from_model(model_data_points, model)
@@ -164,6 +165,7 @@ class WaveformGenerator(object):
         self._cache['waveform'] = model_strain
         self._cache['parameters'] = self.parameters.copy()
         self._cache['model'] = model
+        self._cache['transformed_model'] = transformed_model
         return model_strain
 
     def _strain_from_model(self, model_data_points, model):
diff --git a/test/waveform_generator_test.py b/test/waveform_generator_test.py
index af45d2e63..2d55e27eb 100644
--- a/test/waveform_generator_test.py
+++ b/test/waveform_generator_test.py
@@ -16,6 +16,10 @@ def dummy_func_dict_return_value(frequency_array, amplitude, mu, sigma, ra, dec,
     return ht
 
 
+def dummy_func_array_return_value_2(array, amplitude, mu, sigma, ra, dec, geocent_time, psi):
+    return dict(plus=np.array(array), cross=np.array(array))
+
+
 class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestCase):
 
     def setUp(self):
@@ -302,6 +306,25 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
             parameters=self.simulation_parameters)
         self.assertNotEqual(original_waveform, new_waveform)
 
+    def test_frequency_domain_caching_changing_model(self):
+        original_waveform = self.waveform_generator.frequency_domain_strain(
+            parameters=self.simulation_parameters)
+        self.waveform_generator.frequency_domain_source_model = dummy_func_array_return_value_2
+        new_waveform = self.waveform_generator.frequency_domain_strain(
+            parameters=self.simulation_parameters)
+        self.assertFalse(np.array_equal(original_waveform['plus'], new_waveform['plus']))
+
+    def test_time_domain_caching_changing_model(self):
+        self.waveform_generator = \
+            bilby.gw.waveform_generator.WaveformGenerator(duration=1, sampling_frequency=4096,
+                                                          time_domain_source_model=dummy_func_dict_return_value)
+        original_waveform = self.waveform_generator.frequency_domain_strain(
+            parameters=self.simulation_parameters)
+        self.waveform_generator.time_domain_source_model = dummy_func_array_return_value_2
+        new_waveform = self.waveform_generator.frequency_domain_strain(
+            parameters=self.simulation_parameters)
+        self.assertFalse(np.array_equal(original_waveform['plus'], new_waveform['plus']))
+
 
 class TestTimeDomainStrainMethod(unittest.TestCase):
 
-- 
GitLab