diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py
index 72966d35f7380973beca2276fe6ca7f7ad9e2eab..43835e9357da88abe991394b1f3bff4e6e97f632 100644
--- a/bilby/gw/likelihood.py
+++ b/bilby/gw/likelihood.py
@@ -551,13 +551,14 @@ class GravitationalWaveTransient(likelihood.Likelihood):
     def load_lookup_table(self, filename):
         if os.path.exists(filename):
             loaded_file = dict(np.load(filename))
-            if self._test_cached_lookup_table(loaded_file):
+            match, failure = self._test_cached_lookup_table(loaded_file)
+            if match:
                 logger.info('Loaded distance marginalisation lookup table from '
                             '{}.'.format(filename))
                 return loaded_file
             else:
                 logger.info('Loaded distance marginalisation lookup table does '
-                            'not match prior')
+                            'not match for {}.'.format(failure))
                 return None
         elif isinstance(filename, str):
             logger.info('Distance marginalisation file {} does not '
@@ -571,13 +572,22 @@ class GravitationalWaveTransient(likelihood.Likelihood):
                  distance_array=self._distance_array,
                  prior_array=self.distance_prior_array,
                  lookup_table=self._dist_margd_loglikelihood_array,
-                 reference_distance=self._ref_dist)
+                 reference_distance=self._ref_dist,
+                 phase_marginalization=self.phase_marginalization)
 
     def _test_cached_lookup_table(self, loaded_file):
-        cond_a = np.all(self._distance_array == loaded_file['distance_array'])
-        cond_b = np.all(self.distance_prior_array == loaded_file['prior_array'])
-        cond_c = self._ref_dist == loaded_file['reference_distance']
-        return all([cond_a, cond_b, cond_c])
+        pairs = dict(
+            distance_array=self._distance_array,
+            prior_array=self.distance_prior_array,
+            reference_distance=self._ref_dist,
+            phase_marginalization=self.phase_marginalization)
+        for key in pairs:
+            if key not in loaded_file:
+                return False, key
+            elif not np.array_equal(np.atleast_1d(loaded_file[key]),
+                                    np.atleast_1d(pairs[key])):
+                return False, key
+        return True, None
 
     def _create_lookup_table(self):
         """ Make the lookup table """