Skip to content
Snippets Groups Projects
Commit 502cc09c authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'Simplify_wg_redundant_code' into 'master'

Simplify wg redundant code

See merge request Monash/tupak!124
parents 36793a8c 0e1cc4e5
No related branches found
No related tags found
1 merge request!124Simplify wg redundant code
Pipeline #27342 passed with warnings
......@@ -6,21 +6,13 @@ import mock
from mock import MagicMock
def gaussian_frequency_domain_strain(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
ht = {'plus': amplitude * np.exp(-(mu - frequency_array) ** 2 / sigma ** 2 / 2),
'cross': amplitude * np.exp(-(mu - frequency_array) ** 2 / sigma ** 2 / 2)}
return ht
def gaussian_frequency_domain_strain_2(frequency_array, a, m, s, ra, dec, geocent_time, psi, **kwargs):
ht = {'plus': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2),
'cross': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2)}
return ht
def dummy_func_array_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
return amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi
def gaussian_time_domain_strain_2(time_array, a, m, s, ra, dec, geocent_time, psi, **kwargs):
ht = {'plus': a * np.exp(-(m - time_array) ** 2 / s ** 2 / 2),
'cross': a * np.exp(-(m - time_array) ** 2 / s ** 2 / 2)}
def dummy_func_dict_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
ht = {'plus': amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi,
'cross': amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi}
return ht
......@@ -29,7 +21,7 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC
def setUp(self):
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
frequency_domain_source_model=gaussian_frequency_domain_strain)
frequency_domain_source_model=dummy_func_dict_return_value)
self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1,
ra=1.375,
dec=-1.2108,
......@@ -47,7 +39,7 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC
self.assertEqual(self.waveform_generator.sampling_frequency, 4096)
def test_source_model(self):
self.assertEqual(self.waveform_generator.frequency_domain_source_model, gaussian_frequency_domain_strain)
self.assertEqual(self.waveform_generator.frequency_domain_source_model, dummy_func_dict_return_value)
def test_frequency_array_type(self):
self.assertIsInstance(self.waveform_generator.frequency_array, np.ndarray)
......@@ -64,7 +56,7 @@ class TestWaveformArgumentsSetting(unittest.TestCase):
def setUp(self):
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
frequency_domain_source_model=gaussian_frequency_domain_strain,
frequency_domain_source_model=dummy_func_dict_return_value,
waveform_arguments=dict(test='test', arguments='arguments'))
def tearDown(self):
......@@ -80,7 +72,7 @@ class TestSetters(unittest.TestCase):
def setUp(self):
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
frequency_domain_source_model=gaussian_frequency_domain_strain)
frequency_domain_source_model=dummy_func_dict_return_value)
self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1,
ra=1.375,
dec=-1.2108,
......@@ -112,12 +104,12 @@ class TestSetters(unittest.TestCase):
self.assertTrue(np.array_equal(new_time_array, self.waveform_generator.time_array))
def test_parameters_set_from_frequency_domain_source_model(self):
self.waveform_generator.frequency_domain_source_model = gaussian_frequency_domain_strain_2
self.waveform_generator.frequency_domain_source_model = dummy_func_dict_return_value
self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())),
sorted(list(self.simulation_parameters.keys())))
def test_parameters_set_from_time_domain_source_model(self):
self.waveform_generator.time_domain_source_model = gaussian_time_domain_strain_2
self.waveform_generator.time_domain_source_model = dummy_func_dict_return_value
self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())),
sorted(list(self.simulation_parameters.keys())))
......@@ -126,9 +118,9 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
def setUp(self):
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
frequency_domain_source_model=gaussian_frequency_domain_strain)
self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1,
tupak.gw.waveform_generator.WaveformGenerator(duration=1, sampling_frequency=4096,
frequency_domain_source_model=dummy_func_dict_return_value)
self.simulation_parameters = dict(amplitude=1e-2, mu=100, sigma=1,
ra=1.375,
dec=-1.2108,
geocent_time=1126259642.413,
......@@ -144,30 +136,47 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
self.waveform_generator.frequency_domain_strain()
def test_frequency_domain_source_model_call(self):
self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=3)
self.assertEqual(3, self.waveform_generator.frequency_domain_strain())
self.waveform_generator.parameters = self.simulation_parameters
expected = self.waveform_generator.frequency_domain_source_model(self.waveform_generator.frequency_array,
self.simulation_parameters['amplitude'],
self.simulation_parameters['mu'],
self.simulation_parameters['sigma'],
self.simulation_parameters['ra'],
self.simulation_parameters['dec'],
self.simulation_parameters['geocent_time'],
self.simulation_parameters['psi'])
actual = self.waveform_generator.frequency_domain_strain()
self.assertTrue(np.array_equal(expected['plus'], actual['plus']))
self.assertTrue(np.array_equal(expected['cross'], actual['cross']))
def test_time_domain_source_model_call_with_ndarray(self):
self.waveform_generator.frequency_domain_source_model = None
self.waveform_generator.time_domain_source_model = MagicMock(return_value=np.array([1, 2, 3]))
self.waveform_generator.time_domain_source_model = dummy_func_array_return_value
self.waveform_generator.parameters = self.simulation_parameters
def side_effect(value, value2):
return value
with mock.patch('tupak.core.utils.nfft') as m:
m.side_effect = side_effect
self.assertTrue(np.array_equal(np.array([1, 2, 3]), self.waveform_generator.frequency_domain_strain()))
expected = self.waveform_generator.time_domain_strain()
actual = self.waveform_generator.frequency_domain_strain()
self.assertTrue(np.array_equal(expected, actual))
def test_time_domain_source_model_call_with_dict(self):
self.waveform_generator.frequency_domain_source_model = None
self.waveform_generator.time_domain_source_model = MagicMock(return_value=dict(plus=1, cross=2))
self.waveform_generator.time_domain_source_model = dummy_func_dict_return_value
self.waveform_generator.parameters = self.simulation_parameters
def side_effect(value, value2):
return value, value2
return value, self.waveform_generator.frequency_array
with mock.patch('tupak.core.utils.nfft') as m:
m.side_effect = side_effect
self.assertDictEqual(dict(plus=1, cross=2), self.waveform_generator.frequency_domain_strain())
expected = self.waveform_generator.time_domain_strain()
actual = self.waveform_generator.frequency_domain_strain()
self.assertTrue(np.array_equal(expected['plus'], actual['plus']))
self.assertTrue(np.array_equal(expected['cross'], actual['cross']))
def test_no_source_model_given(self):
self.waveform_generator.time_domain_source_model = None
......@@ -194,7 +203,7 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
def setUp(self):
self.waveform_generator = \
tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
time_domain_source_model=gaussian_time_domain_strain_2)
time_domain_source_model=dummy_func_dict_return_value)
self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1,
ra=1.375,
dec=-1.2108,
......@@ -211,32 +220,47 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
self.waveform_generator.time_domain_strain()
def test_time_domain_source_model_call(self):
self.waveform_generator.time_domain_source_model = MagicMock(return_value=3)
self.assertEqual(3, self.waveform_generator.time_domain_strain())
self.waveform_generator.parameters = self.simulation_parameters
expected = self.waveform_generator.time_domain_source_model(self.waveform_generator.time_array,
self.simulation_parameters['amplitude'],
self.simulation_parameters['mu'],
self.simulation_parameters['sigma'],
self.simulation_parameters['ra'],
self.simulation_parameters['dec'],
self.simulation_parameters['geocent_time'],
self.simulation_parameters['psi'])
actual = self.waveform_generator.time_domain_strain()
self.assertTrue(np.array_equal(expected['plus'], actual['plus']))
self.assertTrue(np.array_equal(expected['cross'], actual['cross']))
def test_frequency_domain_source_model_call_with_ndarray(self):
self.waveform_generator.time_domain_source_model = None
self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=np.array([1, 2, 3]))
self.waveform_generator.frequency_domain_source_model = dummy_func_array_return_value
self.waveform_generator.parameters = self.simulation_parameters
def side_effect(value, value2):
return value
with mock.patch('tupak.core.utils.infft') as m:
m.side_effect = side_effect
self.assertTrue(np.array_equal(np.array([1, 2, 3]), self.waveform_generator.time_domain_strain()))
expected = self.waveform_generator.frequency_domain_strain()
actual = self.waveform_generator.time_domain_strain()
self.assertTrue(np.array_equal(expected, actual))
def test_frequency_domain_source_model_call_with_dict(self):
self.waveform_generator.time_domain_source_model = None
self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=dict(plus=1, cross=2))
self.waveform_generator.frequency_domain_source_model = dummy_func_dict_return_value
self.waveform_generator.parameters = self.simulation_parameters
def side_effect(value, value2):
return value, value2
return value
with mock.patch('tupak.core.utils.infft') as m:
m.side_effect = side_effect
self.assertDictEqual(dict(plus=(1, self.waveform_generator.sampling_frequency),
cross=(2, self.waveform_generator.sampling_frequency)),
self.waveform_generator.time_domain_strain())
expected = self.waveform_generator.frequency_domain_strain()
actual = self.waveform_generator.time_domain_strain()
self.assertTrue(np.array_equal(expected['plus'], actual['plus']))
self.assertTrue(np.array_equal(expected['cross'], actual['cross']))
def test_no_source_model_given(self):
self.waveform_generator.time_domain_source_model = None
......@@ -245,7 +269,9 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
self.waveform_generator.time_domain_strain()
def test_key_popping(self):
self.waveform_generator.parameter_conversion = MagicMock(return_value=(dict(a=1e-21, m=100, s=1,
self.waveform_generator.parameter_conversion = MagicMock(return_value=(dict(amplitude=1e-2,
mu=100,
sigma=1,
ra=1.375, dec=-1.2108,
geocent_time=1126259642.413,
psi=2.659, c=None, d=None),
......@@ -255,7 +281,7 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
except RuntimeError:
pass
self.assertListEqual(sorted(self.waveform_generator.parameters.keys()),
sorted(['a', 'm', 's', 'ra', 'dec', 'geocent_time', 'psi']))
sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi']))
if __name__ == '__main__':
......
......@@ -7,7 +7,8 @@ import numpy as np
class WaveformGenerator(object):
def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequency_domain_source_model=None,
time_domain_source_model=None, parameters=None, parameter_conversion=None,
time_domain_source_model=None, parameters=None,
parameter_conversion=lambda parameters, search_keys: (parameters, []),
non_standard_sampling_parameter_keys=None,
waveform_arguments=None):
""" A waveform generator
......@@ -32,7 +33,8 @@ class WaveformGenerator(object):
Initial values for the parameters
parameter_conversion: func, optional
Function to convert from sampled parameters to parameters of the
waveform generator
waveform generator. Default value is the identity, i.e. it leaves
the parameters unaffected.
non_standard_sampling_parameter_keys: list, optional
List of parameter name for *non-standard* sampling parameters.
waveform_arguments: dict, optional
......@@ -62,6 +64,8 @@ class WaveformGenerator(object):
self.__time_array_updated = False
self.__full_source_model_keyword_arguments = {}
self.__full_source_model_keyword_arguments.update(self.waveform_arguments)
self.__full_source_model_keyword_arguments.update(self.parameters)
self.__added_keys = []
def frequency_domain_strain(self):
""" Rapper to source_model.
......@@ -78,32 +82,11 @@ class WaveformGenerator(object):
RuntimeError: If no source model is given
"""
added_keys = []
if self.parameter_conversion is not None:
self.parameters, added_keys = self.parameter_conversion(self.parameters,
self.non_standard_sampling_parameter_keys)
if self.frequency_domain_source_model is not None:
self.__full_source_model_keyword_arguments.update(self.parameters)
model_frequency_strain = self.frequency_domain_source_model(
self.frequency_array,
**self.__full_source_model_keyword_arguments)
elif self.time_domain_source_model is not None:
model_frequency_strain = dict()
self.__full_source_model_keyword_arguments.update(self.parameters)
time_domain_strain = self.time_domain_source_model(
self.time_array, **self.__full_source_model_keyword_arguments)
if isinstance(time_domain_strain, np.ndarray):
return utils.nfft(time_domain_strain, self.sampling_frequency)
for key in time_domain_strain:
model_frequency_strain[key], self.frequency_array = utils.nfft(time_domain_strain[key],
self.sampling_frequency)
else:
raise RuntimeError("No source model given")
for key in added_keys:
self.parameters.pop(key)
return model_frequency_strain
return self._calculate_strain(model=self.frequency_domain_source_model,
model_data_points=self.frequency_array,
transformation_function=utils.nfft,
transformed_model=self.time_domain_source_model,
transformed_model_data_points=self.time_array)
def time_domain_strain(self):
""" Rapper to source_model.
......@@ -121,29 +104,51 @@ class WaveformGenerator(object):
RuntimeError: If no source model is given
"""
added_keys = []
if self.parameter_conversion is not None:
self.parameters, added_keys = self.parameter_conversion(self.parameters,
self.non_standard_sampling_parameter_keys)
if self.time_domain_source_model is not None:
self.__full_source_model_keyword_arguments.update(self.parameters)
model_time_series = self.time_domain_source_model(
self.time_array, **self.__full_source_model_keyword_arguments)
elif self.frequency_domain_source_model is not None:
model_time_series = dict()
self.__full_source_model_keyword_arguments.update(self.parameters)
frequency_domain_strain = self.frequency_domain_source_model(
self.frequency_array, **self.__full_source_model_keyword_arguments)
if isinstance(frequency_domain_strain, np.ndarray):
return utils.infft(frequency_domain_strain, self.sampling_frequency)
for key in frequency_domain_strain:
model_time_series[key] = utils.infft(frequency_domain_strain[key], self.sampling_frequency)
return self._calculate_strain(model=self.time_domain_source_model,
model_data_points=self.time_array,
transformation_function=utils.infft,
transformed_model=self.frequency_domain_source_model,
transformed_model_data_points=self.frequency_array)
def _calculate_strain(self, model, model_data_points, transformation_function, transformed_model,
transformed_model_data_points):
self._apply_parameter_conversion()
if model is not None:
model_strain = self._strain_from_model(model_data_points, model)
elif transformed_model is not None:
model_strain = self._strain_from_transformed_model(transformed_model_data_points, transformed_model,
transformation_function)
else:
raise RuntimeError("No source model given")
for key in added_keys:
self._remove_added_keys()
return model_strain
def _apply_parameter_conversion(self):
self.parameters, self.__added_keys = self.parameter_conversion(self.parameters,
self.non_standard_sampling_parameter_keys)
self.__full_source_model_keyword_arguments.update(self.parameters)
def _strain_from_model(self, model_data_points, model):
return model(model_data_points, **self.__full_source_model_keyword_arguments)
def _strain_from_transformed_model(self, transformed_model_data_points, transformed_model, transformation_function):
transformed_model_strain = self._strain_from_model(transformed_model_data_points, transformed_model)
if isinstance(transformed_model_strain, np.ndarray):
return transformation_function(transformed_model_strain, self.sampling_frequency)
model_strain = dict()
for key in transformed_model_strain:
if transformation_function == utils.nfft:
model_strain[key], self.frequency_array = \
transformation_function(transformed_model_strain[key], self.sampling_frequency)
else:
model_strain[key] = transformation_function(transformed_model_strain[key], self.sampling_frequency)
return model_strain
def _remove_added_keys(self):
for key in self.__added_keys:
self.parameters.pop(key)
return model_time_series
@property
def frequency_array(self):
......@@ -155,8 +160,8 @@ class WaveformGenerator(object):
"""
if self.__frequency_array_updated is False:
self.frequency_array = utils.create_frequency_series(
self.sampling_frequency,
self.duration)
self.sampling_frequency,
self.duration)
return self.__frequency_array
@frequency_array.setter
......@@ -175,9 +180,9 @@ class WaveformGenerator(object):
if self.__time_array_updated is False:
self.__time_array = utils.create_time_series(
self.sampling_frequency,
self.duration,
self.start_time)
self.sampling_frequency,
self.duration,
self.start_time)
self.__time_array_updated = True
return self.__time_array
......@@ -191,8 +196,6 @@ class WaveformGenerator(object):
def parameters(self):
""" The dictionary of parameters for source model.
Does some introspection into the source_model to figure out the parameters if none are given.
Returns
-------
dict: The dictionary of parameter key-value pairs
......@@ -202,6 +205,8 @@ class WaveformGenerator(object):
@parameters.setter
def parameters(self, parameters):
""" Does some introspection into the source_model to figure out the parameters if none are given.
"""
self.__parameters_from_source_model()
if isinstance(parameters, dict):
for key in parameters.keys():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment