diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py
index 24afd74cb80c5c725e102bcd6cb02c0771a687e2..c466667112e887b22a8617fc6d7bdec71e146457 100644
--- a/bilby/gw/detector/interferometer.py
+++ b/bilby/gw/detector/interferometer.py
@@ -287,20 +287,23 @@ class Interferometer(object):
 
         return signal_ifo
 
-    def inject_signal(self, parameters=None, injection_polarizations=None,
+    def inject_signal(self, parameters, injection_polarizations=None,
                       waveform_generator=None):
-        """ Inject a signal into noise
+        """ General signal injection method.
+        Provide the injection parameters and either the injection polarizations
+        or the waveform generator to inject a signal into the detector.
+        Defaults to the injection polarizations is both are given.
 
         Parameters
         ----------
         parameters: dict
             Parameters of the injection.
-        injection_polarizations: dict
+        injection_polarizations: dict, optional
            Polarizations of waveform to inject, output of
            `waveform_generator.frequency_domain_strain()`. If
            `waveform_generator` is also given, the injection_polarizations will
            be calculated directly and this argument can be ignored.
-        waveform_generator: bilby.gw.waveform_generator.WaveformGenerator
+        waveform_generator: bilby.gw.waveform_generator.WaveformGenerator, optional
             A WaveformGenerator instance using the source model to inject. If
             `injection_polarizations` is given, this will be ignored.
 
@@ -313,40 +316,71 @@ class Interferometer(object):
         Returns
         -------
         injection_polarizations: dict
+            The injected polarizations. This is the same as the injection_polarizations parameters
+            if it was passed in. Otherwise it is the return value of waveform_generator.frequency_domain_strain().
+
+        """
+        if injection_polarizations is None and waveform_generator is None:
+            raise ValueError(
+                "inject_signal needs one of waveform_generator or "
+                "injection_polarizations.")
+        elif injection_polarizations is not None:
+            self.inject_signal_from_waveform_polarizations(parameters=parameters,
+                                                           injection_polarizations=injection_polarizations)
+        elif waveform_generator is not None:
+            injection_polarizations = self.inject_signal_from_waveform_generator(parameters=parameters,
+                                                                                 waveform_generator=waveform_generator)
+        return injection_polarizations
+
+    def inject_signal_from_waveform_generator(self, parameters, waveform_generator):
+        """ Inject a signal using a waveform generator and a set of parameters.
+        Alternative to `inject_signal` and `inject_signal_from_waveform_polarizations`
+
+        Parameters
+        ----------
+        parameters: dict
+            Parameters of the injection.
+        waveform_generator: bilby.gw.waveform_generator.WaveformGenerator
+            A WaveformGenerator instance using the source model to inject.
+
+        Note
+        -------
+        if your signal takes a substantial amount of time to generate, or
+        you experience buggy behaviour. It is preferable to use the
+        inject_signal_from_waveform_polarizations() method.
+
+        Returns
+        -------
+        injection_polarizations: dict
+            The internally generated injection parameters
 
         """
+        injection_polarizations = \
+            waveform_generator.frequency_domain_strain(parameters)
+        self.inject_signal_from_waveform_polarizations(parameters=parameters,
+                                                       injection_polarizations=injection_polarizations)
+        return injection_polarizations
+
+    def inject_signal_from_waveform_polarizations(self, parameters, injection_polarizations):
+        """ Inject a signal into the detector from a dict of waveform polarizations.
+        Alternative to `inject_signal` and `inject_signal_from_waveform_generator`.
 
-        if injection_polarizations is None:
-            if waveform_generator is not None:
-                injection_polarizations = \
-                    waveform_generator.frequency_domain_strain(parameters)
-            else:
-                raise ValueError(
-                    "inject_signal needs one of waveform_generator or "
-                    "injection_polarizations.")
-
-            if injection_polarizations is None:
-                raise ValueError(
-                    'Trying to inject signal which is None. The most likely cause'
-                    ' is that waveform_generator.frequency_domain_strain returned'
-                    ' None. This can be caused if, e.g., mass_2 > mass_1.')
+        Parameters
+        ----------
+        parameters: dict
+            Parameters of the injection.
+        injection_polarizations: dict
+           Polarizations of waveform to inject, output of
+           `waveform_generator.frequency_domain_strain()`.
 
+        """
         if not self.strain_data.time_within_data(parameters['geocent_time']):
             logger.warning(
                 'Injecting signal outside segment, start_time={}, merger time={}.'
                 .format(self.strain_data.start_time, parameters['geocent_time']))
 
         signal_ifo = self.get_detector_response(injection_polarizations, parameters)
-        if np.shape(self.strain_data.frequency_domain_strain).__eq__(np.shape(signal_ifo)):
-            self.strain_data.frequency_domain_strain = \
-                signal_ifo + self.strain_data.frequency_domain_strain
-        else:
-            logger.info('Injecting into zero noise.')
-            self.set_strain_data_from_frequency_domain_strain(
-                signal_ifo,
-                sampling_frequency=self.strain_data.sampling_frequency,
-                duration=self.strain_data.duration,
-                start_time=self.strain_data.start_time)
+        self.strain_data.frequency_domain_strain += signal_ifo
 
         self.meta_data['optimal_SNR'] = (
             np.sqrt(self.optimal_snr_squared(signal=signal_ifo)).real)
@@ -360,8 +394,6 @@ class Interferometer(object):
         for key in parameters:
             logger.info('  {} = {}'.format(key, parameters[key]))
 
-        return injection_polarizations
-
     @property
     def amplitude_spectral_density_array(self):
         """ Returns the amplitude spectral density (ASD) given we know a power spectral denstiy (PSD)
diff --git a/test/detector_test.py b/test/detector_test.py
index f63ee5baa4cbdb9fe25998c016035f67a06da4cb..2f225c719ad4827bcea861eef7ea23769dbb0d85 100644
--- a/test/detector_test.py
+++ b/test/detector_test.py
@@ -1,6 +1,7 @@
 from __future__ import absolute_import
 
 import bilby
+import inspect
 import unittest
 import mock
 from mock import MagicMock
@@ -261,6 +262,17 @@ class TestInterferometer(unittest.TestCase):
         self.ifo.strain_data.set_from_frequency_domain_strain(
             np.linspace(0, 4096, 4097), sampling_frequency=4096, duration=2)
         self.outdir = 'outdir'
+
+        self.injection_polarizations = dict()
+        np.random.seed(42)
+        self.injection_polarizations['plus'] = np.random.random(4097)
+        self.injection_polarizations['cross'] = np.random.random(4097)
+
+        self.waveform_generator = MagicMock()
+        self.wg_polarizations = dict(plus=np.random.random(4097), cross=np.random.random(4097))
+        self.waveform_generator.frequency_domain_strain = lambda _: self.wg_polarizations
+        self.parameters = dict(ra=0., dec=0., geocent_time=0., psi=0.)
+
         bilby.core.utils.check_directory_exists_and_if_not_mkdir(self.outdir)
 
     def tearDown(self):
@@ -277,6 +289,10 @@ class TestInterferometer(unittest.TestCase):
         del self.xarm_tilt
         del self.yarm_tilt
         del self.ifo
+        del self.injection_polarizations
+        del self.wg_polarizations
+        del self.waveform_generator
+        del self.parameters
         rmtree(self.outdir)
 
     def test_name_setting(self):
@@ -342,7 +358,86 @@ class TestInterferometer(unittest.TestCase):
             parameters=dict(ra=0, dec=0, geocent_time=0, psi=0))
         self.assertTrue(np.array_equal(response, (plus + cross) * self.ifo.frequency_mask * np.exp(-0j)))
 
-    def test_inject_signal_no_waveform_polarizations(self):
+    def test_inject_signal_from_waveform_polarizations_correct_injection(self):
+        original_strain = self.ifo.strain_data.frequency_domain_strain
+        self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
+        self.ifo.inject_signal_from_waveform_polarizations(parameters=self.parameters,
+                                                           injection_polarizations=self.injection_polarizations)
+        expected = self.injection_polarizations['plus'] + self.injection_polarizations['cross'] + original_strain
+        self.assertTrue(np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain))
+
+    def test_inject_signal_from_waveform_polarizations_meta_data(self):
+        self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
+        self.ifo.inject_signal_from_waveform_polarizations(parameters=self.parameters,
+                                                           injection_polarizations=self.injection_polarizations)
+        signal_ifo_expected = self.injection_polarizations['plus'] + self.injection_polarizations['cross']
+        self.assertAlmostEqual(self.ifo.optimal_snr_squared(signal=signal_ifo_expected).real,
+                               self.ifo.meta_data['optimal_SNR']**2, 10)
+        self.assertAlmostEqual(self.ifo.matched_filter_snr(signal=signal_ifo_expected),
+                               self.ifo.meta_data['matched_filter_SNR'], 10)
+        self.assertDictEqual(self.parameters,
+                             self.ifo.meta_data['parameters'])
+
+    def test_inject_signal_from_waveform_polarizations_incorrect_length(self):
+        self.injection_polarizations['plus'] = np.random.random(1000)
+        self.injection_polarizations['cross'] = np.random.random(1000)
+        self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
+        with self.assertRaises(ValueError):
+            self.ifo.inject_signal_from_waveform_polarizations(parameters=self.parameters,
+                                                               injection_polarizations=self.injection_polarizations)
+
+    @patch.object(bilby.core.utils.logger, 'warning')
+    def test_inject_signal_outside_segment_logs_warning(self, m):
+        self.parameters['geocent_time'] = 24345.
+        self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
+        self.ifo.inject_signal_from_waveform_polarizations(parameters=self.parameters,
+                                                           injection_polarizations=self.injection_polarizations)
+        self.assertTrue(m.called)
+
+    def test_inject_signal_from_waveform_generator_correct_return_value(self):
+        self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
+        returned_polarizations = self.ifo.inject_signal_from_waveform_generator(parameters=self.parameters,
+                                                                                waveform_generator=self.waveform_generator)
+        self.assertTrue(np.array_equal(self.wg_polarizations['plus'], returned_polarizations['plus']))
+        self.assertTrue(np.array_equal(self.wg_polarizations['cross'], returned_polarizations['cross']))
+
+    @patch.object(bilby.gw.detector.Interferometer, 'inject_signal_from_waveform_generator')
+    def test_inject_signal_with_waveform_generator_correct_call(self, m):
+        self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
+        _ = self.ifo.inject_signal(parameters=self.parameters,
+                                   waveform_generator=self.waveform_generator)
+        m.assert_called_with(parameters=self.parameters,
+                             waveform_generator=self.waveform_generator)
+
+    def test_inject_signal_from_waveform_generator_correct_injection(self):
+        original_strain = self.ifo.strain_data.frequency_domain_strain
+        self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
+        injection_polarizations = self.ifo.inject_signal_from_waveform_generator(parameters=self.parameters,
+                                                                                 waveform_generator=self.waveform_generator)
+        expected = injection_polarizations['plus'] + injection_polarizations['cross'] + original_strain
+        self.assertTrue(np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain))
+
+    def test_inject_signal_with_injection_polarizations(self):
+        original_strain = self.ifo.strain_data.frequency_domain_strain
+        self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
+        self.ifo.inject_signal(parameters=self.parameters,
+                               injection_polarizations=self.injection_polarizations)
+        expected = self.injection_polarizations['plus'] + self.injection_polarizations['cross'] + original_strain
+        self.assertTrue(np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain))
+
+    @patch.object(bilby.gw.detector.Interferometer, 'inject_signal_from_waveform_polarizations')
+    def test_inject_signal_with_injection_polarizations_and_waveform_generator(self, m):
+        self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
+        _ = self.ifo.inject_signal(parameters=self.parameters,
+                                   waveform_generator=self.waveform_generator,
+                                   injection_polarizations=self.injection_polarizations)
+        m.assert_called_with(parameters=self.parameters,
+                             injection_polarizations=self.injection_polarizations)
+        with self.assertRaises(ValueError):
+            m.assert_called_with(parameters=self.parameters,
+                                 injection_polarizations=self.wg_polarizations)
+
+    def test_inject_signal_raises_value_error(self):
         with self.assertRaises(ValueError):
             self.ifo.inject_signal(injection_polarizations=None, parameters=None)
 
