diff --git a/test/waveform_generator_tests.py b/test/waveform_generator_tests.py index 16c13c5dbb0113b38fa6f9df34f2ae05d10d84e5..bc4a36ca85d374d18a00eb289ed058b477286c0c 100644 --- a/test/waveform_generator_tests.py +++ b/test/waveform_generator_tests.py @@ -12,15 +12,25 @@ def gaussian_frequency_domain_strain(frequency_array, amplitude, mu, sigma, ra, return ht +def dummy_function_with_array_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs): + return amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi + + +def dummy_function_with_dict_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs): + ht = {'plus': amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi, + 'cross': amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi} + return ht + + def gaussian_frequency_domain_strain_2(frequency_array, a, m, s, ra, dec, geocent_time, psi, **kwargs): ht = {'plus': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2), 'cross': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2)} return ht -def gaussian_time_domain_strain_2(time_array, a, m, s, ra, dec, geocent_time, psi, **kwargs): - ht = {'plus': a * np.exp(-(m - time_array) ** 2 / s ** 2 / 2), - 'cross': a * np.exp(-(m - time_array) ** 2 / s ** 2 / 2)} +def gaussian_time_domain_strain_2(time_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs): + ht = {'plus': amplitude * np.exp(-(mu - time_array) ** 2 / sigma ** 2 / 2), + 'cross': amplitude * np.exp(-(mu - time_array) ** 2 / sigma ** 2 / 2)} return ht @@ -126,9 +136,9 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase): def setUp(self): self.waveform_generator = \ - tupak.gw.waveform_generator.WaveformGenerator(1, 4096, + tupak.gw.waveform_generator.WaveformGenerator(duration=1, sampling_frequency=4096, frequency_domain_source_model=gaussian_frequency_domain_strain) - self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1, + self.simulation_parameters = dict(amplitude=1e-2, mu=100, sigma=1, ra=1.375, dec=-1.2108, geocent_time=1126259642.413, @@ -144,30 +154,47 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase): self.waveform_generator.frequency_domain_strain() def test_frequency_domain_source_model_call(self): - self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=3) - self.assertEqual(3, self.waveform_generator.frequency_domain_strain()) + self.waveform_generator.parameters = self.simulation_parameters + expected = self.waveform_generator.frequency_domain_source_model(self.waveform_generator.frequency_array, + self.simulation_parameters['amplitude'], + self.simulation_parameters['mu'], + self.simulation_parameters['sigma'], + self.simulation_parameters['ra'], + self.simulation_parameters['dec'], + self.simulation_parameters['geocent_time'], + self.simulation_parameters['psi']) + actual = self.waveform_generator.frequency_domain_strain() + self.assertTrue(np.array_equal(expected['plus'], actual['plus'])) + self.assertTrue(np.array_equal(expected['cross'], actual['cross'])) def test_time_domain_source_model_call_with_ndarray(self): self.waveform_generator.frequency_domain_source_model = None - self.waveform_generator.time_domain_source_model = MagicMock(return_value=np.array([1, 2, 3])) + self.waveform_generator.time_domain_source_model = dummy_function_with_array_return_value + self.waveform_generator.parameters = self.simulation_parameters def side_effect(value, value2): return value with mock.patch('tupak.core.utils.nfft') as m: m.side_effect = side_effect - self.assertTrue(np.array_equal(np.array([1, 2, 3]), self.waveform_generator.frequency_domain_strain())) + expected = self.waveform_generator.time_domain_strain() + actual = self.waveform_generator.frequency_domain_strain() + self.assertTrue(np.array_equal(expected, actual)) def test_time_domain_source_model_call_with_dict(self): self.waveform_generator.frequency_domain_source_model = None - self.waveform_generator.time_domain_source_model = MagicMock(return_value=dict(plus=1, cross=2)) + self.waveform_generator.time_domain_source_model = dummy_function_with_dict_return_value + self.waveform_generator.parameters = self.simulation_parameters def side_effect(value, value2): - return value, value2 + return value, self.waveform_generator.frequency_array with mock.patch('tupak.core.utils.nfft') as m: m.side_effect = side_effect - self.assertDictEqual(dict(plus=1, cross=2), self.waveform_generator.frequency_domain_strain()) + expected = self.waveform_generator.time_domain_strain() + actual = self.waveform_generator.frequency_domain_strain() + self.assertTrue(np.array_equal(expected['plus'], actual['plus'])) + self.assertTrue(np.array_equal(expected['cross'], actual['cross'])) def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None @@ -211,32 +238,47 @@ class TestTimeDomainStrainMethod(unittest.TestCase): self.waveform_generator.time_domain_strain() def test_time_domain_source_model_call(self): - self.waveform_generator.time_domain_source_model = MagicMock(return_value=3) - self.assertEqual(3, self.waveform_generator.time_domain_strain()) + self.waveform_generator.parameters = self.simulation_parameters + expected = self.waveform_generator.time_domain_source_model(self.waveform_generator.time_array, + self.simulation_parameters['amplitude'], + self.simulation_parameters['mu'], + self.simulation_parameters['sigma'], + self.simulation_parameters['ra'], + self.simulation_parameters['dec'], + self.simulation_parameters['geocent_time'], + self.simulation_parameters['psi']) + actual = self.waveform_generator.time_domain_strain() + self.assertTrue(np.array_equal(expected['plus'], actual['plus'])) + self.assertTrue(np.array_equal(expected['cross'], actual['cross'])) def test_frequency_domain_source_model_call_with_ndarray(self): self.waveform_generator.time_domain_source_model = None - self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=np.array([1, 2, 3])) + self.waveform_generator.frequency_domain_source_model = dummy_function_with_array_return_value + self.waveform_generator.parameters = self.simulation_parameters def side_effect(value, value2): return value with mock.patch('tupak.core.utils.infft') as m: m.side_effect = side_effect - self.assertTrue(np.array_equal(np.array([1, 2, 3]), self.waveform_generator.time_domain_strain())) + expected = self.waveform_generator.frequency_domain_strain() + actual = self.waveform_generator.time_domain_strain() + self.assertTrue(np.array_equal(expected, actual)) def test_frequency_domain_source_model_call_with_dict(self): self.waveform_generator.time_domain_source_model = None - self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=dict(plus=1, cross=2)) + self.waveform_generator.frequency_domain_source_model = dummy_function_with_dict_return_value + self.waveform_generator.parameters = self.simulation_parameters def side_effect(value, value2): - return value, value2 + return value with mock.patch('tupak.core.utils.infft') as m: m.side_effect = side_effect - self.assertDictEqual(dict(plus=(1, self.waveform_generator.sampling_frequency), - cross=(2, self.waveform_generator.sampling_frequency)), - self.waveform_generator.time_domain_strain()) + expected = self.waveform_generator.frequency_domain_strain() + actual = self.waveform_generator.time_domain_strain() + self.assertTrue(np.array_equal(expected['plus'], actual['plus'])) + self.assertTrue(np.array_equal(expected['cross'], actual['cross'])) def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None @@ -245,7 +287,9 @@ class TestTimeDomainStrainMethod(unittest.TestCase): self.waveform_generator.time_domain_strain() def test_key_popping(self): - self.waveform_generator.parameter_conversion = MagicMock(return_value=(dict(a=1e-21, m=100, s=1, + self.waveform_generator.parameter_conversion = MagicMock(return_value=(dict(amplitude=1e-2, + mu=100, + sigma=1, ra=1.375, dec=-1.2108, geocent_time=1126259642.413, psi=2.659, c=None, d=None), @@ -255,7 +299,7 @@ class TestTimeDomainStrainMethod(unittest.TestCase): except RuntimeError: pass self.assertListEqual(sorted(self.waveform_generator.parameters.keys()), - sorted(['a', 'm', 's', 'ra', 'dec', 'geocent_time', 'psi'])) + sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi'])) if __name__ == '__main__':