Skip to content
Snippets Groups Projects
Commit edf8d889 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch 'waveform-caching' into 'master'

add basic waveform caching

See merge request !427
parents 971deb31 b30462d0
No related branches found
No related tags found
1 merge request!427add basic waveform caching
Pipeline #56991 passed with warnings
......@@ -62,6 +62,7 @@ class WaveformGenerator(object):
self.waveform_arguments = dict()
if isinstance(parameters, dict):
self.parameters = parameters
self._cache = dict(parameters=None, waveform=None)
def __repr__(self):
if self.frequency_domain_source_model is not None:
......@@ -147,6 +148,8 @@ class WaveformGenerator(object):
transformed_model_data_points, parameters):
if parameters is not None:
self.parameters = parameters
if self.parameters == self._cache['parameters']:
return self._cache['waveform']
if model is not None:
model_strain = self._strain_from_model(model_data_points, model)
elif transformed_model is not None:
......@@ -154,6 +157,8 @@ class WaveformGenerator(object):
transformation_function)
else:
raise RuntimeError("No source model given")
self._cache['waveform'] = model_strain
self._cache['parameters'] = self.parameters.copy()
return model_strain
def _strain_from_model(self, model_data_points, model):
......
......@@ -263,6 +263,19 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
self.assertListEqual(sorted(self.waveform_generator.parameters.keys()),
sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi']))
def test_caching_with_parameters(self):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
self.assertDictEqual(original_waveform, new_waveform)
def test_caching_without_parameters(self):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.frequency_domain_strain()
self.assertDictEqual(original_waveform, new_waveform)
class TestTimeDomainStrainMethod(unittest.TestCase):
......@@ -354,6 +367,20 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
self.assertListEqual(sorted(self.waveform_generator.parameters.keys()),
sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi']))
def test_caching_with_parameters(self):
original_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters)
self.assertDictEqual(original_waveform, new_waveform)
def test_caching_without_parameters(self):
original_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.time_domain_strain()
self.assertDictEqual(original_waveform, new_waveform)
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment