diff --git a/test/waveform_generator_tests.py b/test/waveform_generator_tests.py
index 16c13c5dbb0113b38fa6f9df34f2ae05d10d84e5..4d67cc47d16967dc167f5f6ede0f02bde7a9f9f7 100644
--- a/test/waveform_generator_tests.py
+++ b/test/waveform_generator_tests.py
@@ -6,21 +6,13 @@ import mock
 from mock import MagicMock
 
 
-def gaussian_frequency_domain_strain(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
-    ht = {'plus': amplitude * np.exp(-(mu - frequency_array) ** 2 / sigma ** 2 / 2),
-          'cross': amplitude * np.exp(-(mu - frequency_array) ** 2 / sigma ** 2 / 2)}
-    return ht
-
-
-def gaussian_frequency_domain_strain_2(frequency_array, a, m, s, ra, dec, geocent_time, psi, **kwargs):
-    ht = {'plus': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2),
-          'cross': a * np.exp(-(m - frequency_array) ** 2 / s ** 2 / 2)}
-    return ht
+def dummy_func_array_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
+    return amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi
 
 
-def gaussian_time_domain_strain_2(time_array, a, m, s, ra, dec, geocent_time, psi, **kwargs):
-    ht = {'plus': a * np.exp(-(m - time_array) ** 2 / s ** 2 / 2),
-          'cross': a * np.exp(-(m - time_array) ** 2 / s ** 2 / 2)}
+def dummy_func_dict_return_value(frequency_array, amplitude, mu, sigma, ra, dec, geocent_time, psi, **kwargs):
+    ht = {'plus': amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi,
+          'cross': amplitude + mu + frequency_array + sigma + ra + dec + geocent_time + psi}
     return ht
 
 
@@ -29,7 +21,7 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC
     def setUp(self):
         self.waveform_generator = \
             tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
-                                                          frequency_domain_source_model=gaussian_frequency_domain_strain)
+                                                          frequency_domain_source_model=dummy_func_dict_return_value)
         self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1,
                                           ra=1.375,
                                           dec=-1.2108,
@@ -47,7 +39,7 @@ class TestWaveformGeneratorInstantiationWithoutOptionalParameters(unittest.TestC
         self.assertEqual(self.waveform_generator.sampling_frequency, 4096)
 
     def test_source_model(self):
-        self.assertEqual(self.waveform_generator.frequency_domain_source_model, gaussian_frequency_domain_strain)
+        self.assertEqual(self.waveform_generator.frequency_domain_source_model, dummy_func_dict_return_value)
 
     def test_frequency_array_type(self):
         self.assertIsInstance(self.waveform_generator.frequency_array, np.ndarray)
@@ -64,7 +56,7 @@ class TestWaveformArgumentsSetting(unittest.TestCase):
     def setUp(self):
         self.waveform_generator = \
             tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
-                                                          frequency_domain_source_model=gaussian_frequency_domain_strain,
+                                                          frequency_domain_source_model=dummy_func_dict_return_value,
                                                           waveform_arguments=dict(test='test', arguments='arguments'))
 
     def tearDown(self):
@@ -80,7 +72,7 @@ class TestSetters(unittest.TestCase):
     def setUp(self):
         self.waveform_generator = \
             tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
-                                                          frequency_domain_source_model=gaussian_frequency_domain_strain)
+                                                          frequency_domain_source_model=dummy_func_dict_return_value)
         self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1,
                                           ra=1.375,
                                           dec=-1.2108,
@@ -112,12 +104,12 @@ class TestSetters(unittest.TestCase):
         self.assertTrue(np.array_equal(new_time_array, self.waveform_generator.time_array))
 
     def test_parameters_set_from_frequency_domain_source_model(self):
-        self.waveform_generator.frequency_domain_source_model = gaussian_frequency_domain_strain_2
+        self.waveform_generator.frequency_domain_source_model = dummy_func_dict_return_value
         self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())),
                              sorted(list(self.simulation_parameters.keys())))
 
     def test_parameters_set_from_time_domain_source_model(self):