@@ -517,8 +612,8 @@ class TestInterferometerEquals(unittest.TestCase):
         self.assertNotEqual(self.ifo_1, self.ifo_2)
 
     def test_eq_false_different_ifo_strain_data(self):
-        self.strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency/2,
-                                                          duration=self.duration*2)
+        self.strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency / 2,
+                                                          duration=self.duration * 2)
         self.ifo_1.set_strain_data_from_frequency_domain_strain(frequency_array=self.frequency_array,
                                                                 frequency_domain_strain=self.strain)
         self.assertNotEqual(self.ifo_1, self.ifo_2)
@@ -766,27 +861,27 @@ class TestInterferometerStrainDataEquals(unittest.TestCase):
         self.assertNotEqual(self.ifosd_1, self.ifosd_2)
 
     def test_eq_different_frequency_array(self):
-        new_frequency_array = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency/2,
-                                                                  duration=self.duration*2)
+        new_frequency_array = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency / 2,
+                                                                  duration=self.duration * 2)
         self.ifosd_1.frequency_array = new_frequency_array
         self.assertNotEqual(self.ifosd_1, self.ifosd_2)
 
     def test_eq_different_frequency_domain_strain(self):
-        new_strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency/2,
-                                                         duration=self.duration*2)
+        new_strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency / 2,
+                                                         duration=self.duration * 2)
         self.ifosd_1._frequency_domain_strain = new_strain
         self.assertNotEqual(self.ifosd_1, self.ifosd_2)
 
     def test_eq_different_time_array(self):
-        new_time_array = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency/2,
-                                                        duration=self.duration*2)
+        new_time_array = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency / 2,
+                                                        duration=self.duration * 2)
         self.ifosd_1.time_array = new_time_array
         self.assertNotEqual(self.ifosd_1, self.ifosd_2)
 
     def test_eq_different_time_domain_strain(self):
-        new_strain = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency/2,
-                                                    duration=self.duration*2)
-        self.ifosd_1._time_domain_strain= new_strain
+        new_strain = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency / 2,
+                                                    duration=self.duration * 2)
+        self.ifosd_1._time_domain_strain = new_strain
         self.assertNotEqual(self.ifosd_1, self.ifosd_2)
 
 
@@ -1250,9 +1345,9 @@ class TestPowerSpectralDensityEquals(unittest.TestCase):
         self.frequency_array = np.linspace(1, 100)
         self.psd_array = np.linspace(1, 100)
         self.psd_from_array_1 = bilby.gw.detector.PowerSpectralDensity. \
-            from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array= self.psd_array)
+            from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array=self.psd_array)
         self.psd_from_array_2 = bilby.gw.detector.PowerSpectralDensity. \
-            from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array= self.psd_array)
+            from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array=self.psd_array)
 
     def tearDown(self):
         del self.psd_from_file_1