From 4d5d6640386e51882ca41b03b6596868150d27b4 Mon Sep 17 00:00:00 2001 From: Moritz Huebner <moritz.huebner@ligo.org> Date: Mon, 9 Dec 2019 18:13:13 -0600 Subject: [PATCH] Fix waveform generator caching issue --- bilby/gw/waveform_generator.py | 4 +++- test/waveform_generator_test.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/bilby/gw/waveform_generator.py b/bilby/gw/waveform_generator.py index 80d2334c6..ace070353 100644 --- a/bilby/gw/waveform_generator.py +++ b/bilby/gw/waveform_generator.py @@ -152,7 +152,8 @@ class WaveformGenerator(object): transformed_model_data_points, parameters): if parameters is not None: self.parameters = parameters - if self.parameters == self._cache['parameters'] and self._cache['model'] == model: + if self.parameters == self._cache['parameters'] and self._cache['model'] == model and \ + self._cache['transformed_model'] == transformed_model: return self._cache['waveform'] if model is not None: model_strain = self._strain_from_model(model_data_points, model) @@ -164,6 +165,7 @@ class WaveformGenerator(object): self._cache['waveform'] = model_strain self._cache['parameters'] = self.parameters.copy() self._cache['model'] = model + self._cache['transformed_model'] = transformed_model return model_strain def _strain_from_model(self, model_data_points, model): diff --git a/test/waveform_generator_test.py b/test/waveform_generator_test.py index af45d2e63..2d55e27eb 100644 --- a/test/waveform_generator_test.py +++ b/test/waveform_generator_test.py @@ -16,6 +16,10 @@ def dummy_func_dict_return_value(frequency_array, amplitude, mu, sigma, ra, dec, return ht +def dummy_func_array_return_value_2(array, amplitude, mu, sigma, ra, dec, geocent_time, psi): + return dict(plus=np.array(array), cross=np.array(array)) + + class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestCase): def setUp(self): @@ -302,6 +306,25 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase): parameters=self.simulation_parameters) self.assertNotEqual(original_waveform, new_waveform) + def test_frequency_domain_caching_changing_model(self): + original_waveform = self.waveform_generator.frequency_domain_strain( + parameters=self.simulation_parameters) + self.waveform_generator.frequency_domain_source_model = dummy_func_array_return_value_2 + new_waveform = self.waveform_generator.frequency_domain_strain( + parameters=self.simulation_parameters) + self.assertFalse(np.array_equal(original_waveform['plus'], new_waveform['plus'])) + + def test_time_domain_caching_changing_model(self): + self.waveform_generator = \ + bilby.gw.waveform_generator.WaveformGenerator(duration=1, sampling_frequency=4096, + time_domain_source_model=dummy_func_dict_return_value) + original_waveform = self.waveform_generator.frequency_domain_strain( + parameters=self.simulation_parameters) + self.waveform_generator.time_domain_source_model = dummy_func_array_return_value_2 + new_waveform = self.waveform_generator.frequency_domain_strain( + parameters=self.simulation_parameters) + self.assertFalse(np.array_equal(original_waveform['plus'], new_waveform['plus'])) + class TestTimeDomainStrainMethod(unittest.TestCase): -- GitLab