Skip to content
Snippets Groups Projects
Commit 8aebd8af authored by moritz's avatar moritz
Browse files

Moritz Huebner: Added time_domain_source_model

parent 4440481d
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -36,7 +36,7 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC
self.assertEqual(self.waveform_generator.sampling_frequency, 4096)
def test_source_model(self):
self.assertEqual(self.waveform_generator.source_model, gaussian_frequency_domain_strain)
self.assertEqual(self.waveform_generator.frequency_domain_source_model, gaussian_frequency_domain_strain)
def test_frequency_array_type(self):
self.assertIsInstance(self.waveform_generator.frequency_array, np.ndarray)
......@@ -93,7 +93,7 @@ class TestSourceModelSetter(unittest.TestCase):
def setUp(self):
self.waveform_generator = tupak.waveform_generator.WaveformGenerator(
frequency_domain_source_model=gaussian_frequency_domain_strain)
self.waveform_generator.source_model = gaussian_frequency_domain_strain_2
self.waveform_generator.frequency_domain_source_model = gaussian_frequency_domain_strain_2
self.simulation_parameters = dict(a=1e-21, m=100, s=1,
ra=1.375,
dec=-1.2108,
......
......@@ -23,18 +23,22 @@ class WaveformGenerator(object):
"""
def __init__(self, frequency_domain_source_model, sampling_frequency=4096, time_duration=1,
def __init__(self, frequency_domain_source_model=None, time_domain_source_model=None, sampling_frequency=4096, time_duration=1,
parameters=None):
self.time_duration = time_duration
self.sampling_frequency = sampling_frequency
self.source_model = frequency_domain_source_model
self.frequency_domain_source_model = frequency_domain_source_model
self.time_domain_source_model = time_domain_source_model
self.parameters = parameters
self.__frequency_array_updated = False
self.__time_array_updated = False
def frequency_domain_strain(self):
""" Wrapper to source_model """
return self.source_model(self.frequency_array, **self.parameters)
return self.frequency_domain_source_model(self.frequency_array, **self.parameters)
def time_domain_strain(self):
return self.time_domain_source_model(self.time_array, **self.parameters)
@property
def frequency_array(self):
......@@ -62,7 +66,7 @@ class WaveformGenerator(object):
@parameters.setter
def parameters(self, parameters):
if parameters is None:
parameters = inspect.getargspec(self.source_model).args
parameters = inspect.getargspec(self.frequency_domain_source_model).args
parameters.pop(0)
self.__parameters = dict.fromkeys(parameters)
elif isinstance(parameters, list):
......@@ -80,11 +84,11 @@ class WaveformGenerator(object):
' a dictionary of key-value pairs.')
@property
def source_model(self):
def frequency_domain_source_model(self):
return self.__source_model
@source_model.setter
def source_model(self, source_model):
@frequency_domain_source_model.setter
def frequency_domain_source_model(self, source_model):
self.__source_model = source_model
self.parameters = inspect.getargspec(source_model).args
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment