diff --git a/bilby/gw/waveform_generator.py b/bilby/gw/waveform_generator.py index 80d2334c63c4644171a87c32085834aabd6de60d..ace0703537afbe525e9abaee9fbd3cb997b7b0ce 100644 --- a/bilby/gw/waveform_generator.py +++ b/bilby/gw/waveform_generator.py @@ -152,7 +152,8 @@ class WaveformGenerator(object): transformed_model_data_points, parameters): if parameters is not None: self.parameters = parameters - if self.parameters == self._cache['parameters'] and self._cache['model'] == model: + if self.parameters == self._cache['parameters'] and self._cache['model'] == model and \ + self._cache['transformed_model'] == transformed_model: return self._cache['waveform'] if model is not None: model_strain = self._strain_from_model(model_data_points, model) @@ -164,6 +165,7 @@ class WaveformGenerator(object): self._cache['waveform'] = model_strain self._cache['parameters'] = self.parameters.copy() self._cache['model'] = model + self._cache['transformed_model'] = transformed_model return model_strain def _strain_from_model(self, model_data_points, model): diff --git a/test/waveform_generator_test.py b/test/waveform_generator_test.py index af45d2e631c943462491ad2f9fb6e08798a4c3cd..2d55e27eb4bce6033cb0413a1ac50f95d7aa76a8 100644 --- a/test/waveform_generator_test.py +++ b/test/waveform_generator_test.py @@ -16,6 +16,10 @@ def dummy_func_dict_return_value(frequency_array, amplitude, mu, sigma, ra, dec, return ht +def dummy_func_array_return_value_2(array, amplitude, mu, sigma, ra, dec, geocent_time, psi): + return dict(plus=np.array(array), cross=np.array(array)) + + class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestCase): def setUp(self): @@ -302,6 +306,25 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase): parameters=self.simulation_parameters) self.assertNotEqual(original_waveform, new_waveform) + def test_frequency_domain_caching_changing_model(self): + original_waveform = self.waveform_generator.frequency_domain_strain( + parameters=self.simulation_parameters) + self.waveform_generator.frequency_domain_source_model = dummy_func_array_return_value_2 + new_waveform = self.waveform_generator.frequency_domain_strain( + parameters=self.simulation_parameters) + self.assertFalse(np.array_equal(original_waveform['plus'], new_waveform['plus'])) + + def test_time_domain_caching_changing_model(self): + self.waveform_generator = \ + bilby.gw.waveform_generator.WaveformGenerator(duration=1, sampling_frequency=4096, + time_domain_source_model=dummy_func_dict_return_value) + original_waveform = self.waveform_generator.frequency_domain_strain( + parameters=self.simulation_parameters) + self.waveform_generator.time_domain_source_model = dummy_func_array_return_value_2 + new_waveform = self.waveform_generator.frequency_domain_strain( + parameters=self.simulation_parameters) + self.assertFalse(np.array_equal(original_waveform['plus'], new_waveform['plus'])) + class TestTimeDomainStrainMethod(unittest.TestCase):