diff --git a/test/waveform_generator_tests.py b/test/waveform_generator_tests.py index 579d61576611bb0a6a952a2a1c2dd96f7ce31114..16c13c5dbb0113b38fa6f9df34f2ae05d10d84e5 100644 --- a/test/waveform_generator_tests.py +++ b/test/waveform_generator_tests.py @@ -2,20 +2,28 @@ from __future__ import absolute_import import unittest import tupak import numpy as np +import mock +from mock import MagicMock -def gaussian_frequency_domain_strain(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi): +def gaussian_frequency_domain_strain(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs): ht = {'plus': amplitude * np.exp(-(mu - frequency_array) ** 2 / sigma ** 2 / 2), 'cross': amplitude * np.exp(-(mu - frequency_array) ** 2 / sigma ** 2 / 2)} return ht -def gaussian_frequency_domain_strain_2(frequency_array, a, m, s, ra, dec, geocent_time, psi): +def gaussian_frequency_domain_strain_2(frequency_array, a, m, s, ra, dec, geocent_time, psi, **kwargs): ht = {'plus': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2), 'cross': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2)} return ht +def gaussian_time_domain_strain_2(time_array, a, m, s, ra, dec, geocent_time, psi, **kwargs): + ht = {'plus': a * np.exp(-(m - time_array) ** 2 / s ** 2 / 2), + 'cross': a * np.exp(-(m - time_array) ** 2 / s ** 2 / 2)} + return ht + + class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestCase): def setUp(self): @@ -52,6 +60,21 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC sorted(list(self.simulation_parameters.keys()))) +class TestWaveformArgumentsSetting(unittest.TestCase): + def setUp(self): + self.waveform_generator = \ + tupak.gw.waveform_generator.WaveformGenerator(1, 4096, + frequency_domain_source_model=gaussian_frequency_domain_strain, + waveform_arguments=dict(test='test', arguments='arguments')) + + def tearDown(self): + del self.waveform_generator + + def test_waveform_arguments_init_setting(self): + self.assertDictEqual(self.waveform_generator.waveform_arguments, + dict(test='test', arguments='arguments')) + + class TestSetters(unittest.TestCase): def setUp(self): @@ -83,13 +106,28 @@ class TestSetters(unittest.TestCase): self.waveform_generator.frequency_array = new_frequency_array self.assertTrue(np.array_equal(new_frequency_array, self.waveform_generator.frequency_array)) + def test_time_array_setter(self): + new_time_array = np.arange(1, 100) + self.waveform_generator.time_array = new_time_array + self.assertTrue(np.array_equal(new_time_array, self.waveform_generator.time_array)) + + def test_parameters_set_from_frequency_domain_source_model(self): + self.waveform_generator.frequency_domain_source_model = gaussian_frequency_domain_strain_2 + self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())), + sorted(list(self.simulation_parameters.keys()))) + + def test_parameters_set_from_time_domain_source_model(self): + self.waveform_generator.time_domain_source_model = gaussian_time_domain_strain_2 + self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())), + sorted(list(self.simulation_parameters.keys()))) + -class TestSourceModelSetter(unittest.TestCase): +class TestFrequencyDomainStrainMethod(unittest.TestCase): def setUp(self): - self.waveform_generator = tupak.gw.waveform_generator.WaveformGenerator(1, 4096, - frequency_domain_source_model=gaussian_frequency_domain_strain) - self.waveform_generator.frequency_domain_source_model = gaussian_frequency_domain_strain_2 + self.waveform_generator = \ + tupak.gw.waveform_generator.WaveformGenerator(1, 4096, + frequency_domain_source_model=gaussian_frequency_domain_strain) self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1, ra=1.375, dec=-1.2108, @@ -100,9 +138,124 @@ class TestSourceModelSetter(unittest.TestCase): del self.waveform_generator del self.simulation_parameters - def test_parameters_are_set_correctly(self): - self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())), - sorted(list(self.simulation_parameters.keys()))) + def test_parameter_conversion_is_called(self): + self.waveform_generator.parameter_conversion = MagicMock(side_effect=KeyError('test')) + with self.assertRaises(KeyError): + self.waveform_generator.frequency_domain_strain() + + def test_frequency_domain_source_model_call(self): + self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=3) + self.assertEqual(3, self.waveform_generator.frequency_domain_strain()) + + def test_time_domain_source_model_call_with_ndarray(self): + self.waveform_generator.frequency_domain_source_model = None + self.waveform_generator.time_domain_source_model = MagicMock(return_value=np.array([1, 2, 3])) + + def side_effect(value, value2): + return value + + with mock.patch('tupak.core.utils.nfft') as m: + m.side_effect = side_effect + self.assertTrue(np.array_equal(np.array([1, 2, 3]), self.waveform_generator.frequency_domain_strain())) + + def test_time_domain_source_model_call_with_dict(self): + self.waveform_generator.frequency_domain_source_model = None + self.waveform_generator.time_domain_source_model = MagicMock(return_value=dict(plus=1, cross=2)) + + def side_effect(value, value2): + return value, value2 + + with mock.patch('tupak.core.utils.nfft') as m: + m.side_effect = side_effect + self.assertDictEqual(dict(plus=1, cross=2), self.waveform_generator.frequency_domain_strain()) + + def test_no_source_model_given(self): + self.waveform_generator.time_domain_source_model = None + self.waveform_generator.frequency_domain_source_model = None + with self.assertRaises(RuntimeError): + self.waveform_generator.frequency_domain_strain() + + def test_key_popping(self): + self.waveform_generator.parameter_conversion = MagicMock(return_value=(dict(amplitude=1e-21, mu=100, sigma=1, + ra=1.375, dec=-1.2108, + geocent_time=1126259642.413, + psi=2.659, c=None, d=None), + ['c', 'd'])) + try: + self.waveform_generator.frequency_domain_strain() + except RuntimeError: + pass + self.assertListEqual(sorted(self.waveform_generator.parameters.keys()), + sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi'])) + + +class TestTimeDomainStrainMethod(unittest.TestCase): + + def setUp(self): + self.waveform_generator = \ + tupak.gw.waveform_generator.WaveformGenerator(1, 4096, + time_domain_source_model=gaussian_time_domain_strain_2) + self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1, + ra=1.375, + dec=-1.2108, + geocent_time=1126259642.413, + psi=2.659) + + def tearDown(self): + del self.waveform_generator + del self.simulation_parameters + + def test_parameter_conversion_is_called(self): + self.waveform_generator.parameter_conversion = MagicMock(side_effect=KeyError('test')) + with self.assertRaises(KeyError): + self.waveform_generator.time_domain_strain() + + def test_time_domain_source_model_call(self): + self.waveform_generator.time_domain_source_model = MagicMock(return_value=3) + self.assertEqual(3, self.waveform_generator.time_domain_strain()) + + def test_frequency_domain_source_model_call_with_ndarray(self): + self.waveform_generator.time_domain_source_model = None + self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=np.array([1, 2, 3])) + + def side_effect(value, value2): + return value + + with mock.patch('tupak.core.utils.infft') as m: + m.side_effect = side_effect + self.assertTrue(np.array_equal(np.array([1, 2, 3]), self.waveform_generator.time_domain_strain())) + + def test_frequency_domain_source_model_call_with_dict(self): + self.waveform_generator.time_domain_source_model = None + self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=dict(plus=1, cross=2)) + + def side_effect(value, value2): + return value, value2 + + with mock.patch('tupak.core.utils.infft') as m: + m.side_effect = side_effect + self.assertDictEqual(dict(plus=(1, self.waveform_generator.sampling_frequency), + cross=(2, self.waveform_generator.sampling_frequency)), + self.waveform_generator.time_domain_strain()) + + def test_no_source_model_given(self): + self.waveform_generator.time_domain_source_model = None + self.waveform_generator.frequency_domain_source_model = None + with self.assertRaises(RuntimeError): + self.waveform_generator.time_domain_strain() + + def test_key_popping(self): + self.waveform_generator.parameter_conversion = MagicMock(return_value=(dict(a=1e-21, m=100, s=1, + ra=1.375, dec=-1.2108, + geocent_time=1126259642.413, + psi=2.659, c=None, d=None), + ['c', 'd'])) + try: + self.waveform_generator.time_domain_strain() + except RuntimeError: + pass + self.assertListEqual(sorted(self.waveform_generator.parameters.keys()), + sorted(['a', 'm', 's', 'ra', 'dec', 'geocent_time', 'psi'])) if __name__ == '__main__': diff --git a/tupak/gw/waveform_generator.py b/tupak/gw/waveform_generator.py index fdee84a20641452a1476f98e56f56e16cfeedfab..f0770b70d0e933075c7643bbccf21ecc62206b4e 100644 --- a/tupak/gw/waveform_generator.py +++ b/tupak/gw/waveform_generator.py @@ -253,9 +253,9 @@ class WaveformGenerator(object): @property def start_time(self): - return self.__starting_time + return self.__start_time @start_time.setter def start_time(self, starting_time): - self.__starting_time = starting_time + self.__start_time = starting_time self.__time_array_updated = False