Skip to content
Snippets Groups Projects
Commit b30462d0 authored by Colm Talbot's avatar Colm Talbot Committed by Moritz Huebner
Browse files

add basic waveform caching

parent 971deb31
No related branches found
No related tags found
No related merge requests found
......@@ -62,6 +62,7 @@ class WaveformGenerator(object):
self.waveform_arguments = dict()
if isinstance(parameters, dict):
self.parameters = parameters
self._cache = dict(parameters=None, waveform=None)
def __repr__(self):
if self.frequency_domain_source_model is not None:
......@@ -147,6 +148,8 @@ class WaveformGenerator(object):
transformed_model_data_points, parameters):
if parameters is not None:
self.parameters = parameters
if self.parameters == self._cache['parameters']:
return self._cache['waveform']
if model is not None:
model_strain = self._strain_from_model(model_data_points, model)
elif transformed_model is not None:
......@@ -154,6 +157,8 @@ class WaveformGenerator(object):
transformation_function)
else:
raise RuntimeError("No source model given")
self._cache['waveform'] = model_strain
self._cache['parameters'] = self.parameters.copy()
return model_strain
def _strain_from_model(self, model_data_points, model):
......
......@@ -263,6 +263,19 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
self.assertListEqual(sorted(self.waveform_generator.parameters.keys()),
sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi']))
def test_caching_with_parameters(self):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
self.assertDictEqual(original_waveform, new_waveform)
def test_caching_without_parameters(self):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.frequency_domain_strain()
self.assertDictEqual(original_waveform, new_waveform)
class TestTimeDomainStrainMethod(unittest.TestCase):
......@@ -354,6 +367,20 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
self.assertListEqual(sorted(self.waveform_generator.parameters.keys()),
sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi']))
def test_caching_with_parameters(self):
original_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters)
self.assertDictEqual(original_waveform, new_waveform)
def test_caching_without_parameters(self):
original_waveform = self.waveform_generator.time_domain_strain(
parameters=self.simulation_parameters)
new_waveform = self.waveform_generator.time_domain_strain()
self.assertDictEqual(original_waveform, new_waveform)
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment