Skip to content
Snippets Groups Projects

Resolve "Introduce conditional prior sets"

Merged Moritz Huebner requested to merge 270-introduce-correlated-prior-sets into master
Compare and
8 files
+ 832
37
Compare changes
  • Side-by-side
  • Inline
Files
8
@@ -5,7 +5,7 @@ import numpy as np
from pandas import DataFrame
from ..utils import logger, command_line_args, Counter
from ..prior import Prior, PriorDict, DeltaFunction, Constraint
from ..prior import Prior, PriorDict, ConditionalPriorDict, DeltaFunction, Constraint
from ..result import Result, read_in_result
@@ -251,13 +251,19 @@ class Sampler(object):
AttributeError
prior can't be sampled.
"""
for key in self.priors:
if isinstance(self.priors[key], Constraint):
continue
if isinstance(self.priors, ConditionalPriorDict):
try:
self.likelihood.parameters[key] = self.priors[key].sample()
self.likelihood.parameters = self.priors.sample()
except AttributeError as e:
logger.warning('Cannot sample from {}, {}'.format(key, e))
logger.warning('Cannot sample from prior, {}'.format(e))
else:
for key in self.priors:
if isinstance(self.priors[key], Constraint):
continue
try:
self.likelihood.parameters[key] = self.priors[key].sample()
except AttributeError as e:
logger.warning('Cannot sample from {}, {}'.format(key, e))
def _verify_parameters(self):
""" Evaluate a set of parameters drawn from the prior
@@ -276,9 +282,13 @@ class Sampler(object):
"Your sampling set contains redundant parameters.")
self._check_if_priors_can_be_sampled()
try:
if isinstance(self.priors, ConditionalPriorDict):
theta = self.priors.sample()
theta = [theta[key] for key in self._search_parameter_keys]
else:
theta = [self.priors[key].sample()
for key in self._search_parameter_keys]
try:
self.log_likelihood(theta)
except TypeError as e:
raise TypeError(
@@ -298,8 +308,12 @@ class Sampler(object):
t1 = datetime.datetime.now()
for _ in range(n_evaluations):
theta = [self.priors[key].sample()
for key in self._search_parameter_keys]
if isinstance(self.priors, ConditionalPriorDict):
theta = self.priors.sample()
theta = list(theta.values())
else:
theta = [self.priors[key].sample()
for key in self._search_parameter_keys]
self.log_likelihood(theta)
total_time = (datetime.datetime.now() - t1).total_seconds()
self._log_likelihood_eval_time = total_time / n_evaluations
Loading