Skip to content
Snippets Groups Projects
Commit 8ea2ec37 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'fix_waveform_generator_caching_issue' into 'master'

Fix waveform generator caching issue

See merge request !630
parents e3ba9810 4d5d6640
No related branches found
No related tags found
1 merge request!630Fix waveform generator caching issue
Pipeline #92707 passed with warnings
......@@ -152,7 +152,8 @@ class WaveformGenerator(object):
transformed_model_data_points, parameters):
if parameters is not None:
self.parameters = parameters
if self.parameters == self._cache['parameters'] and self._cache['model'] == model:
if self.parameters == self._cache['parameters'] and self._cache['model'] == model and \
self._cache['transformed_model'] == transformed_model:
return self._cache['waveform']
if model is not None:
model_strain = self._strain_from_model(model_data_points, model)
......@@ -164,6 +165,7 @@ class WaveformGenerator(object):
self._cache['waveform'] = model_strain
self._cache['parameters'] = self.parameters.copy()
self._cache['model'] = model
self._cache['transformed_model'] = transformed_model
return model_strain
def _strain_from_model(self, model_data_points, model):
......
......@@ -16,6 +16,10 @@ def dummy_func_dict_return_value(frequency_array, amplitude, mu, sigma, ra, dec,
return ht
def dummy_func_array_return_value_2(array, amplitude, mu, sigma, ra, dec, geocent_time, psi):
return dict(plus=np.array(array), cross=np.array(array))
class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestCase):
def setUp(self):
......@@ -302,6 +306,25 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
parameters=self.simulation_parameters)
self.assertNotEqual(original_waveform, new_waveform)
def test_frequency_domain_caching_changing_model(self):
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
self.waveform_generator.frequency_domain_source_model = dummy_func_array_return_value_2
new_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
self.assertFalse(np.array_equal(original_waveform['plus'], new_waveform['plus']))
def test_time_domain_caching_changing_model(self):
self.waveform_generator = \
bilby.gw.waveform_generator.WaveformGenerator(duration=1, sampling_frequency=4096,
time_domain_source_model=dummy_func_dict_return_value)
original_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
self.waveform_generator.time_domain_source_model = dummy_func_array_return_value_2
new_waveform = self.waveform_generator.frequency_domain_strain(
parameters=self.simulation_parameters)
self.assertFalse(np.array_equal(original_waveform['plus'], new_waveform['plus']))
class TestTimeDomainStrainMethod(unittest.TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment