From 0b57cbe8938c40edd1ef4711d0b03e1c7265f7bf Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Fri, 10 Nov 2023 15:34:47 +0000 Subject: [PATCH] TEST: fix numpy array testing --- test/gw/waveform_generator_test.py | 105 ++++++++++++++++------------- 1 file changed, 57 insertions(+), 48 deletions(-) diff --git a/test/gw/waveform_generator_test.py b/test/gw/waveform_generator_test.py index c4bd5729f..ce809140d 100644 --- a/test/gw/waveform_generator_test.py +++ b/test/gw/waveform_generator_test.py @@ -438,42 +438,42 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase): def test_frequency_domain_caching_and_using_time_domain_strain_without_parameters( self, ): - original_waveform = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) - new_waveform = self.waveform_generator.time_domain_strain() - self.assertNotEqual(original_waveform, new_waveform) + self.assertFalse(_test_caching_different_domain( + self.waveform_generator.frequency_domain_strain, + self.waveform_generator.time_domain_strain, + self.simulation_parameters, + None, + )) def test_frequency_domain_caching_and_using_time_domain_strain_with_parameters( self, ): - original_waveform = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) - new_waveform = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) - self.assertNotEqual(original_waveform, new_waveform) + self.assertFalse(_test_caching_different_domain( + self.waveform_generator.frequency_domain_strain, + self.waveform_generator.time_domain_strain, + self.simulation_parameters, + self.simulation_parameters, + )) def test_time_domain_caching_and_using_frequency_domain_strain_without_parameters( self, ): - original_waveform = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) - new_waveform = self.waveform_generator.frequency_domain_strain() - self.assertNotEqual(original_waveform, new_waveform) + self.assertFalse(_test_caching_different_domain( + self.waveform_generator.time_domain_strain, + self.waveform_generator.frequency_domain_strain, + self.simulation_parameters, + None, + )) def test_time_domain_caching_and_using_frequency_domain_strain_with_parameters( self, ): - original_waveform = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) - new_waveform = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) - self.assertNotEqual(original_waveform, new_waveform) + self.assertFalse(_test_caching_different_domain( + self.waveform_generator.time_domain_strain, + self.waveform_generator.frequency_domain_strain, + self.simulation_parameters, + self.simulation_parameters, + )) def test_frequency_domain_caching_changing_model(self): original_waveform = self.waveform_generator.frequency_domain_strain( @@ -648,42 +648,51 @@ class TestTimeDomainStrainMethod(unittest.TestCase): def test_frequency_domain_caching_and_using_time_domain_strain_without_parameters( self, ): - original_waveform = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) - new_waveform = self.waveform_generator.time_domain_strain() - self.assertNotEqual(original_waveform, new_waveform) + self.assertFalse(_test_caching_different_domain( + self.waveform_generator.frequency_domain_strain, + self.waveform_generator.time_domain_strain, + self.simulation_parameters, + None, + )) def test_frequency_domain_caching_and_using_time_domain_strain_with_parameters( self, ): - original_waveform = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) - new_waveform = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) - self.assertNotEqual(original_waveform, new_waveform) + self.assertFalse(_test_caching_different_domain( + self.waveform_generator.frequency_domain_strain, + self.waveform_generator.time_domain_strain, + self.simulation_parameters, + self.simulation_parameters, + )) def test_time_domain_caching_and_using_frequency_domain_strain_without_parameters( self, ): - original_waveform = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) - new_waveform = self.waveform_generator.frequency_domain_strain() - self.assertNotEqual(original_waveform, new_waveform) + self.assertFalse(_test_caching_different_domain( + self.waveform_generator.time_domain_strain, + self.waveform_generator.frequency_domain_strain, + self.simulation_parameters, + None, + )) def test_time_domain_caching_and_using_frequency_domain_strain_with_parameters( self, ): - original_waveform = self.waveform_generator.time_domain_strain( - parameters=self.simulation_parameters - ) - new_waveform = self.waveform_generator.frequency_domain_strain( - parameters=self.simulation_parameters - ) - self.assertNotEqual(original_waveform, new_waveform) + self.assertFalse(_test_caching_different_domain( + self.waveform_generator.time_domain_strain, + self.waveform_generator.frequency_domain_strain, + self.simulation_parameters, + self.simulation_parameters, + )) + + +def _test_caching_different_domain(func1, func2, params1, params2): + original_waveform = func1(parameters=params1) + new_waveform = func2(parameters=params2) + output = True + for key in original_waveform: + output &= np.array_equal(original_waveform[key], new_waveform[key]) + return output if __name__ == "__main__": -- GitLab