diff --git a/bilby/gw/waveform_generator.py b/bilby/gw/waveform_generator.py
index 00e24bde621c5b989ba3daca27d4f4ebf4ed2267..98d06914695af20473745e88d3916a0367832e2b 100644
--- a/bilby/gw/waveform_generator.py
+++ b/bilby/gw/waveform_generator.py
@@ -62,6 +62,7 @@ class WaveformGenerator(object):
             self.waveform_arguments = dict()
         if isinstance(parameters, dict):
             self.parameters = parameters
+        self._cache = dict(parameters=None, waveform=None)
 
     def __repr__(self):
         if self.frequency_domain_source_model is not None:
@@ -147,6 +148,8 @@ class WaveformGenerator(object):
                           transformed_model_data_points, parameters):
         if parameters is not None:
             self.parameters = parameters
+        if self.parameters == self._cache['parameters']:
+            return self._cache['waveform']
         if model is not None:
             model_strain = self._strain_from_model(model_data_points, model)
         elif transformed_model is not None:
@@ -154,6 +157,8 @@ class WaveformGenerator(object):
                                                                transformation_function)
         else:
             raise RuntimeError("No source model given")
+        self._cache['waveform'] = model_strain
+        self._cache['parameters'] = self.parameters.copy()
         return model_strain
 
     def _strain_from_model(self, model_data_points, model):
diff --git a/test/waveform_generator_test.py b/test/waveform_generator_test.py
index 8eba859ac03f0a1d73ade2a74f0785aff1162953..15aa65e467f2ecccb379effab951417904bed6b7 100644
--- a/test/waveform_generator_test.py
+++ b/test/waveform_generator_test.py
@@ -263,6 +263,19 @@ class TestFrequencyDomainStrainMethod(unittest.TestCase):
         self.assertListEqual(sorted(self.waveform_generator.parameters.keys()),
                              sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi']))
 
+    def test_caching_with_parameters(self):
+        original_waveform = self.waveform_generator.frequency_domain_strain(
+            parameters=self.simulation_parameters)
+        new_waveform = self.waveform_generator.frequency_domain_strain(
+            parameters=self.simulation_parameters)
+        self.assertDictEqual(original_waveform, new_waveform)
+
+    def test_caching_without_parameters(self):
+        original_waveform = self.waveform_generator.frequency_domain_strain(
+            parameters=self.simulation_parameters)
+        new_waveform = self.waveform_generator.frequency_domain_strain()
+        self.assertDictEqual(original_waveform, new_waveform)
+
 
 class TestTimeDomainStrainMethod(unittest.TestCase):
 
@@ -354,6 +367,20 @@ class TestTimeDomainStrainMethod(unittest.TestCase):
         self.assertListEqual(sorted(self.waveform_generator.parameters.keys()),
                              sorted(['amplitude', 'mu', 'sigma', 'ra', 'dec', 'geocent_time', 'psi']))
 
+    def test_caching_with_parameters(self):
+        original_waveform = self.waveform_generator.time_domain_strain(
+            parameters=self.simulation_parameters)
+        new_waveform = self.waveform_generator.time_domain_strain(
+            parameters=self.simulation_parameters)
+        self.assertDictEqual(original_waveform, new_waveform)
+
+    def test_caching_without_parameters(self):
+        original_waveform = self.waveform_generator.time_domain_strain(
+            parameters=self.simulation_parameters)
+        new_waveform = self.waveform_generator.time_domain_strain()
+        self.assertDictEqual(original_waveform, new_waveform)
+
+
 
 if __name__ == '__main__':
     unittest.main()