Skip to content
Snippets Groups Projects
Commit 4d5d6640 authored by Moritz Huebner's avatar Moritz Huebner Committed by Gregory Ashton
Browse files

Fix waveform generator caching issue

parent e3ba9810
No related branches found
No related tags found
No related merge requests found
......@@ -152,7 +152,8 @@ class WaveformGenerator(object):
transformed_model_data_points, parameters):
if parameters is not None:
self.parameters = parameters
if self.parameters == self._cache['parameters'] and self._cache['model'] == model:
if self.parameters == self._cache['parameters'] and self._cache['model'] == model and \
self._cache['transformed_model'] == transformed_model:
return self._cache['waveform']
if model is not None:
model_strain = self._strain_from_model(model_data_points, model)
......@@ -164,6 +165,7 @@ class WaveformGenerator(object):
self._cache['waveform'] = model_strain
self._cache['parameters'] = self.parameters.copy()
self._cache['model'] = model
self._cache['transformed_model'] = transformed_model
return model_strain
def _strain_from_model(self, model_data_points, model):
......
......@@ -16,6 +16,10 @@ def dummy_func_dict_return_value(frequency_array, amplitude, mu, sigma, ra, dec,
return ht
def dummy_func_array_return_value_2(array, amplitude, mu, sigma, ra, dec, geocent_time, psi):
return dict(plus=np.array(array), cross=np.array(array))
class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestCase):
def setUp(self):
......@@ -302,6 +306,25 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
parameters=self.simulation_parameters)
self.assertNotEqual(original_waveform, new_waveform)
def test_frequency_domain_caching_changing_model(self):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
self.waveform_generator.frequency_domain_source_model = dummy_func_array_return_value_2
new_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
self.assertFalse(np.array_equal(original_waveform['plus'], new_waveform['plus']))
def test_time_domain_caching_changing_model(self):
self.waveform_generator = \
bilby.gw.waveform_generator.WaveformGenerator(duration=1, sampling_frequency=4096,
time_domain_source_model=dummy_func_dict_return_value)
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
self.waveform_generator.time_domain_source_model = dummy_func_array_return_value_2
new_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
self.assertFalse(np.array_equal(original_waveform['plus'], new_waveform['plus']))
class TestTimeDomainStrainMethod(unittest.TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment