diff --git a/bilby/hyper/likelihood.py b/bilby/hyper/likelihood.py
index 3325bca17c82e5d2629f52332db9b474e4d2025d..6b457ea9c2180d1c3026098859ddab9d33340ecd 100644
--- a/bilby/hyper/likelihood.py
+++ b/bilby/hyper/likelihood.py
@@ -6,6 +6,7 @@ import numpy as np
 
 from ..core.likelihood import Likelihood
 from .model import Model
+from ..core.prior import PriorDict
 
 
 class HyperparameterLikelihood(Likelihood):
@@ -32,12 +33,17 @@ class HyperparameterLikelihood(Likelihood):
 
     """
 
-    def __init__(self, posteriors, hyper_prior, sampling_prior,
+    def __init__(self, posteriors, hyper_prior, sampling_prior=None,
                  log_evidences=None, max_samples=1e100):
         if not isinstance(hyper_prior, Model):
             hyper_prior = Model([hyper_prior])
-        if not isinstance(sampling_prior, Model):
-            sampling_prior = Model([sampling_prior])
+        if sampling_prior is None:
+            if ('log_prior' not in posteriors[0].keys()) and ('prior' not in posteriors[0].keys()):
+                raise ValueError('Missing both sampling prior function and prior or log_prior '
+                                 'column in posterior dictionary. Must pass one or the other.')
+        else:
+            if not (isinstance(sampling_prior, Model) or isinstance(sampling_prior, PriorDict)):
+                sampling_prior = Model([sampling_prior])
         if log_evidences is not None:
             self.evidence_factor = np.sum(log_evidences)
         else:
@@ -57,7 +63,7 @@ class HyperparameterLikelihood(Likelihood):
     def log_likelihood_ratio(self):
         self.hyper_prior.parameters.update(self.parameters)
         log_l = np.sum(np.log(np.sum(self.hyper_prior.prob(self.data) /
-                       self.sampling_prior.prob(self.data), axis=-1)))
+                       self.data['prior'], axis=-1)))
         log_l += self.samples_factor
         return np.nan_to_num(log_l)
 
@@ -87,10 +93,18 @@ class HyperparameterLikelihood(Likelihood):
         for posterior in self.posteriors:
             self.max_samples = min(len(posterior), self.max_samples)
         data = {key: [] for key in self.posteriors[0]}
+        if 'log_prior' in data.keys():
+            data.pop('log_prior')
+        if 'prior' not in data.keys():
+            data['prior'] = []
         logging.debug('Downsampling to {} samples per posterior.'.format(
             self.max_samples))
         for posterior in self.posteriors:
             temp = posterior.sample(self.max_samples)
+            if self.sampling_prior is not None:
+                temp['prior'] = self.sampling_prior.prob(temp, axis=0)
+            elif 'log_prior' in temp.keys():
+                temp['prior'] = np.exp(temp['log_prior'])
             for key in data:
                 data[key].append(temp[key])
         for key in data:
diff --git a/bilby/hyper/model.py b/bilby/hyper/model.py
index 82b0d8fa508617073e70dbe37a1abc91d369eaa4..e5c595349712807fd4dcd9af91627d3a5de835c3 100644
--- a/bilby/hyper/model.py
+++ b/bilby/hyper/model.py
@@ -23,7 +23,7 @@ class Model(object):
             for key in param_keys:
                 self.parameters[key] = None
 
-    def prob(self, data):
+    def prob(self, data, **kwargs):
         probability = 1.0
         for ii, function in enumerate(self.models):
             probability *= function(data, **self._get_function_parameters(function))