diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py
index 6da8c8363c1ea56b92ff3561a1afd3eab40f6dfe..c5d2b46bae88404318398944788497c1697fe5a2 100644
--- a/bilby/gw/likelihood.py
+++ b/bilby/gw/likelihood.py
@@ -53,6 +53,14 @@ class GravitationalWaveTransient(likelihood.Likelihood):
         This is done analytically using a Bessel function.
     priors: dict, optional
         If given, used in the distance and phase marginalization.
+    distance_marginalization_lookup_table: (dict, str), optional
+        If a dict, dictionary containing the lookup_table, distance_array,
+        (distance) prior_array, and reference_distance used to construct
+        the table.
+        If a string the name of a file containing these quantities.
+        The lookup table is stored after construction in either the
+        provided string or a default location:
+        '.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz'
 
     Returns
     -------
@@ -62,8 +70,10 @@ class GravitationalWaveTransient(likelihood.Likelihood):
 
     """
 
-    def __init__(self, interferometers, waveform_generator, time_marginalization=False, distance_marginalization=False,
-                 phase_marginalization=False, priors=None):
+    def __init__(self, interferometers, waveform_generator,
+                 time_marginalization=False, distance_marginalization=False,
+                 phase_marginalization=False, priors=None,
+                 distance_marginalization_lookup_table=None):
 
         self.waveform_generator = waveform_generator
         likelihood.Likelihood.__init__(self, dict())
@@ -79,7 +89,8 @@ class GravitationalWaveTransient(likelihood.Likelihood):
             phase_marginalization=self.phase_marginalization,
             distance_marginalization=self.distance_marginalization,
             waveform_arguments=waveform_generator.waveform_arguments,
-            frequency_domain_source_model=str(waveform_generator.frequency_domain_source_model))
+            frequency_domain_source_model=str(
+                waveform_generator.frequency_domain_source_model))
 
         if self.time_marginalization:
             self._check_prior_is_set(key='geocent_time')
@@ -93,10 +104,16 @@ class GravitationalWaveTransient(likelihood.Likelihood):
             priors['phase'] = float(0)
 
         if self.distance_marginalization:
+            self._lookup_table_filename = None
             self._check_prior_is_set(key='luminosity_distance')
-            self._distance_array = np.linspace(self.priors['luminosity_distance'].minimum,
-                                               self.priors['luminosity_distance'].maximum, int(1e4))
-            self._setup_distance_marginalization()
+            self._distance_array = np.linspace(
+                self.priors['luminosity_distance'].minimum,
+                self.priors['luminosity_distance'].maximum, int(1e4))
+            self.distance_prior_array = np.array(
+                [self.priors['luminosity_distance'].prob(distance)
+                 for distance in self._distance_array])
+            self._setup_distance_marginalization(
+                distance_marginalization_lookup_table)
             priors['luminosity_distance'] = float(self._ref_dist)
 
     def __repr__(self):
@@ -292,68 +309,74 @@ class GravitationalWaveTransient(likelihood.Likelihood):
             return np.hstack((-np.logspace(3, -3, self._dist_margd_loglikelihood_array.shape[1] / 2),
                               np.logspace(-3, 10, self._dist_margd_loglikelihood_array.shape[1] / 2)))
 
-    def _setup_distance_marginalization(self):
-        self._create_lookup_table()
+    def _setup_distance_marginalization(self, lookup_table=None):
+        if isinstance(lookup_table, str) or lookup_table is None:
+            self.cached_lookup_table_filename = lookup_table
+            lookup_table = self.load_lookup_table(
+                self.cached_lookup_table_filename)
+        if isinstance(lookup_table, dict):
+            if self._test_cached_lookup_table(lookup_table):
+                self._dist_margd_loglikelihood_array = lookup_table[
+                    'lookup_table']
+            else:
+                self._create_lookup_table()
+        else:
+            self._create_lookup_table()
         self._interp_dist_margd_loglikelihood = UnsortedInterp2d(
             self._rho_mf_ref_array, self._rho_opt_ref_array,
             self._dist_margd_loglikelihood_array)
 
     @property
     def cached_lookup_table_filename(self):
-        dmin = self._distance_array[0]
-        dmax = self._distance_array[-1]
-        n = len(self._distance_array)
-        cached_lookup_table_filename = (
-            '.distance_marginalization_lookup_dmin{}_dmax{}_n{}_v1.npy'
-            .format(dmin, dmax, n))
-        return cached_lookup_table_filename
-
-    @property
-    def cached_lookup_table(self):
-        """ Reads in the cached lookup table
-
-        Returns
-        -------
-        cached_lookup_table: np.ndarray
-            Columns are _distance_array, distance_prior_array,
-            dist_marged_log_l_tc_array. This is only returned if the file
-            exists and the first two columns match the equivalent values
-            stored on disk.
-
-        """
-
-        if os.path.exists(self.cached_lookup_table_filename):
-            loaded_file = np.load(self.cached_lookup_table_filename)
+        if self._lookup_table_filename is None:
+            dmin = self._distance_array[0]
+            dmax = self._distance_array[-1]
+            n = len(self._distance_array)
+            self._lookup_table_filename = (
+                '.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz'
+                .format(dmin, dmax, n))
+        return self._lookup_table_filename
+
+    @cached_lookup_table_filename.setter
+    def cached_lookup_table_filename(self, filename):
+        if isinstance(filename, str):
+            if filename[-4:] != '.npz':
+                filename += '.npz'
+        self._lookup_table_filename = filename
+
+    def load_lookup_table(self, filename):
+        if os.path.exists(filename):
+            loaded_file = dict(np.load(filename))
             if self._test_cached_lookup_table(loaded_file):
+                logger.info('Loaded distance marginalisation lookup table from '
+                            '{}.'.format(filename))
                 return loaded_file
+            else:
+                logger.info('Loaded distance marginalisation lookup table does '
+                            'not match prior')
+                return None
+        elif isinstance(filename, str):
+            logger.info('Distance marginalisation file {} does not '
+                        'exist'.format(filename))
+            return None
         else:
             return None
 
-    @cached_lookup_table.setter
-    def cached_lookup_table(self, lookup_table):
-        np.save(self.cached_lookup_table_filename, lookup_table)
+    def cache_lookup_table(self):
+        np.savez(self.cached_lookup_table_filename,
+                 distance_array=self._distance_array,
+                 prior_array=self.distance_prior_array,
+                 lookup_table=self._dist_margd_loglikelihood_array,
+                 reference_distance=self._ref_dist)
 
-    def _test_cached_lookup_table(self, lookup_table):
-        cond_a = np.all(self._distance_array == lookup_table[0])
-        cond_b = np.all(self.distance_prior_array == lookup_table[1])
-        if cond_a and cond_b:
-            return True
+    def _test_cached_lookup_table(self, loaded_file):
+        cond_a = np.all(self._distance_array == loaded_file['distance_array'])
+        cond_b = np.all(self.distance_prior_array == loaded_file['prior_array'])
+        cond_c = self._ref_dist == loaded_file['reference_distance']
+        return all([cond_a, cond_b, cond_c])
 
     def _create_lookup_table(self):
         """ Make the lookup table """
-
-        self.distance_prior_array = np.array(
-            [self.priors['luminosity_distance'].prob(distance)
-             for distance in self._distance_array])
-
-        # Check if a cached lookup table exists in file
-        cached_lookup_table = self.cached_lookup_table
-        if cached_lookup_table is not None:
-            self._dist_margd_loglikelihood_array = cached_lookup_table[-1]
-            logger.info("Using the cached lookup table {}"
-                        .format(os.path.abspath(self.cached_lookup_table_filename)))
-            return
-
         logger.info('Building lookup table for distance marginalisation.')
 
         self._dist_margd_loglikelihood_array = np.zeros((400, 800))
@@ -370,10 +393,7 @@ class GravitationalWaveTransient(likelihood.Likelihood):
         log_norm = logsumexp(0. / self._distance_array - 0. / self._distance_array ** 2.,
                              b=self.distance_prior_array * self._delta_distance)
         self._dist_margd_loglikelihood_array -= log_norm
-        self.cached_lookup_table = np.array([
-            self._distance_array,
-            self.distance_prior_array,
-            self._dist_margd_loglikelihood_array])
+        self.cache_lookup_table()
 
     def _setup_phase_marginalization(self):
         self._bessel_function_interped = interp1d(
@@ -513,16 +533,26 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
         array, or the array itself.
     priors: dict, bilby.prior.PriorDict
         A dictionary of priors containing at least the geocent_time prior
+    distance_marginalization_lookup_table: (dict, str), optional
+        If a dict, dictionary containing the lookup_table, distance_array,
+        (distance) prior_array, and reference_distance used to construct
+        the table.
+        If a string the name of a file containing these quantities.
+        The lookup table is stored after construction in either the
+        provided string or a default location:
+        '.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz'
 
     """
     def __init__(self, interferometers, waveform_generator, priors,
                  weights=None, linear_matrix=None, quadratic_matrix=None,
-                 distance_marginalization=False, phase_marginalization=False):
+                 distance_marginalization=False, phase_marginalization=False,
+                 distance_marginalization_lookup_table=None):
         GravitationalWaveTransient.__init__(
             self, interferometers=interferometers,
             waveform_generator=waveform_generator, priors=priors,
             distance_marginalization=distance_marginalization,
-            phase_marginalization=phase_marginalization)
+            phase_marginalization=phase_marginalization,
+            distance_marginalization_lookup_table=distance_marginalization_lookup_table)
 
         self.time_samples = np.arange(
             self.priors['geocent_time'].minimum - 0.045,