-        self.waveform_generator.time_domain_source_model = gaussian_time_domain_strain_2
+        self.waveform_generator.time_domain_source_model = dummy_func_dict_return_value
         self.assertListEqual(sorted(list(self.waveform_generator.parameters.keys())),
                              sorted(list(self.simulation_parameters.keys())))
 
@@ -126,9 +118,9 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
 
     def setUp(self):
         self.waveform_generator = \
-            tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
-                                                          frequency_domain_source_model=gaussian_frequency_domain_strain)
-        self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1,
+            tupak.gw.waveform_generator.WaveformGenerator(duration=1, sampling_frequency=4096,
+                                                          frequency_domain_source_model=dummy_func_dict_return_value)
+        self.simulation_parameters = dict(amplitude=1e-2, mu=100, sigma=1,
                                           ra=1.375,
                                           dec=-1.2108,
                                           geocent_time=1126259642.413,
@@ -144,30 +136,47 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
             self.waveform_generator.frequency_domain_strain()
 
     def test_frequency_domain_source_model_call(self):
-        self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=3)
-        self.assertEqual(3, self.waveform_generator.frequency_domain_strain())
+        self.waveform_generator.parameters = self.simulation_parameters
+        expected = self.waveform_generator.frequency_domain_source_model(self.waveform_generator.frequency_array,
+                                                                         self.simulation_parameters['amplitude'],
+                                                                         self.simulation_parameters['mu'],
+                                                                         self.simulation_parameters['sigma'],
+                                                                         self.simulation_parameters['ra'],
+                                                                         self.simulation_parameters['dec'],
+                                                                         self.simulation_parameters['geocent_time'],
+                                                                         self.simulation_parameters['psi'])
+        actual = self.waveform_generator.frequency_domain_strain()
+        self.assertTrue(np.array_equal(expected['plus'], actual['plus']))
+        self.assertTrue(np.array_equal(expected['cross'], actual['cross']))
 
     def test_time_domain_source_model_call_with_ndarray(self):
         self.waveform_generator.frequency_domain_source_model = None
-        self.waveform_generator.time_domain_source_model = MagicMock(return_value=np.array([1, 2, 3]))
+        self.waveform_generator.time_domain_source_model = dummy_func_array_return_value
+        self.waveform_generator.parameters = self.simulation_parameters
 
         def side_effect(value, value2):
             return value
 
         with mock.patch('tupak.core.utils.nfft') as m:
             m.side_effect = side_effect
-            self.assertTrue(np.array_equal(np.array([1, 2, 3]), self.waveform_generator.frequency_domain_strain()))
+            expected = self.waveform_generator.time_domain_strain()
+            actual = self.waveform_generator.frequency_domain_strain()
+            self.assertTrue(np.array_equal(expected, actual))
 
     def test_time_domain_source_model_call_with_dict(self):
         self.waveform_generator.frequency_domain_source_model = None
-        self.waveform_generator.time_domain_source_model = MagicMock(return_value=dict(plus=1, cross=2))
+        self.waveform_generator.time_domain_source_model = dummy_func_dict_return_value
+        self.waveform_generator.parameters = self.simulation_parameters
 
         def side_effect(value, value2):
-            return value, value2
+            return value, self.waveform_generator.frequency_array
 
         with mock.patch('tupak.core.utils.nfft') as m:
             m.side_effect = side_effect
-            self.assertDictEqual(dict(plus=1, cross=2), self.waveform_generator.frequency_domain_strain())
+            expected = self.waveform_generator.time_domain_strain()
+            actual = self.waveform_generator.frequency_domain_strain()
+            self.assertTrue(np.array_equal(expected['plus'], actual['plus']))
+            self.assertTrue(np.array_equal(expected['cross'], actual['cross']))
 
     def test_no_source_model_given(self):
         self.waveform_generator.time_domain_source_model = None
