Skip to content
Snippets Groups Projects
Commit ce7c7491 authored by Sylvia Biscoveanu's avatar Sylvia Biscoveanu Committed by Moritz
Browse files

Replace all calls to sample individual keys with call to sample from correlated prior dict

parent 1055de12
No related branches found
No related tags found
1 merge request!332Resolve "Introduce conditional prior sets"
......@@ -5,7 +5,7 @@ import numpy as np
from pandas import DataFrame
from ..utils import logger, command_line_args, Counter
from ..prior import Prior, PriorDict, DeltaFunction, Constraint
from ..prior import Prior, PriorDict, CorrelatedPriorDict, DeltaFunction, Constraint
from ..result import Result, read_in_result
......@@ -251,13 +251,19 @@ class Sampler(object):
AttributeError
prior can't be sampled.
"""
for key in self.priors:
if isinstance(self.priors[key], Constraint):
continue
if isinstance(self.priors, CorrelatedPriorDict):
try:
self.likelihood.parameters[key] = self.priors[key].sample()
self.likelihood.parameters = self.priors.sample()
except AttributeError as e:
logger.warning('Cannot sample from {}, {}'.format(key, e))
logger.warning('Cannot sample from prior, {}'.format(e))
else:
for key in self.priors:
if isinstance(self.priors[key], Constraint):
continue
try:
self.likelihood.parameters[key] = self.priors[key].sample()
except AttributeError as e:
logger.warning('Cannot sample from {}, {}'.format(key, e))
def _verify_parameters(self):
""" Evaluate a set of parameters drawn from the prior
......@@ -276,9 +282,12 @@ class Sampler(object):
"Your sampling set contains redundant parameters.")
self._check_if_priors_can_be_sampled()
try:
if isinstance(self.priors, CorrelatedPriorDict):
theta = self.priors.sample()
else:
theta = [self.priors[key].sample()
for key in self._search_parameter_keys]
try:
self.log_likelihood(theta)
except TypeError as e:
raise TypeError(
......@@ -298,8 +307,11 @@ class Sampler(object):
t1 = datetime.datetime.now()
for _ in range(n_evaluations):
theta = [self.priors[key].sample()
for key in self._search_parameter_keys]
if isinstance(self.priors, CorrelatedPriorDict):
theta = self.priors.sample()
else:
theta = [self.priors[key].sample()
for key in self._search_parameter_keys]
self.log_likelihood(theta)
total_time = (datetime.datetime.now() - t1).total_seconds()
self._log_likelihood_eval_time = total_time / n_evaluations
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment