Skip to content
Snippets Groups Projects
Commit d9a8f1e8 authored by MoritzThomasHuebner's avatar MoritzThomasHuebner
Browse files

Reduced the number of test functions

parent 148eac39
No related branches found
No related tags found
1 merge request!124Simplify wg redundant code
Pipeline #
......@@ -6,40 +6,22 @@ import mock
from mock import MagicMock
def gaussian_frequency_domain_strain(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
ht = {'plus': amplitude * np.exp(-(mu - frequency_array) ** 2 / sigma ** 2 / 2),
'cross': amplitude * np.exp(-(mu - frequency_array) ** 2 / sigma ** 2 / 2)}
return ht
def dummy_function_with_array_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
def dummy_func_array_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
return amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi
def dummy_function_with_dict_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
def dummy_func_dict_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
ht = {'plus': amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi,
'cross': amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi}
return ht
def gaussian_frequency_domain_strain_2(frequency_array, a, m, s, ra, dec, geocent_time, psi, **kwargs):
ht = {'plus': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2),
'cross': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2)}
return ht
def gaussian_time_domain_strain_2(time_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
ht = {'plus': amplitude * np.exp(-(mu - time_array) ** 2 / sigma ** 2 / 2),
'cross': amplitude * np.exp(-(mu - time_array) ** 2 / sigma ** 2 / 2)}
return ht
class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestCase):
def setUp(self):
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
frequency_domain_source_model=gaussian_frequency_domain_strain)
frequency_domain_source_model=dummy_func_dict_return_value)
self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1,
ra=1.375,
dec=-1.2108,
......@@ -57,7 +39,7 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC
self.assertEqual(self.waveform_generator.sampling_frequency, 4096)
def test_source_model(self):
self.assertEqual(self.waveform_generator.frequency_domain_source_model, gaussian_frequency_domain_strain)
self.assertEqual(self.waveform_generator.frequency_domain_source_model, dummy_func_dict_return_value)
def test_frequency_array_type(self):
self.assertIsInstance(self.waveform_generator.frequency_array, np.ndarray)
......@@ -74,7 +56,7 @@ class TestWaveformArgumentsSetting(unittest.TestCase):
def setUp(self):
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
frequency_domain_source_model=gaussian_frequency_domain_strain,
frequency_domain_source_model=dummy_func_dict_return_value,
waveform_arguments=dict(test='test', arguments='arguments'))
def tearDown(self):
......@@ -90,7 +72,7 @@ class TestSetters(unittest.TestCase):
def setUp(self):
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
frequency_domain_source_model=gaussian_frequency_domain_strain)
frequency_domain_source_model=dummy_func_dict_return_value)
self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1,
ra=1.375,
dec=-1.2108,
......@@ -122,12 +104,12 @@ class TestSetters(unittest.TestCase):
self.assertTrue(np.array_equal(new_time_array, self.waveform_generator.time_array))
def test_parameters_set_from_frequency_domain_source_model(self):
self.waveform_generator.frequency_domain_source_model = gaussian_frequency_domain_strain_2
self.waveform_generator.frequency_domain_source_model = dummy_func_dict_return_value
self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())),
sorted(list(self.simulation_parameters.keys())))
def test_parameters_set_from_time_domain_source_model(self):
self.waveform_generator.time_domain_source_model = gaussian_time_domain_strain_2
self.waveform_generator.time_domain_source_model = dummy_func_dict_return_value
self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())),
sorted(list(self.simulation_parameters.keys())))
......@@ -137,7 +119,7 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
def setUp(self):
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(duration=1, sampling_frequency=4096,
frequency_domain_source_model=gaussian_frequency_domain_strain)
frequency_domain_source_model=dummy_func_dict_return_value)
self.simulation_parameters = dict(amplitude=1e-2, mu=100, sigma=1,
ra=1.375,
dec=-1.2108,
......@@ -169,7 +151,7 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
def test_time_domain_source_model_call_with_ndarray(self):
self.waveform_generator.frequency_domain_source_model = None
self.waveform_generator.time_domain_source_model = dummy_function_with_array_return_value
self.waveform_generator.time_domain_source_model = dummy_func_array_return_value
self.waveform_generator.parameters = self.simulation_parameters
def side_effect(value, value2):
......@@ -183,7 +165,7 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
def test_time_domain_source_model_call_with_dict(self):
self.waveform_generator.frequency_domain_source_model = None
self.waveform_generator.time_domain_source_model = dummy_function_with_dict_return_value
self.waveform_generator.time_domain_source_model = dummy_func_dict_return_value
self.waveform_generator.parameters = self.simulation_parameters
def side_effect(value, value2):
......@@ -221,7 +203,7 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
def setUp(self):
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
time_domain_source_model=gaussian_time_domain_strain_2)
time_domain_source_model=dummy_func_dict_return_value)
self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1,
ra=1.375,
dec=-1.2108,
......@@ -253,7 +235,7 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
def test_frequency_domain_source_model_call_with_ndarray(self):
self.waveform_generator.time_domain_source_model = None
self.waveform_generator.frequency_domain_source_model = dummy_function_with_array_return_value
self.waveform_generator.frequency_domain_source_model = dummy_func_array_return_value
self.waveform_generator.parameters = self.simulation_parameters
def side_effect(value, value2):
......@@ -267,7 +249,7 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
def test_frequency_domain_source_model_call_with_dict(self):
self.waveform_generator.time_domain_source_model = None
self.waveform_generator.frequency_domain_source_model = dummy_function_with_dict_return_value
self.waveform_generator.frequency_domain_source_model = dummy_func_dict_return_value
self.waveform_generator.parameters = self.simulation_parameters
def side_effect(value, value2):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment