diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py
index b88118a5dbd483701c094c26e4faad98c9c27011..b3bd6bed6008dc3ef6808d28e61b8090e05c1628 100644
--- a/bilby/gw/likelihood.py
+++ b/bilby/gw/likelihood.py
@@ -187,9 +187,6 @@ class GravitationalWaveTransient(Likelihood):
                 [self.priors['luminosity_distance'].prob(distance)
                  for distance in self._distance_array])
             self._ref_dist = self.priors['luminosity_distance'].rescale(0.5)
-            if self.phase_marginalization:
-                max_bound = np.ceil(10 + np.log10(self._dist_multiplier))
-                self._setup_phase_marginalization(max_bound=max_bound)
             self._setup_distance_marginalization(
                 distance_marginalization_lookup_table)
             for key in ['redshift', 'comoving_distance']:
diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py
index adea14483de062a6c48c31194f366594cd9293da..b4862d79b5f2aee778ed970ed44f4df867a991bf 100644
--- a/test/gw/likelihood_test.py
+++ b/test/gw/likelihood_test.py
@@ -274,131 +274,6 @@ class TestGWTransient(unittest.TestCase):
         )
 
 
-class TestTimeMarginalization(unittest.TestCase):
-    def setUp(self):
-        np.random.seed(500)
-        self.duration = 4
-        self.sampling_frequency = 2048
-        self.parameters = dict(
-            mass_1=31.0,
-            mass_2=29.0,
-            a_1=0.4,
-            a_2=0.3,
-            tilt_1=0.0,
-            tilt_2=0.0,
-            phi_12=1.7,
-            phi_jl=0.3,
-            luminosity_distance=4000.0,
-            theta_jn=0.4,
-            psi=2.659,
-            phase=1.3,
-            geocent_time=1126259640,
-            ra=1.375,
-            dec=-1.2108,
-        )
-
-        self.interferometers = bilby.gw.detector.InterferometerList(["H1"])
-        self.interferometers.set_strain_data_from_power_spectral_densities(
-            sampling_frequency=self.sampling_frequency,
-            duration=self.duration,
-            start_time=1126259640,
-        )
-
-        self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
-            duration=self.duration,
-            sampling_frequency=self.sampling_frequency,
-            frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-            start_time=1126259640,
-        )
-
-        self.priors = bilby.gw.prior.BBHPriorDict()
-
-        self.likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=self.interferometers,
-            waveform_generator=self.waveform_generator,
-            priors=self.priors.copy(),
-        )
-
-        self.likelihood.parameters = self.parameters.copy()
-
-    def tearDown(self):
-        del self.duration
-        del self.sampling_frequency
-        del self.parameters
-        del self.interferometers
-        del self.waveform_generator
-        del self.priors
-        del self.likelihood
-
-    def test_time_marginalisation_full_segment(self):
-        """
-        Test time marginalised likelihood matches brute force version over the
-        whole segment.
-        """
-        likes = []
-        lls = []
-        self.priors["geocent_time"] = bilby.prior.Uniform(
-            minimum=self.waveform_generator.start_time,
-            maximum=self.waveform_generator.start_time + self.duration,
-        )
-        self.time = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=self.interferometers,
-            waveform_generator=self.waveform_generator,
-            time_marginalization=True,
-            priors=self.priors,
-        )
-        times = (
-            self.waveform_generator.start_time
-            + np.linspace(0, self.duration, 4097)[:-1]
-        )
-        for time in times:
-            self.likelihood.parameters["geocent_time"] = time
-            lls.append(self.likelihood.log_likelihood_ratio())
-            likes.append(np.exp(lls[-1]))
-
-        marg_like = np.log(
-            np.trapz(likes * self.time.priors["geocent_time"].prob(times), times)
-        )
-        self.time.parameters = self.parameters.copy()
-        self.time.parameters["time_jitter"] = 0.0
-        self.time.parameters["geocent_time"] = self.waveform_generator.start_time
-        self.assertAlmostEqual(marg_like, self.time.log_likelihood_ratio(), delta=0.5)
-
-    def test_time_marginalisation_partial_segment(self):
-        """
-        Test time marginalised likelihood matches brute force version over the
-        whole segment.
-        """
-        likes = []
-        lls = []
-        self.priors["geocent_time"] = bilby.prior.Uniform(
-            minimum=self.parameters["geocent_time"] + 1 - 0.1,
-            maximum=self.parameters["geocent_time"] + 1 + 0.1,
-        )
-        self.time = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=self.interferometers,
-            waveform_generator=self.waveform_generator,
-            time_marginalization=True,
-            priors=self.priors,
-        )
-        times = (
-            self.waveform_generator.start_time
-            + np.linspace(0, self.duration, 4097)[:-1]
-        )
-        for time in times:
-            self.likelihood.parameters["geocent_time"] = time
-            lls.append(self.likelihood.log_likelihood_ratio())
-            likes.append(np.exp(lls[-1]))
-
-        marg_like = np.log(
-            np.trapz(likes * self.time.priors["geocent_time"].prob(times), times)
-        )
-        self.time.parameters = self.parameters.copy()
-        self.time.parameters["time_jitter"] = 0.0
-        self.time.parameters["geocent_time"] = self.waveform_generator.start_time
-        self.assertAlmostEqual(marg_like, self.time.log_likelihood_ratio(), delta=0.5)
-
-
 class TestMarginalizedLikelihood(unittest.TestCase):
     def setUp(self):
         np.random.seed(500)
@@ -544,7 +419,14 @@ class TestMarginalizedLikelihood(unittest.TestCase):
             bilby.run_sampler(like, new_prior)
 
 
-class TestPhaseMarginalization(unittest.TestCase):
+class TestMarginalizations(unittest.TestCase):
+    """
+    Test all marginalised likelihoods matches brute force version.
+
+    For time, this is strongly dependent on the specific time grid used.
+    The `time_jitter` parameter makes this a weaker dependence during sampling.
+    """
+
     def setUp(self):
         np.random.seed(500)
         self.duration = 4
@@ -565,39 +447,31 @@ class TestPhaseMarginalization(unittest.TestCase):
             geocent_time=1126259642.413,
             ra=1.375,
             dec=-1.2108,
+            time_jitter=0,
         )
 
         self.interferometers = bilby.gw.detector.InterferometerList(["H1"])
         self.interferometers.set_strain_data_from_power_spectral_densities(
-            sampling_frequency=self.sampling_frequency, duration=self.duration
+            sampling_frequency=self.sampling_frequency,
+            duration=self.duration,
+            start_time=1126259640,
         )
 
         self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
             duration=self.duration,
             sampling_frequency=self.sampling_frequency,
             frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
+            start_time=1126259640,
         )
-
-        self.prior = bilby.gw.prior.BBHPriorDict()
-        self.prior["geocent_time"] = bilby.prior.Uniform(
-            minimum=self.parameters["geocent_time"] - self.duration / 2,
-            maximum=self.parameters["geocent_time"] + self.duration / 2,
+        self.interferometers.inject_signal(
+            parameters=self.parameters, waveform_generator=self.waveform_generator
         )
 
-        self.likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=self.interferometers,
-            waveform_generator=self.waveform_generator,
-            priors=self.prior.copy(),
-        )
-
-        self.phase = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=self.interferometers,
-            waveform_generator=self.waveform_generator,
-            phase_marginalization=True,
-            priors=self.prior.copy(),
+        self.priors = bilby.gw.prior.BBHPriorDict()
+        self.priors["geocent_time"] = bilby.prior.Uniform(
+            minimum=self.interferometers.start_time,
+            maximum=self.interferometers.start_time + self.interferometers.duration,
         )
-        for like in [self.likelihood, self.phase]:
-            like.parameters = self.parameters.copy()
 
     def tearDown(self):
         del self.duration
@@ -605,147 +479,167 @@ class TestPhaseMarginalization(unittest.TestCase):
         del self.parameters
         del self.interferometers
         del self.waveform_generator
-        del self.prior
-        del self.likelihood
-        del self.phase
+        del self.priors
 
-    def test_phase_marginalisation(self):
-        """Test phase marginalised likelihood matches brute force version"""
-        like = []
-        phases = np.linspace(0, 2 * np.pi, 1000)
-        for phase in phases:
-            self.likelihood.parameters["phase"] = phase
-            like.append(np.exp(self.likelihood.log_likelihood_ratio()))
+    def get_likelihood(
+        self,
+        time_marginalization=False,
+        phase_marginalization=False,
+        distance_marginalization=False,
+        priors=None
+    ):
+        if priors is None:
+            priors = self.priors.copy()
+        if distance_marginalization and phase_marginalization:
+            lookup = "distance_lookup_phase.npz"
+        elif distance_marginalization:
+            lookup = "distance_lookup_no_phase.npz"
+        else:
+            lookup = None
+        like = bilby.gw.likelihood.GravitationalWaveTransient(
+            interferometers=self.interferometers,
+            waveform_generator=self.waveform_generator,
+            distance_marginalization=distance_marginalization,
+            phase_marginalization=phase_marginalization,
+            time_marginalization=time_marginalization,
+            distance_marginalization_lookup_table=lookup,
+            priors=priors,
+        )
+        like.parameters = self.parameters.copy()
+        if time_marginalization:
+            like.parameters["geocent_time"] = self.interferometers.start_time
+        return like
+
+    def _template(self, marginalized, non_marginalized, key, prior=None, values=None):
+        if prior is None:
+            prior = self.priors[key]
+        if values is None:
+            values = np.linspace(prior.minimum, prior.maximum, 1000)
+        prior_values = prior.prob(values)
+        ln_likes = np.empty(values.shape)
+        for ii, value in enumerate(values):
+            non_marginalized.parameters[key] = value
+            ln_likes[ii] = non_marginalized.log_likelihood_ratio()
+        like = np.exp(ln_likes - max(ln_likes))
+
+        marg_like = np.log(np.trapz(like * prior_values, values)) + max(ln_likes)
+        self.assertAlmostEqual(
+            marg_like, marginalized.log_likelihood_ratio(), delta=0.5
+        )
 
-        marg_like = np.log(np.trapz(like, phases) / (2 * np.pi))
-        self.phase.parameters = self.parameters.copy()
-        self.assertAlmostEqual(marg_like, self.phase.log_likelihood_ratio(), delta=0.5)
+    def test_distance_marginalisation(self):
+        self._template(
+            self.get_likelihood(distance_marginalization=True),
+            self.get_likelihood(),
+            key="luminosity_distance",
+        )
 
+    def test_distance_phase_marginalisation(self):
+        self._template(
+            self.get_likelihood(distance_marginalization=True, phase_marginalization=True),
+            self.get_likelihood(phase_marginalization=True),
+            key="luminosity_distance",
+        )
 
-class TestTimePhaseMarginalization(unittest.TestCase):
-    def setUp(self):
-        np.random.seed(500)
-        self.duration = 4
-        self.sampling_frequency = 2048
-        self.parameters = dict(
-            mass_1=31.0,
-            mass_2=29.0,
-            a_1=0.4,
-            a_2=0.3,
-            tilt_1=0.0,
-            tilt_2=0.0,
-            phi_12=1.7,
-            phi_jl=0.3,
-            luminosity_distance=4000.0,
-            theta_jn=0.4,
-            psi=2.659,
-            phase=1.3,
-            geocent_time=1126259642.413,
-            ra=1.375,
-            dec=-1.2108,
+    def test_distance_time_marginalisation(self):
+        self._template(
+            self.get_likelihood(distance_marginalization=True, time_marginalization=True),
+            self.get_likelihood(time_marginalization=True),
+            key="luminosity_distance",
         )
 
-        self.interferometers = bilby.gw.detector.InterferometerList(["H1"])
-        self.interferometers.set_strain_data_from_power_spectral_densities(
-            sampling_frequency=self.sampling_frequency,
-            duration=self.duration,
-            start_time=1126259640,
+    def test_distance_phase_time_marginalisation(self):
+        """
+        Test phase marginalised likelihood matches brute force version when
+        also marginalising over time.
+        """
+        self._template(
+            self.get_likelihood(distance_marginalization=True, phase_marginalization=True, time_marginalization=True),
+            self.get_likelihood(phase_marginalization=True, time_marginalization=True),
+            key="luminosity_distance",
         )
 
-        self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
-            duration=self.duration,
-            sampling_frequency=self.sampling_frequency,
-            frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-            start_time=1126259640,
+    def test_phase_marginalisation(self):
+        self._template(
+            self.get_likelihood(phase_marginalization=True),
+            self.get_likelihood(),
+            key="phase",
         )
 
-        self.priors = bilby.gw.prior.BBHPriorDict()
-        self.priors["geocent_time"] = bilby.prior.Uniform(
-            minimum=self.parameters["geocent_time"] - self.duration / 2,
-            maximum=self.parameters["geocent_time"] + self.duration / 2,
+    def test_phase_distance_marginalisation(self):
+        self._template(
+            self.get_likelihood(distance_marginalization=True, phase_marginalization=True),
+            self.get_likelihood(distance_marginalization=True),
+            key="phase",
         )
 
-        self.likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=self.interferometers,
-            waveform_generator=self.waveform_generator,
-            priors=self.priors.copy(),
+    def test_phase_time_marginalisation(self):
+        self._template(
+            self.get_likelihood(time_marginalization=True, phase_marginalization=True),
+            self.get_likelihood(time_marginalization=True),
+            key="phase",
         )
 
-        self.time = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=self.interferometers,
-            waveform_generator=self.waveform_generator,
-            time_marginalization=True,
-            priors=self.priors.copy(),
+    def test_phase_distance_time_marginalisation(self):
+        self._template(
+            self.get_likelihood(time_marginalization=True, distance_marginalization=True, phase_marginalization=True),
+            self.get_likelihood(time_marginalization=True, distance_marginalization=True),
+            key="phase",
         )
 
-        self.phase = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=self.interferometers,
-            waveform_generator=self.waveform_generator,
-            phase_marginalization=True,
-            priors=self.priors.copy(),
+    def test_time_marginalisation(self):
+        times = self.waveform_generator.time_array
+        self._template(
+            self.get_likelihood(time_marginalization=True),
+            self.get_likelihood(),
+            key="geocent_time",
+            values=times,
         )
 
-        self.time_phase = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=self.interferometers,
-            waveform_generator=self.waveform_generator,
-            time_marginalization=True,
-            phase_marginalization=True,
-            priors=self.priors,
+    def test_time_distance_marginalisation(self):
+        times = self.waveform_generator.time_array
+        self._template(
+            self.get_likelihood(time_marginalization=True, distance_marginalization=True),
+            self.get_likelihood(distance_marginalization=True),
+            key="geocent_time",
+            values=times
         )
-        for like in [self.likelihood, self.time, self.phase, self.time_phase]:
-            like.parameters = self.parameters.copy()
 
-    def tearDown(self):
-        del self.duration
-        del self.sampling_frequency
-        del self.parameters
-        del self.interferometers
-        del self.waveform_generator
-        del self.priors
-        del self.likelihood
-        del self.time
-        del self.phase
-        del self.time_phase
+    def test_time_phase_marginalisation(self):
+        times = self.waveform_generator.time_array
+        self._template(
+            self.get_likelihood(time_marginalization=True, phase_marginalization=True),
+            self.get_likelihood(phase_marginalization=True),
+            key="geocent_time",
+            values=times
+        )
 
-    def test_time_marginalisation(self):
-        """
-        Test time marginalised likelihood matches brute force version when
-        also marginalising over phase.
-        """
-        like = []
-        times = np.linspace(
-            self.time_phase.priors["geocent_time"].minimum,
-            self.time_phase.priors["geocent_time"].maximum,
-            4097,
-        )[:-1]
-        for time in times:
-            self.phase.parameters["geocent_time"] = time
-            like.append(np.exp(self.phase.log_likelihood_ratio()))
-
-        marg_like = np.log(np.trapz(like, times) / self.waveform_generator.duration)
-        self.time_phase.parameters = self.parameters.copy()
-        self.time_phase.parameters["time_jitter"] = 0.0
-        self.assertAlmostEqual(
-            marg_like, self.time_phase.log_likelihood_ratio(), delta=0.5
+    def test_time_distance_phase_marginalisation(self):
+        times = self.waveform_generator.time_array
+        self._template(
+            self.get_likelihood(time_marginalization=True, phase_marginalization=True, distance_marginalization=True),
+            self.get_likelihood(phase_marginalization=True, distance_marginalization=True),
+            key="geocent_time",
+            values=times
         )
 
-    def test_phase_marginalisation(self):
+    def test_time_marginalisation_partial_segment(self):
         """
-        Test phase marginalised likelihood matches brute force version when
-        also marginalising over time.
+        Test time marginalised likelihood matches brute force version over
+        just part of a segment.
         """
-        like = []
-        phases = np.linspace(0, 2 * np.pi, 1000)
-        for phase in phases:
-            self.time.parameters["phase"] = phase
-            self.time.parameters["time_jitter"] = 0.0
-            like.append(np.exp(self.time.log_likelihood_ratio()))
-
-        marg_like = np.log(np.trapz(like, phases) / (2 * np.pi))
-        self.time_phase.parameters = self.parameters.copy()
-        self.time_phase.parameters["time_jitter"] = 0.0
-        self.assertAlmostEqual(
-            marg_like, self.time_phase.log_likelihood_ratio(), delta=0.5
+        priors = self.priors.copy()
+        prior = bilby.prior.Uniform(
+            minimum=self.parameters["geocent_time"] - 0.1,
+            maximum=self.parameters["geocent_time"] + 0.1,
+        )
+        priors["geocent_time"] = prior
+        self._template(
+            self.get_likelihood(time_marginalization=True, priors=priors.copy()),
+            self.get_likelihood(priors=priors.copy()),
+            key="geocent_time",
+            values=self.waveform_generator.time_array,
+            prior=prior,
         )