Skip to content
Snippets Groups Projects
Commit 2d7e3ae5 authored by MoritzThomasHuebner's avatar MoritzThomasHuebner
Browse files

Made some overall improvements and threw out some shoddy stuff that should not...

Made some overall improvements and threw out some shoddy stuff that should not have been in there in the first place
parent 375814aa
No related branches found
No related tags found
1 merge request!124Simplify wg redundant code
Pipeline #
......@@ -12,15 +12,25 @@ def gaussian_frequency_domain_strain(frequency_array, amplitude, mu, sigma, ra,
return ht
def dummy_function_with_array_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
return amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi
def dummy_function_with_dict_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
ht = {'plus': amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi,
'cross': amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi}
return ht
def gaussian_frequency_domain_strain_2(frequency_array, a, m, s, ra, dec, geocent_time, psi, **kwargs):
ht = {'plus': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2),
'cross': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2)}
return ht
def gaussian_time_domain_strain_2(time_array, a, m, s, ra, dec, geocent_time, psi, **kwargs):
ht = {'plus': a * np.exp(-(m - time_array) ** 2 / s ** 2 / 2),
'cross': a * np.exp(-(m - time_array) ** 2 / s ** 2 / 2)}
def gaussian_time_domain_strain_2(time_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
ht = {'plus': amplitude * np.exp(-(mu - time_array) ** 2 / sigma ** 2 / 2),
'cross': amplitude * np.exp(-(mu - time_array) ** 2 / sigma ** 2 / 2)}
return ht
......@@ -126,9 +136,9 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
def setUp(self):
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
tupak.gw.waveform_generator.WaveformGenerator(duration=1, sampling_frequency=4096,
frequency_domain_source_model=gaussian_frequency_domain_strain)
self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1,
self.simulation_parameters = dict(amplitude=1e-2, mu=100, sigma=1,
ra=1.375,
dec=-1.2108,
geocent_time=1126259642.413,
......@@ -144,30 +154,47 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
self.waveform_generator.frequency_domain_strain()
def test_frequency_domain_source_model_call(self):
self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=3)
self.assertEqual(3, self.waveform_generator.frequency_domain_strain())
self.waveform_generator.parameters = self.simulation_parameters
expected = self.waveform_generator.frequency_domain_source_model(self.waveform_generator.frequency_array,
self.simulation_parameters['amplitude'],
self.simulation_parameters['mu'],
self.simulation_parameters['sigma'],
self.simulation_parameters['ra'],
self.simulation_parameters['dec'],
self.simulation_parameters['geocent_time'],
self.simulation_parameters['psi'])
actual = self.waveform_generator.frequency_domain_strain()
self.assertTrue(np.array_equal(expected['plus'], actual['plus']))
self.assertTrue(np.array_equal(expected['cross'], actual['cross']))
def test_time_domain_source_model_call_with_ndarray(self):
self.waveform_generator.frequency_domain_source_model = None
self.waveform_generator.time_domain_source_model = MagicMock(return_value=np.array([1, 2, 3]))
self.waveform_generator.time_domain_source_model = dummy_function_with_array_return_value
self.waveform_generator.parameters = self.simulation_parameters
def side_effect(value, value2):
return value
with mock.patch('tupak.core.utils.nfft') as m:
m.side_effect = side_effect
self.assertTrue(np.array_equal(np.array([1, 2, 3]), self.waveform_generator.frequency_domain_strain()))
expected = self.waveform_generator.time_domain_strain()
actual = self.waveform_generator.frequency_domain_strain()
self.assertTrue(np.array_equal(expected, actual))
def test_time_domain_source_model_call_with_dict(self):
self.waveform_generator.frequency_domain_source_model = None
self.waveform_generator.time_domain_source_model = MagicMock(return_value=dict(plus=1, cross=2))
self.waveform_generator.time_domain_source_model = dummy_function_with_dict_return_value
self.waveform_generator.parameters = self.simulation_parameters
def side_effect(value, value2):
return value, value2
return value, self.waveform_generator.frequency_array
with mock.patch('tupak.core.utils.nfft') as m:
m.side_effect = side_effect
self.assertDictEqual(dict(plus=1, cross=2), self.waveform_generator.frequency_domain_strain())
expected = self.waveform_generator.time_domain_strain()
actual = self.waveform_generator.frequency_domain_strain()
self.assertTrue(np.array_equal(expected['plus'], actual['plus']))
self.assertTrue(np.array_equal(expected['cross'], actual['cross']))
def test_no_source_model_given(self):
self.waveform_generator.time_domain_source_model = None
......@@ -211,32 +238,47 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
self.waveform_generator.time_domain_strain()
def test_time_domain_source_model_call(self):
self.waveform_generator.time_domain_source_model = MagicMock(return_value=3)
self.assertEqual(3, self.waveform_generator.time_domain_strain())
self.waveform_generator.parameters = self.simulation_parameters
expected = self.waveform_generator.time_domain_source_model(self.waveform_generator.time_array,
self.simulation_parameters['amplitude'],
self.simulation_parameters['mu'],
self.simulation_parameters['sigma'],
self.simulation_parameters['ra'],
self.simulation_parameters['dec'],
self.simulation_parameters['geocent_time'],
self.simulation_parameters['psi'])
actual = self.waveform_generator.time_domain_strain()
self.assertTrue(np.array_equal(expected['plus'], actual['plus']))
self.assertTrue(np.array_equal(expected['cross'], actual['cross']))
def test_frequency_domain_source_model_call_with_ndarray(self):
self.waveform_generator.time_domain_source_model = None
self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=np.array([1, 2, 3]))
self.waveform_generator.frequency_domain_source_model = dummy_function_with_array_return_value
self.waveform_generator.parameters = self.simulation_parameters
def side_effect(value, value2):
return value
with mock.patch('tupak.core.utils.infft') as m:
m.side_effect = side_effect
self.assertTrue(np.array_equal(np.array([1, 2, 3]), self.waveform_generator.time_domain_strain()))
expected = self.waveform_generator.frequency_domain_strain()
actual = self.waveform_generator.time_domain_strain()
self.assertTrue(np.array_equal(expected, actual))
def test_frequency_domain_source_model_call_with_dict(self):
self.waveform_generator.time_domain_source_model = None
self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=dict(plus=1, cross=2))
self.waveform_generator.frequency_domain_source_model = dummy_function_with_dict_return_value
self.waveform_generator.parameters = self.simulation_parameters
def side_effect(value, value2):
return value, value2
return value
with mock.patch('tupak.core.utils.infft') as m:
m.side_effect = side_effect
self.assertDictEqual(dict(plus=(1, self.waveform_generator.sampling_frequency),
cross=(2, self.waveform_generator.sampling_frequency)),
self.waveform_generator.time_domain_strain())
expected = self.waveform_generator.frequency_domain_strain()
actual = self.waveform_generator.time_domain_strain()
self.assertTrue(np.array_equal(expected['plus'], actual['plus']))
self.assertTrue(np.array_equal(expected['cross'], actual['cross']))
def test_no_source_model_given(self):
self.waveform_generator.time_domain_source_model = None
......@@ -245,7 +287,9 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
self.waveform_generator.time_domain_strain()
def test_key_popping(self):
self.waveform_generator.parameter_conversion = MagicMock(return_value=(dict(a=1e-21, m=100, s=1,
self.waveform_generator.parameter_conversion = MagicMock(return_value=(dict(amplitude=1e-2,
mu=100,
sigma=1,
ra=1.375, dec=-1.2108,
geocent_time=1126259642.413,
psi=2.659, c=None, d=None),
......@@ -255,7 +299,7 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
except RuntimeError:
pass
self.assertListEqual(sorted(self.waveform_generator.parameters.keys()),
sorted(['a', 'm', 's', 'ra', 'dec', 'geocent_time', 'psi']))
sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi']))
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment