diff --git a/bilby/gw/waveform_generator.py b/bilby/gw/waveform_generator.py index 00e24bde621c5b989ba3daca27d4f4ebf4ed2267..98d06914695af20473745e88d3916a0367832e2b 100644 --- a/bilby/gw/waveform_generator.py +++ b/bilby/gw/waveform_generator.py @@ -62,6 +62,7 @@ class WaveformGenerator(object): self.waveform_arguments = dict() if isinstance(parameters, dict): self.parameters = parameters + self._cache = dict(parameters=None, waveform=None) def __repr__(self): if self.frequency_domain_source_model is not None: @@ -147,6 +148,8 @@ class WaveformGenerator(object): transformed_model_data_points, parameters): if parameters is not None: self.parameters = parameters + if self.parameters == self._cache['parameters']: + return self._cache['waveform'] if model is not None: model_strain = self._strain_from_model(model_data_points, model) elif transformed_model is not None: @@ -154,6 +157,8 @@ class WaveformGenerator(object): transformation_function) else: raise RuntimeError("No source model given") + self._cache['waveform'] = model_strain + self._cache['parameters'] = self.parameters.copy() return model_strain def _strain_from_model(self, model_data_points, model): diff --git a/test/waveform_generator_test.py b/test/waveform_generator_test.py index 8eba859ac03f0a1d73ade2a74f0785aff1162953..15aa65e467f2ecccb379effab951417904bed6b7 100644 --- a/test/waveform_generator_test.py +++ b/test/waveform_generator_test.py @@ -263,6 +263,19 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase): self.assertListEqual(sorted(self.waveform_generator.parameters.keys()), sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi'])) + def test_caching_with_parameters(self): + original_waveform = self.waveform_generator.frequency_domain_strain( + parameters=self.simulation_parameters) + new_waveform = self.waveform_generator.frequency_domain_strain( + parameters=self.simulation_parameters) + self.assertDictEqual(original_waveform, new_waveform) + + def test_caching_without_parameters(self): + original_waveform = self.waveform_generator.frequency_domain_strain( + parameters=self.simulation_parameters) + new_waveform = self.waveform_generator.frequency_domain_strain() + self.assertDictEqual(original_waveform, new_waveform) + class TestTimeDomainStrainMethod(unittest.TestCase): @@ -354,6 +367,20 @@ class TestTimeDomainStrainMethod(unittest.TestCase): self.assertListEqual(sorted(self.waveform_generator.parameters.keys()), sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi'])) + def test_caching_with_parameters(self): + original_waveform = self.waveform_generator.time_domain_strain( + parameters=self.simulation_parameters) + new_waveform = self.waveform_generator.time_domain_strain( + parameters=self.simulation_parameters) + self.assertDictEqual(original_waveform, new_waveform) + + def test_caching_without_parameters(self): + original_waveform = self.waveform_generator.time_domain_strain( + parameters=self.simulation_parameters) + new_waveform = self.waveform_generator.time_domain_strain() + self.assertDictEqual(original_waveform, new_waveform) + + if __name__ == '__main__': unittest.main()