@@ -194,7 +203,7 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
     def setUp(self):
         self.waveform_generator = \
             tupak.gw.waveform_generator.WaveformGenerator(1, 4096,
-                                                          time_domain_source_model=gaussian_time_domain_strain_2)
+                                                          time_domain_source_model=dummy_func_dict_return_value)
         self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1,
                                           ra=1.375,
                                           dec=-1.2108,
@@ -211,32 +220,47 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
             self.waveform_generator.time_domain_strain()
 
     def test_time_domain_source_model_call(self):
-        self.waveform_generator.time_domain_source_model = MagicMock(return_value=3)
-        self.assertEqual(3, self.waveform_generator.time_domain_strain())
+        self.waveform_generator.parameters = self.simulation_parameters
+        expected = self.waveform_generator.time_domain_source_model(self.waveform_generator.time_array,
+                                                                    self.simulation_parameters['amplitude'],
+                                                                    self.simulation_parameters['mu'],
+                                                                    self.simulation_parameters['sigma'],
+                                                                    self.simulation_parameters['ra'],
+                                                                    self.simulation_parameters['dec'],
+                                                                    self.simulation_parameters['geocent_time'],
+                                                                    self.simulation_parameters['psi'])
+        actual = self.waveform_generator.time_domain_strain()
+        self.assertTrue(np.array_equal(expected['plus'], actual['plus']))
+        self.assertTrue(np.array_equal(expected['cross'], actual['cross']))
 
     def test_frequency_domain_source_model_call_with_ndarray(self):
         self.waveform_generator.time_domain_source_model = None
-        self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=np.array([1, 2, 3]))
+        self.waveform_generator.frequency_domain_source_model = dummy_func_array_return_value
+        self.waveform_generator.parameters = self.simulation_parameters
 
         def side_effect(value, value2):
             return value
 
         with mock.patch('tupak.core.utils.infft') as m:
             m.side_effect = side_effect
-            self.assertTrue(np.array_equal(np.array([1, 2, 3]), self.waveform_generator.time_domain_strain()))
+            expected = self.waveform_generator.frequency_domain_strain()
+            actual = self.waveform_generator.time_domain_strain()
+            self.assertTrue(np.array_equal(expected, actual))
 
     def test_frequency_domain_source_model_call_with_dict(self):
         self.waveform_generator.time_domain_source_model = None
-        self.waveform_generator.frequency_domain_source_model = MagicMock(return_value=dict(plus=1, cross=2))
+        self.waveform_generator.frequency_domain_source_model = dummy_func_dict_return_value
+        self.waveform_generator.parameters = self.simulation_parameters
 
         def side_effect(value, value2):
-            return value, value2
+            return value
 
         with mock.patch('tupak.core.utils.infft') as m:
             m.side_effect = side_effect
-            self.assertDictEqual(dict(plus=(1, self.waveform_generator.sampling_frequency),
-                                      cross=(2, self.waveform_generator.sampling_frequency)),
-                                 self.waveform_generator.time_domain_strain())
+            expected = self.waveform_generator.frequency_domain_strain()
+            actual = self.waveform_generator.time_domain_strain()
+            self.assertTrue(np.array_equal(expected['plus'], actual['plus']))
+            self.assertTrue(np.array_equal(expected['cross'], actual['cross']))
 
     def test_no_source_model_given(self):
         self.waveform_generator.time_domain_source_model = None
@@ -245,7 +269,9 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
             self.waveform_generator.time_domain_strain()
 
     def test_key_popping(self):
