From 47e69bec24d25d2454604c7b32c8ba6e5c10bfa6 Mon Sep 17 00:00:00 2001
From: Sylvia Biscoveanu <sylvia.biscoveanu@ligo.org>
Date: Thu, 18 Jul 2019 19:09:40 -0500
Subject: [PATCH] Use the `PriorSet` from the first step of PE as the
 `sampling_prior` for hyper-pe

---
 bilby/hyper/likelihood.py | 22 ++++++++++++++++++----
 bilby/hyper/model.py      |  2 +-
 2 files changed, 19 insertions(+), 5 deletions(-)

diff --git a/bilby/hyper/likelihood.py b/bilby/hyper/likelihood.py
index 3325bca17..6b457ea9c 100644
--- a/bilby/hyper/likelihood.py
+++ b/bilby/hyper/likelihood.py
@@ -6,6 +6,7 @@ import numpy as np
 
 from ..core.likelihood import Likelihood
 from .model import Model
+from ..core.prior import PriorDict
 
 
 class HyperparameterLikelihood(Likelihood):
@@ -32,12 +33,17 @@ class HyperparameterLikelihood(Likelihood):
 
     """
 
-    def __init__(self, posteriors, hyper_prior, sampling_prior,
+    def __init__(self, posteriors, hyper_prior, sampling_prior=None,
                  log_evidences=None, max_samples=1e100):
         if not isinstance(hyper_prior, Model):
             hyper_prior = Model([hyper_prior])
-        if not isinstance(sampling_prior, Model):
-            sampling_prior = Model([sampling_prior])
+        if sampling_prior is None:
+            if ('log_prior' not in posteriors[0].keys()) and ('prior' not in posteriors[0].keys()):
+                raise ValueError('Missing both sampling prior function and prior or log_prior '
+                                 'column in posterior dictionary. Must pass one or the other.')
+        else:
+            if not (isinstance(sampling_prior, Model) or isinstance(sampling_prior, PriorDict)):
+                sampling_prior = Model([sampling_prior])
         if log_evidences is not None:
             self.evidence_factor = np.sum(log_evidences)
         else:
@@ -57,7 +63,7 @@ class HyperparameterLikelihood(Likelihood):
     def log_likelihood_ratio(self):
         self.hyper_prior.parameters.update(self.parameters)
         log_l = np.sum(np.log(np.sum(self.hyper_prior.prob(self.data) /
-                       self.sampling_prior.prob(self.data), axis=-1)))
+                       self.data['prior'], axis=-1)))
         log_l += self.samples_factor
         return np.nan_to_num(log_l)
 
@@ -87,10 +93,18 @@ class HyperparameterLikelihood(Likelihood):
         for posterior in self.posteriors:
             self.max_samples = min(len(posterior), self.max_samples)
         data = {key: [] for key in self.posteriors[0]}
+        if 'log_prior' in data.keys():
+            data.pop('log_prior')
+        if 'prior' not in data.keys():
+            data['prior'] = []
         logging.debug('Downsampling to {} samples per posterior.'.format(
             self.max_samples))
         for posterior in self.posteriors:
             temp = posterior.sample(self.max_samples)
+            if self.sampling_prior is not None:
+                temp['prior'] = self.sampling_prior.prob(temp, axis=0)
+            elif 'log_prior' in temp.keys():
+                temp['prior'] = np.exp(temp['log_prior'])
             for key in data:
                 data[key].append(temp[key])
         for key in data:
diff --git a/bilby/hyper/model.py b/bilby/hyper/model.py
index 82b0d8fa5..e5c595349 100644
--- a/bilby/hyper/model.py
+++ b/bilby/hyper/model.py
@@ -23,7 +23,7 @@ class Model(object):
             for key in param_keys:
                 self.parameters[key] = None
 
-    def prob(self, data):
+    def prob(self, data, **kwargs):
         probability = 1.0
         for ii, function in enumerate(self.models):
             probability *= function(data, **self._get_function_parameters(function))
-- 
GitLab