diff --git a/test/waveform_generator_tests.py b/test/waveform_generator_tests.py index 16c13c5dbb0113b38fa6f9df34f2ae05d10d84e5..4d67cc47d16967dc167f5f6ede0f02bde7a9f9f7 100644 --- a/test/waveform_generator_tests.py +++ b/test/waveform_generator_tests.py @@ -6,21 +6,13 @@ import mock from mock import MagicMock -def gaussian_frequency_domain_strain(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs): - ht = {'plus': amplitude * np.exp(-(mu - frequency_array) ** 2 / sigma ** 2 / 2), - 'cross': amplitude * np.exp(-(mu - frequency_array) ** 2 / sigma ** 2 / 2)} - return ht - - -def gaussian_frequency_domain_strain_2(frequency_array, a, m, s, ra, dec, geocent_time, psi, **kwargs): - ht = {'plus': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2), - 'cross': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2)} - return ht +def dummy_func_array_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs): + return amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi -def gaussian_time_domain_strain_2(time_array, a, m, s, ra, dec, geocent_time, psi, **kwargs): - ht = {'plus': a * np.exp(-(m - time_array) ** 2 / s ** 2 / 2), - 'cross': a * np.exp(-(m - time_array) ** 2 / s ** 2 / 2)} +def dummy_func_dict_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs): + ht = {'plus': amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi, + 'cross': amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi} return ht @@ -29,7 +21,7 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC def setUp(self): self.waveform_generator = \ tupak.gw.waveform_generator.WaveformGenerator(1, 4096, - frequency_domain_source_model=gaussian_frequency_domain_strain) + frequency_domain_source_model=dummy_func_dict_return_value) self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1, ra=1.375, dec=-1.2108, @@ -47,7 +39,7 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC self.assertEqual(self.waveform_generator.sampling_frequency, 4096) def test_source_model(self): - self.assertEqual(self.waveform_generator.frequency_domain_source_model, gaussian_frequency_domain_strain) + self.assertEqual(self.waveform_generator.frequency_domain_source_model, dummy_func_dict_return_value) def test_frequency_array_type(self): self.assertIsInstance(self.waveform_generator.frequency_array, np.ndarray) @@ -64,7 +56,7 @@ class TestWaveformArgumentsSetting(unittest.TestCase): def setUp(self): self.waveform_generator = \ tupak.gw.waveform_generator.WaveformGenerator(1, 4096, - frequency_domain_source_model=gaussian_frequency_domain_strain, + frequency_domain_source_model=dummy_func_dict_return_value, waveform_arguments=dict(test='test', arguments='arguments')) def tearDown(self): @@ -80,7 +72,7 @@ class TestSetters(unittest.TestCase): def setUp(self): self.waveform_generator = \ tupak.gw.waveform_generator.WaveformGenerator(1, 4096, - frequency_domain_source_model=gaussian_frequency_domain_strain) + frequency_domain_source_model=dummy_func_dict_return_value) self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1, ra=1.375, dec=-1.2108, @@ -112,12 +104,12 @@ class TestSetters(unittest.TestCase): self.assertTrue(np.array_equal(new_time_array, self.waveform_generator.time_array)) def test_parameters_set_from_frequency_domain_source_model(self): - self.waveform_generator.frequency_domain_source_model = gaussian_frequency_domain_strain_2 + self.waveform_generator.frequency_domain_source_model = dummy_func_dict_return_value self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())), sorted(list(self.simulation_parameters.keys()))) def test_parameters_set_from_time_domain_source_model(self): - self.waveform_generator.time_domain_source_model = gaussian_time_domain_strain_2 + self.waveform_generator.time_domain_source_model = dummy_func_dict_return_value self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())), sorted(list(self.simulation_parameters.keys()))) @@ -126,9 +118,9 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase): def setUp(self): self.waveform_generator = \ - tupak.gw.waveform_generator.WaveformGenerator(1, 4096, - frequency_domain_source_model=gaussian_frequency_domain_strain) - self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1, + tupak.gw.waveform_generator.WaveformGenerator(duration=1, sampling_frequency=4096, + frequency_domain_source_model=dummy_func_dict_return_value) + self.simulation_parameters = dict(amplitude=1e-2, mu=100, sigma=1, ra=1.375, dec=-1.2108, geocent_time=1126259642.413, @@ -144,30 +136,47 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase): self.waveform_generator.frequency_domain_strain() def test_frequency_domain_source_model_call(self): - self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=3) - self.assertEqual(3, self.waveform_generator.frequency_domain_strain()) + self.waveform_generator.parameters = self.simulation_parameters + expected = self.waveform_generator.frequency_domain_source_model(self.waveform_generator.frequency_array, + self.simulation_parameters['amplitude'], + self.simulation_parameters['mu'], + self.simulation_parameters['sigma'], + self.simulation_parameters['ra'], + self.simulation_parameters['dec'], + self.simulation_parameters['geocent_time'], + self.simulation_parameters['psi']) + actual = self.waveform_generator.frequency_domain_strain() + self.assertTrue(np.array_equal(expected['plus'], actual['plus'])) + self.assertTrue(np.array_equal(expected['cross'], actual['cross'])) def test_time_domain_source_model_call_with_ndarray(self): self.waveform_generator.frequency_domain_source_model = None - self.waveform_generator.time_domain_source_model = MagicMock(return_value=np.array([1, 2, 3])) + self.waveform_generator.time_domain_source_model = dummy_func_array_return_value + self.waveform_generator.parameters = self.simulation_parameters def side_effect(value, value2): return value with mock.patch('tupak.core.utils.nfft') as m: m.side_effect = side_effect - self.assertTrue(np.array_equal(np.array([1, 2, 3]), self.waveform_generator.frequency_domain_strain())) + expected = self.waveform_generator.time_domain_strain() + actual = self.waveform_generator.frequency_domain_strain() + self.assertTrue(np.array_equal(expected, actual)) def test_time_domain_source_model_call_with_dict(self): self.waveform_generator.frequency_domain_source_model = None - self.waveform_generator.time_domain_source_model = MagicMock(return_value=dict(plus=1, cross=2)) + self.waveform_generator.time_domain_source_model = dummy_func_dict_return_value + self.waveform_generator.parameters = self.simulation_parameters def side_effect(value, value2): - return value, value2 + return value, self.waveform_generator.frequency_array with mock.patch('tupak.core.utils.nfft') as m: m.side_effect = side_effect - self.assertDictEqual(dict(plus=1, cross=2), self.waveform_generator.frequency_domain_strain()) + expected = self.waveform_generator.time_domain_strain() + actual = self.waveform_generator.frequency_domain_strain() + self.assertTrue(np.array_equal(expected['plus'], actual['plus'])) + self.assertTrue(np.array_equal(expected['cross'], actual['cross'])) def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None @@ -194,7 +203,7 @@ class TestTimeDomainStrainMethod(unittest.TestCase): def setUp(self): self.waveform_generator = \ tupak.gw.waveform_generator.WaveformGenerator(1, 4096, - time_domain_source_model=gaussian_time_domain_strain_2) + time_domain_source_model=dummy_func_dict_return_value) self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1, ra=1.375, dec=-1.2108, @@ -211,32 +220,47 @@ class TestTimeDomainStrainMethod(unittest.TestCase): self.waveform_generator.time_domain_strain() def test_time_domain_source_model_call(self): - self.waveform_generator.time_domain_source_model = MagicMock(return_value=3) - self.assertEqual(3, self.waveform_generator.time_domain_strain()) + self.waveform_generator.parameters = self.simulation_parameters + expected = self.waveform_generator.time_domain_source_model(self.waveform_generator.time_array, + self.simulation_parameters['amplitude'], + self.simulation_parameters['mu'], + self.simulation_parameters['sigma'], + self.simulation_parameters['ra'], + self.simulation_parameters['dec'], + self.simulation_parameters['geocent_time'], + self.simulation_parameters['psi']) + actual = self.waveform_generator.time_domain_strain() + self.assertTrue(np.array_equal(expected['plus'], actual['plus'])) + self.assertTrue(np.array_equal(expected['cross'], actual['cross'])) def test_frequency_domain_source_model_call_with_ndarray(self): self.waveform_generator.time_domain_source_model = None - self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=np.array([1, 2, 3])) + self.waveform_generator.frequency_domain_source_model = dummy_func_array_return_value + self.waveform_generator.parameters = self.simulation_parameters def side_effect(value, value2): return value with mock.patch('tupak.core.utils.infft') as m: m.side_effect = side_effect - self.assertTrue(np.array_equal(np.array([1, 2, 3]), self.waveform_generator.time_domain_strain())) + expected = self.waveform_generator.frequency_domain_strain() + actual = self.waveform_generator.time_domain_strain() + self.assertTrue(np.array_equal(expected, actual)) def test_frequency_domain_source_model_call_with_dict(self): self.waveform_generator.time_domain_source_model = None - self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=dict(plus=1, cross=2)) + self.waveform_generator.frequency_domain_source_model = dummy_func_dict_return_value + self.waveform_generator.parameters = self.simulation_parameters def side_effect(value, value2): - return value, value2 + return value with mock.patch('tupak.core.utils.infft') as m: m.side_effect = side_effect - self.assertDictEqual(dict(plus=(1, self.waveform_generator.sampling_frequency), - cross=(2, self.waveform_generator.sampling_frequency)), - self.waveform_generator.time_domain_strain()) + expected = self.waveform_generator.frequency_domain_strain() + actual = self.waveform_generator.time_domain_strain() + self.assertTrue(np.array_equal(expected['plus'], actual['plus'])) + self.assertTrue(np.array_equal(expected['cross'], actual['cross'])) def test_no_source_model_given(self): self.waveform_generator.time_domain_source_model = None @@ -245,7 +269,9 @@ class TestTimeDomainStrainMethod(unittest.TestCase): self.waveform_generator.time_domain_strain() def test_key_popping(self): - self.waveform_generator.parameter_conversion = MagicMock(return_value=(dict(a=1e-21, m=100, s=1, + self.waveform_generator.parameter_conversion = MagicMock(return_value=(dict(amplitude=1e-2, + mu=100, + sigma=1, ra=1.375, dec=-1.2108, geocent_time=1126259642.413, psi=2.659, c=None, d=None), @@ -255,7 +281,7 @@ class TestTimeDomainStrainMethod(unittest.TestCase): except RuntimeError: pass self.assertListEqual(sorted(self.waveform_generator.parameters.keys()), - sorted(['a', 'm', 's', 'ra', 'dec', 'geocent_time', 'psi'])) + sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi'])) if __name__ == '__main__': diff --git a/tupak/gw/waveform_generator.py b/tupak/gw/waveform_generator.py index f0770b70d0e933075c7643bbccf21ecc62206b4e..fe87dc8acd2cf8a54916b0296050a623b0848cea 100644 --- a/tupak/gw/waveform_generator.py +++ b/tupak/gw/waveform_generator.py @@ -7,7 +7,8 @@ import numpy as np class WaveformGenerator(object): def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequency_domain_source_model=None, - time_domain_source_model=None, parameters=None, parameter_conversion=None, + time_domain_source_model=None, parameters=None, + parameter_conversion=lambda parameters, search_keys: (parameters, []), non_standard_sampling_parameter_keys=None, waveform_arguments=None): """ A waveform generator @@ -32,7 +33,8 @@ class WaveformGenerator(object): Initial values for the parameters parameter_conversion: func, optional Function to convert from sampled parameters to parameters of the - waveform generator + waveform generator. Default value is the identity, i.e. it leaves + the parameters unaffected. non_standard_sampling_parameter_keys: list, optional List of parameter name for *non-standard* sampling parameters. waveform_arguments: dict, optional @@ -62,6 +64,8 @@ class WaveformGenerator(object): self.__time_array_updated = False self.__full_source_model_keyword_arguments = {} self.__full_source_model_keyword_arguments.update(self.waveform_arguments) + self.__full_source_model_keyword_arguments.update(self.parameters) + self.__added_keys = [] def frequency_domain_strain(self): """ Rapper to source_model. @@ -78,32 +82,11 @@ class WaveformGenerator(object): RuntimeError: If no source model is given """ - added_keys = [] - if self.parameter_conversion is not None: - self.parameters, added_keys = self.parameter_conversion(self.parameters, - self.non_standard_sampling_parameter_keys) - - if self.frequency_domain_source_model is not None: - self.__full_source_model_keyword_arguments.update(self.parameters) - model_frequency_strain = self.frequency_domain_source_model( - self.frequency_array, - **self.__full_source_model_keyword_arguments) - elif self.time_domain_source_model is not None: - model_frequency_strain = dict() - self.__full_source_model_keyword_arguments.update(self.parameters) - time_domain_strain = self.time_domain_source_model( - self.time_array, **self.__full_source_model_keyword_arguments) - if isinstance(time_domain_strain, np.ndarray): - return utils.nfft(time_domain_strain, self.sampling_frequency) - for key in time_domain_strain: - model_frequency_strain[key], self.frequency_array = utils.nfft(time_domain_strain[key], - self.sampling_frequency) - else: - raise RuntimeError("No source model given") - - for key in added_keys: - self.parameters.pop(key) - return model_frequency_strain + return self._calculate_strain(model=self.frequency_domain_source_model, + model_data_points=self.frequency_array, + transformation_function=utils.nfft, + transformed_model=self.time_domain_source_model, + transformed_model_data_points=self.time_array) def time_domain_strain(self): """ Rapper to source_model. @@ -121,29 +104,51 @@ class WaveformGenerator(object): RuntimeError: If no source model is given """ - added_keys = [] - if self.parameter_conversion is not None: - self.parameters, added_keys = self.parameter_conversion(self.parameters, - self.non_standard_sampling_parameter_keys) - if self.time_domain_source_model is not None: - self.__full_source_model_keyword_arguments.update(self.parameters) - model_time_series = self.time_domain_source_model( - self.time_array, **self.__full_source_model_keyword_arguments) - elif self.frequency_domain_source_model is not None: - model_time_series = dict() - self.__full_source_model_keyword_arguments.update(self.parameters) - frequency_domain_strain = self.frequency_domain_source_model( - self.frequency_array, **self.__full_source_model_keyword_arguments) - if isinstance(frequency_domain_strain, np.ndarray): - return utils.infft(frequency_domain_strain, self.sampling_frequency) - for key in frequency_domain_strain: - model_time_series[key] = utils.infft(frequency_domain_strain[key], self.sampling_frequency) + return self._calculate_strain(model=self.time_domain_source_model, + model_data_points=self.time_array, + transformation_function=utils.infft, + transformed_model=self.frequency_domain_source_model, + transformed_model_data_points=self.frequency_array) + + def _calculate_strain(self, model, model_data_points, transformation_function, transformed_model, + transformed_model_data_points): + self._apply_parameter_conversion() + if model is not None: + model_strain = self._strain_from_model(model_data_points, model) + elif transformed_model is not None: + model_strain = self._strain_from_transformed_model(transformed_model_data_points, transformed_model, + transformation_function) else: raise RuntimeError("No source model given") - - for key in added_keys: + self._remove_added_keys() + return model_strain + + def _apply_parameter_conversion(self): + self.parameters, self.__added_keys = self.parameter_conversion(self.parameters, + self.non_standard_sampling_parameter_keys) + self.__full_source_model_keyword_arguments.update(self.parameters) + + def _strain_from_model(self, model_data_points, model): + return model(model_data_points, **self.__full_source_model_keyword_arguments) + + def _strain_from_transformed_model(self, transformed_model_data_points, transformed_model, transformation_function): + transformed_model_strain = self._strain_from_model(transformed_model_data_points, transformed_model) + + if isinstance(transformed_model_strain, np.ndarray): + return transformation_function(transformed_model_strain, self.sampling_frequency) + + model_strain = dict() + for key in transformed_model_strain: + if transformation_function == utils.nfft: + model_strain[key], self.frequency_array = \ + transformation_function(transformed_model_strain[key], self.sampling_frequency) + else: + model_strain[key] = transformation_function(transformed_model_strain[key], self.sampling_frequency) + return model_strain + + def _remove_added_keys(self): + for key in self.__added_keys: self.parameters.pop(key) - return model_time_series @property def frequency_array(self): @@ -155,8 +160,8 @@ class WaveformGenerator(object): """ if self.__frequency_array_updated is False: self.frequency_array = utils.create_frequency_series( - self.sampling_frequency, - self.duration) + self.sampling_frequency, + self.duration) return self.__frequency_array @frequency_array.setter @@ -175,9 +180,9 @@ class WaveformGenerator(object): if self.__time_array_updated is False: self.__time_array = utils.create_time_series( - self.sampling_frequency, - self.duration, - self.start_time) + self.sampling_frequency, + self.duration, + self.start_time) self.__time_array_updated = True return self.__time_array @@ -191,8 +196,6 @@ class WaveformGenerator(object): def parameters(self): """ The dictionary of parameters for source model. - Does some introspection into the source_model to figure out the parameters if none are given. - Returns ------- dict: The dictionary of parameter key-value pairs @@ -202,6 +205,8 @@ class WaveformGenerator(object): @parameters.setter def parameters(self, parameters): + """ Does some introspection into the source_model to figure out the parameters if none are given. + """ self.__parameters_from_source_model() if isinstance(parameters, dict): for key in parameters.keys():