-        self.waveform_generator.parameter_conversion = MagicMock(return_value=(dict(a=1e-21, m=100, s=1,
+        self.waveform_generator.parameter_conversion = MagicMock(return_value=(dict(amplitude=1e-2,
+                                                                                    mu=100,
+                                                                                    sigma=1,
                                                                                     ra=1.375, dec=-1.2108,
                                                                                     geocent_time=1126259642.413,
                                                                                     psi=2.659, c=None, d=None),
@@ -255,7 +281,7 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
         except RuntimeError:
             pass
         self.assertListEqual(sorted(self.waveform_generator.parameters.keys()),
-                             sorted(['a', 'm', 's', 'ra', 'dec', 'geocent_time', 'psi']))
+                             sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi']))
 
 
 if __name__ == '__main__':
diff --git a/tupak/gw/waveform_generator.py b/tupak/gw/waveform_generator.py
index f0770b70d0e933075c7643bbccf21ecc62206b4e..fe87dc8acd2cf8a54916b0296050a623b0848cea 100644
--- a/tupak/gw/waveform_generator.py
+++ b/tupak/gw/waveform_generator.py
@@ -7,7 +7,8 @@ import numpy as np
 class WaveformGenerator(object):
 
     def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequency_domain_source_model=None,
-                 time_domain_source_model=None, parameters=None, parameter_conversion=None,
+                 time_domain_source_model=None, parameters=None,
+                 parameter_conversion=lambda parameters, search_keys: (parameters, []),
                  non_standard_sampling_parameter_keys=None,
                  waveform_arguments=None):
         """ A waveform generator
@@ -32,7 +33,8 @@ class WaveformGenerator(object):
         Initial values for the parameters
     parameter_conversion: func, optional
         Function to convert from sampled parameters to parameters of the
-        waveform generator
+        waveform generator. Default value is the identity, i.e. it leaves
+        the parameters unaffected.
     non_standard_sampling_parameter_keys: list, optional
         List of parameter name for *non-standard* sampling parameters.
     waveform_arguments: dict, optional
@@ -62,6 +64,8 @@ class WaveformGenerator(object):
         self.__time_array_updated = False
         self.__full_source_model_keyword_arguments = {}
         self.__full_source_model_keyword_arguments.update(self.waveform_arguments)
+        self.__full_source_model_keyword_arguments.update(self.parameters)
+        self.__added_keys = []
 
     def frequency_domain_strain(self):
         """ Rapper to source_model.
@@ -78,32 +82,11 @@ class WaveformGenerator(object):
         RuntimeError: If no source model is given
 
         """
-        added_keys = []
-        if self.parameter_conversion is not None:
-            self.parameters, added_keys = self.parameter_conversion(self.parameters,
-                                                                    self.non_standard_sampling_parameter_keys)
-
-        if self.frequency_domain_source_model is not None:
-            self.__full_source_model_keyword_arguments.update(self.parameters)
-            model_frequency_strain = self.frequency_domain_source_model(
-                self.frequency_array,
-                **self.__full_source_model_keyword_arguments)
-        elif self.time_domain_source_model is not None:
-            model_frequency_strain = dict()
-            self.__full_source_model_keyword_arguments.update(self.parameters)
-            time_domain_strain = self.time_domain_source_model(
-                self.time_array, **self.__full_source_model_keyword_arguments)
-            if isinstance(time_domain_strain, np.ndarray):
-                return utils.nfft(time_domain_strain, self.sampling_frequency)
-            for key in time_domain_strain:
-                model_frequency_strain[key], self.frequency_array = utils.nfft(time_domain_strain[key],
-                                                                               self.sampling_frequency)
-        else:
-            raise RuntimeError("No source model given")
-
-        for key in added_keys:
-            self.parameters.pop(key)
-        return model_frequency_strain
+        return self._calculate_strain(model=self.frequency_domain_source_model,
+                                      model_data_points=self.frequency_array,
+                                      transformation_function=utils.nfft,
+                                      transformed_model=self.time_domain_source_model,
+                                      transformed_model_data_points=self.time_array)
 
     def time_domain_strain(self):
         """ Rapper to source_model.
@@ -121,29 +104,51 @@ class WaveformGenerator(object):
         RuntimeError: If no source model is given
 
         """
-        added_keys = []
-        if self.parameter_conversion is not None:
-            self.parameters, added_keys = self.parameter_conversion(self.parameters,
-                                                                    self.non_standard_sampling_parameter_keys)
-        if self.time_domain_source_model is not None:
-            self.__full_source_model_keyword_arguments.update(self.parameters)
-            model_time_series = self.time_domain_source_model(
-                self.time_array, **self.__full_source_model_keyword_arguments)
-        elif self.frequency_domain_source_model is not None:
-            model_time_series = dict()
-            self.__full_source_model_keyword_arguments.update(self.parameters)
-            frequency_domain_strain = self.frequency_domain_source_model(
-                self.frequency_array, **self.__full_source_model_keyword_arguments)
-            if isinstance(frequency_domain_strain, np.ndarray):
-                return utils.infft(frequency_domain_strain, self.sampling_frequency)
-            for key in frequency_domain_strain:
-                model_time_series[key] = utils.infft(frequency_domain_strain[key], self.sampling_frequency)
+        return self._calculate_strain(model=self.time_domain_source_model,
+                                      model_data_points=self.time_array,
+                                      transformation_function=utils.infft,
+                                      transformed_model=self.frequency_domain_source_model,
+                                      transformed_model_data_points=self.frequency_array)
+
+    def _calculate_strain(self, model, model_data_points, transformation_function, transformed_model,
+                          transformed_model_data_points):
+        self._apply_parameter_conversion()
+        if model is not None:
+            model_strain = self._strain_from_model(model_data_points, model)
+        elif transformed_model is not None:
+            model_strain = self._strain_from_transformed_model(transformed_model_data_points, transformed_model,
+                                                               transformation_function)
         else:
             raise RuntimeError("No source model given")
-
-        for key in added_keys:
+        self._remove_added_keys()
+        return model_strain
+
+    def _apply_parameter_conversion(self):
+        self.parameters, self.__added_keys = self.parameter_conversion(self.parameters,
+                                                                       self.non_standard_sampling_parameter_keys)
+        self.__full_source_model_keyword_arguments.update(self.parameters)
+
+    def _strain_from_model(self, model_data_points, model):
+        return model(model_data_points, **self.__full_source_model_keyword_arguments)
+
+    def _strain_from_transformed_model(self, transformed_model_data_points, transformed_model, transformation_function):
+        transformed_model_strain = self._strain_from_model(transformed_model_data_points, transformed_model)
+
+        if isinstance(transformed_model_strain, np.ndarray):
+            return transformation_function(transformed_model_strain, self.sampling_frequency)
+
+        model_strain = dict()
+        for key in transformed_model_strain:
+            if transformation_function == utils.nfft:
+                model_strain[key], self.frequency_array = \
+                    transformation_function(transformed_model_strain[key], self.sampling_frequency)
+            else:
+                model_strain[key] = transformation_function(transformed_model_strain[key], self.sampling_frequency)
+        return model_strain
+
+    def _remove_added_keys(self):
+        for key in self.__added_keys:
             self.parameters.pop(key)
-        return model_time_series
 
     @property
     def frequency_array(self):
@@ -155,8 +160,8 @@ class WaveformGenerator(object):
         """
         if self.__frequency_array_updated is False:
             self.frequency_array = utils.create_frequency_series(
-                                        self.sampling_frequency,
-                                        self.duration)
+                self.sampling_frequency,
+                self.duration)
         return self.__frequency_array
 
     @frequency_array.setter
@@ -175,9 +180,9 @@ class WaveformGenerator(object):
 
         if self.__time_array_updated is False:
             self.__time_array = utils.create_time_series(
-                                        self.sampling_frequency,
-                                        self.duration,
-                                        self.start_time)
+                self.sampling_frequency,
+                self.duration,
+                self.start_time)
 
             self.__time_array_updated = True
         return self.__time_array
@@ -191,8 +196,6 @@ class WaveformGenerator(object):
     def parameters(self):
         """ The dictionary of parameters for source model.
 
-        Does some introspection into the source_model to figure out the parameters if none are given.
-
         Returns
         -------
         dict: The dictionary of parameter key-value pairs
@@ -202,6 +205,8 @@ class WaveformGenerator(object):
 
     @parameters.setter
     def parameters(self, parameters):
+        """ Does some introspection into the source_model to figure out the parameters if none are given.
+        """
         self.__parameters_from_source_model()
         if isinstance(parameters, dict):
             for key in parameters.keys():