Skip to content
Snippets Groups Projects
Commit 47e69bec authored by Sylvia Biscoveanu's avatar Sylvia Biscoveanu Committed by Colm Talbot
Browse files

Use the `PriorSet` from the first step of PE as the `sampling_prior` for hyper-pe

parent 0b1aa313
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,7 @@ import numpy as np
from ..core.likelihood import Likelihood
from .model import Model
from ..core.prior import PriorDict
class HyperparameterLikelihood(Likelihood):
......@@ -32,12 +33,17 @@ class HyperparameterLikelihood(Likelihood):
"""
def __init__(self, posteriors, hyper_prior, sampling_prior,
def __init__(self, posteriors, hyper_prior, sampling_prior=None,
log_evidences=None, max_samples=1e100):
if not isinstance(hyper_prior, Model):
hyper_prior = Model([hyper_prior])
if not isinstance(sampling_prior, Model):
sampling_prior = Model([sampling_prior])
if sampling_prior is None:
if ('log_prior' not in posteriors[0].keys()) and ('prior' not in posteriors[0].keys()):
raise ValueError('Missing both sampling prior function and prior or log_prior '
'column in posterior dictionary. Must pass one or the other.')
else:
if not (isinstance(sampling_prior, Model) or isinstance(sampling_prior, PriorDict)):
sampling_prior = Model([sampling_prior])
if log_evidences is not None:
self.evidence_factor = np.sum(log_evidences)
else:
......@@ -57,7 +63,7 @@ class HyperparameterLikelihood(Likelihood):
def log_likelihood_ratio(self):
self.hyper_prior.parameters.update(self.parameters)
log_l = np.sum(np.log(np.sum(self.hyper_prior.prob(self.data) /
self.sampling_prior.prob(self.data), axis=-1)))
self.data['prior'], axis=-1)))
log_l += self.samples_factor
return np.nan_to_num(log_l)
......@@ -87,10 +93,18 @@ class HyperparameterLikelihood(Likelihood):
for posterior in self.posteriors:
self.max_samples = min(len(posterior), self.max_samples)
data = {key: [] for key in self.posteriors[0]}
if 'log_prior' in data.keys():
data.pop('log_prior')
if 'prior' not in data.keys():
data['prior'] = []
logging.debug('Downsampling to {} samples per posterior.'.format(
self.max_samples))
for posterior in self.posteriors:
temp = posterior.sample(self.max_samples)
if self.sampling_prior is not None:
temp['prior'] = self.sampling_prior.prob(temp, axis=0)
elif 'log_prior' in temp.keys():
temp['prior'] = np.exp(temp['log_prior'])
for key in data:
data[key].append(temp[key])
for key in data:
......
......@@ -23,7 +23,7 @@ class Model(object):
for key in param_keys:
self.parameters[key] = None
def prob(self, data):
def prob(self, data, **kwargs):
probability = 1.0
for ii, function in enumerate(self.models):
probability *= function(data, **self._get_function_parameters(function))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment