Skip to content
Snippets Groups Projects
Commit 90e75e1c authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'series_cache_fix' into 'master'

Fixes a bug with the cache when calling frequency_domain_strain and time domain strain after each other

See merge request !466
parents aae0d15b 35110a99
No related branches found
No related tags found
1 merge request!466Fixes a bug with the cache when calling frequency_domain_strain and time domain strain after each other
Pipeline #60754 passed
......@@ -66,7 +66,7 @@ class WaveformGenerator(object):
self.waveform_arguments = dict()
if isinstance(parameters, dict):
self.parameters = parameters
self._cache = dict(parameters=None, waveform=None)
self._cache = dict(parameters=None, waveform=None, model=None)
def __repr__(self):
if self.frequency_domain_source_model is not None:
......@@ -152,7 +152,7 @@ class WaveformGenerator(object):
transformed_model_data_points, parameters):
if parameters is not None:
self.parameters = parameters
if self.parameters == self._cache['parameters']:
if self.parameters == self._cache['parameters'] and self._cache['model'] == model:
return self._cache['waveform']
if model is not None:
model_strain = self._strain_from_model(model_data_points, model)
......@@ -163,6 +163,7 @@ class WaveformGenerator(object):
raise RuntimeError("No source model given")
self._cache['waveform'] = model_strain
self._cache['parameters'] = self.parameters.copy()
self._cache['model'] = model
return model_strain
def _strain_from_model(self, model_data_points, model):
......
......@@ -276,6 +276,32 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
new_waveform = self.waveform_generator.frequency_domain_strain()
self.assertDictEqual(original_waveform, new_waveform)
def test_frequency_domain_caching_and_using_time_domain_strain_without_parameters(self):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.time_domain_strain()
self.assertNotEqual(original_waveform, new_waveform)
def test_frequency_domain_caching_and_using_time_domain_strain_with_parameters(self):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters)
self.assertNotEqual(original_waveform, new_waveform)
def test_time_domain_caching_and_using_frequency_domain_strain_without_parameters(self):
original_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.frequency_domain_strain()
self.assertNotEqual(original_waveform, new_waveform)
def test_time_domain_caching_and_using_frequency_domain_strain_with_parameters(self):
original_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
self.assertNotEqual(original_waveform, new_waveform)
class TestTimeDomainStrainMethod(unittest.TestCase):
......@@ -380,6 +406,31 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
new_waveform = self.waveform_generator.time_domain_strain()
self.assertDictEqual(original_waveform, new_waveform)
def test_frequency_domain_caching_and_using_time_domain_strain_without_parameters(self):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.time_domain_strain()
self.assertNotEqual(original_waveform, new_waveform)
def test_frequency_domain_caching_and_using_time_domain_strain_with_parameters(self):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters)
self.assertNotEqual(original_waveform, new_waveform)
def test_time_domain_caching_and_using_frequency_domain_strain_without_parameters(self):
original_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.frequency_domain_strain()
self.assertNotEqual(original_waveform, new_waveform)
def test_time_domain_caching_and_using_frequency_domain_strain_with_parameters(self):
original_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
self.assertNotEqual(original_waveform, new_waveform)
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment