From 47e69bec24d25d2454604c7b32c8ba6e5c10bfa6 Mon Sep 17 00:00:00 2001 From: Sylvia Biscoveanu <sylvia.biscoveanu@ligo.org> Date: Thu, 18 Jul 2019 19:09:40 -0500 Subject: [PATCH] Use the `PriorSet` from the first step of PE as the `sampling_prior` for hyper-pe --- bilby/hyper/likelihood.py | 22 ++++++++++++++++++---- bilby/hyper/model.py | 2 +- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/bilby/hyper/likelihood.py b/bilby/hyper/likelihood.py index 3325bca17..6b457ea9c 100644 --- a/bilby/hyper/likelihood.py +++ b/bilby/hyper/likelihood.py @@ -6,6 +6,7 @@ import numpy as np from ..core.likelihood import Likelihood from .model import Model +from ..core.prior import PriorDict class HyperparameterLikelihood(Likelihood): @@ -32,12 +33,17 @@ class HyperparameterLikelihood(Likelihood): """ - def __init__(self, posteriors, hyper_prior, sampling_prior, + def __init__(self, posteriors, hyper_prior, sampling_prior=None, log_evidences=None, max_samples=1e100): if not isinstance(hyper_prior, Model): hyper_prior = Model([hyper_prior]) - if not isinstance(sampling_prior, Model): - sampling_prior = Model([sampling_prior]) + if sampling_prior is None: + if ('log_prior' not in posteriors[0].keys()) and ('prior' not in posteriors[0].keys()): + raise ValueError('Missing both sampling prior function and prior or log_prior ' + 'column in posterior dictionary. Must pass one or the other.') + else: + if not (isinstance(sampling_prior, Model) or isinstance(sampling_prior, PriorDict)): + sampling_prior = Model([sampling_prior]) if log_evidences is not None: self.evidence_factor = np.sum(log_evidences) else: @@ -57,7 +63,7 @@ class HyperparameterLikelihood(Likelihood): def log_likelihood_ratio(self): self.hyper_prior.parameters.update(self.parameters) log_l = np.sum(np.log(np.sum(self.hyper_prior.prob(self.data) / - self.sampling_prior.prob(self.data), axis=-1))) + self.data['prior'], axis=-1))) log_l += self.samples_factor return np.nan_to_num(log_l) @@ -87,10 +93,18 @@ class HyperparameterLikelihood(Likelihood): for posterior in self.posteriors: self.max_samples = min(len(posterior), self.max_samples) data = {key: [] for key in self.posteriors[0]} + if 'log_prior' in data.keys(): + data.pop('log_prior') + if 'prior' not in data.keys(): + data['prior'] = [] logging.debug('Downsampling to {} samples per posterior.'.format( self.max_samples)) for posterior in self.posteriors: temp = posterior.sample(self.max_samples) + if self.sampling_prior is not None: + temp['prior'] = self.sampling_prior.prob(temp, axis=0) + elif 'log_prior' in temp.keys(): + temp['prior'] = np.exp(temp['log_prior']) for key in data: data[key].append(temp[key]) for key in data: diff --git a/bilby/hyper/model.py b/bilby/hyper/model.py index 82b0d8fa5..e5c595349 100644 --- a/bilby/hyper/model.py +++ b/bilby/hyper/model.py @@ -23,7 +23,7 @@ class Model(object): for key in param_keys: self.parameters[key] = None - def prob(self, data): + def prob(self, data, **kwargs): probability = 1.0 for ii, function in enumerate(self.models): probability *= function(data, **self._get_function_parameters(function)) -- GitLab