From 0b57cbe8938c40edd1ef4711d0b03e1c7265f7bf Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Fri, 10 Nov 2023 15:34:47 +0000
Subject: [PATCH] TEST: fix numpy array testing

---
 test/gw/waveform_generator_test.py | 105 ++++++++++++++++-------------
 1 file changed, 57 insertions(+), 48 deletions(-)

diff --git a/test/gw/waveform_generator_test.py b/test/gw/waveform_generator_test.py
index c4bd5729f..ce809140d 100644
--- a/test/gw/waveform_generator_test.py
+++ b/test/gw/waveform_generator_test.py
@@ -438,42 +438,42 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
     def test_frequency_domain_caching_and_using_time_domain_strain_without_parameters(
         self,
     ):
-        original_waveform = self.waveform_generator.frequency_domain_strain(
-            parameters=self.simulation_parameters
-        )
-        new_waveform = self.waveform_generator.time_domain_strain()
-        self.assertNotEqual(original_waveform, new_waveform)
+        self.assertFalse(_test_caching_different_domain(
+            self.waveform_generator.frequency_domain_strain,
+            self.waveform_generator.time_domain_strain,
+            self.simulation_parameters,
+            None,
+        ))
 
     def test_frequency_domain_caching_and_using_time_domain_strain_with_parameters(
         self,
     ):
-        original_waveform = self.waveform_generator.frequency_domain_strain(
-            parameters=self.simulation_parameters
-        )
-        new_waveform = self.waveform_generator.time_domain_strain(
-            parameters=self.simulation_parameters
-        )
-        self.assertNotEqual(original_waveform, new_waveform)
+        self.assertFalse(_test_caching_different_domain(
+            self.waveform_generator.frequency_domain_strain,
+            self.waveform_generator.time_domain_strain,
+            self.simulation_parameters,
+            self.simulation_parameters,
+        ))
 
     def test_time_domain_caching_and_using_frequency_domain_strain_without_parameters(
         self,
     ):
-        original_waveform = self.waveform_generator.time_domain_strain(
-            parameters=self.simulation_parameters
-        )
-        new_waveform = self.waveform_generator.frequency_domain_strain()
-        self.assertNotEqual(original_waveform, new_waveform)
+        self.assertFalse(_test_caching_different_domain(
+            self.waveform_generator.time_domain_strain,
+            self.waveform_generator.frequency_domain_strain,
+            self.simulation_parameters,
+            None,
+        ))
 
     def test_time_domain_caching_and_using_frequency_domain_strain_with_parameters(
         self,
     ):
-        original_waveform = self.waveform_generator.time_domain_strain(
-            parameters=self.simulation_parameters
-        )
-        new_waveform = self.waveform_generator.frequency_domain_strain(
-            parameters=self.simulation_parameters
-        )
-        self.assertNotEqual(original_waveform, new_waveform)
+        self.assertFalse(_test_caching_different_domain(
+            self.waveform_generator.time_domain_strain,
+            self.waveform_generator.frequency_domain_strain,
+            self.simulation_parameters,
+            self.simulation_parameters,
+        ))
 
     def test_frequency_domain_caching_changing_model(self):
         original_waveform = self.waveform_generator.frequency_domain_strain(
@@ -648,42 +648,51 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
     def test_frequency_domain_caching_and_using_time_domain_strain_without_parameters(
         self,
     ):
-        original_waveform = self.waveform_generator.frequency_domain_strain(
-            parameters=self.simulation_parameters
-        )
-        new_waveform = self.waveform_generator.time_domain_strain()
-        self.assertNotEqual(original_waveform, new_waveform)
+        self.assertFalse(_test_caching_different_domain(
+            self.waveform_generator.frequency_domain_strain,
+            self.waveform_generator.time_domain_strain,
+            self.simulation_parameters,
+            None,
+        ))
 
     def test_frequency_domain_caching_and_using_time_domain_strain_with_parameters(
         self,
     ):
-        original_waveform = self.waveform_generator.frequency_domain_strain(
-            parameters=self.simulation_parameters
-        )
-        new_waveform = self.waveform_generator.time_domain_strain(
-            parameters=self.simulation_parameters
-        )
-        self.assertNotEqual(original_waveform, new_waveform)
+        self.assertFalse(_test_caching_different_domain(
+            self.waveform_generator.frequency_domain_strain,
+            self.waveform_generator.time_domain_strain,
+            self.simulation_parameters,
+            self.simulation_parameters,
+        ))
 
     def test_time_domain_caching_and_using_frequency_domain_strain_without_parameters(
         self,
     ):
-        original_waveform = self.waveform_generator.time_domain_strain(
-            parameters=self.simulation_parameters
-        )
-        new_waveform = self.waveform_generator.frequency_domain_strain()
-        self.assertNotEqual(original_waveform, new_waveform)
+        self.assertFalse(_test_caching_different_domain(
+            self.waveform_generator.time_domain_strain,
+            self.waveform_generator.frequency_domain_strain,
+            self.simulation_parameters,
+            None,
+        ))
 
     def test_time_domain_caching_and_using_frequency_domain_strain_with_parameters(
         self,
     ):
-        original_waveform = self.waveform_generator.time_domain_strain(
-            parameters=self.simulation_parameters
-        )
-        new_waveform = self.waveform_generator.frequency_domain_strain(
-            parameters=self.simulation_parameters
-        )
-        self.assertNotEqual(original_waveform, new_waveform)
+        self.assertFalse(_test_caching_different_domain(
+            self.waveform_generator.time_domain_strain,
+            self.waveform_generator.frequency_domain_strain,
+            self.simulation_parameters,
+            self.simulation_parameters,
+        ))
+
+
+def _test_caching_different_domain(func1, func2, params1, params2):
+    original_waveform = func1(parameters=params1)
+    new_waveform = func2(parameters=params2)
+    output = True
+    for key in original_waveform:
+        output &= np.array_equal(original_waveform[key], new_waveform[key])
+    return output
 
 
 if __name__ == "__main__":
-- 
GitLab