Skip to content
Snippets Groups Projects

Resolve "Introduce conditional prior sets"

Merged Moritz Huebner requested to merge 270-introduce-correlated-prior-sets into master
1 file
+ 36
17
Compare changes
  • Side-by-side
  • Inline
+ 36
17
@@ -501,9 +501,10 @@ class ConditionalPriorDict(PriorDict):
self._rescale_indexes = []
self._least_recently_rescaled_keys = []
super(ConditionalPriorDict, self).__init__(dictionary=dictionary, filename=filename)
self.resolved = False
self._resolve_conditions()
def _resolve_conditions(self):
def _resolve_conditions(self, disable_log=False):
""" Resolves how variables depend on each other and automatically sorts them into the right order """
conditioned_keys_unsorted = [key for key in self.keys() if hasattr(self[key], 'condition_func')]
self._unconditional_keys = [key for key in self.keys() if not hasattr(self[key], 'condition_func')]
@@ -518,9 +519,12 @@ class ConditionalPriorDict(PriorDict):
self._sorted_keys = self._unconditional_keys.copy()
self._sorted_keys.extend(self.conditional_keys)
self.resolved = True
if len(conditioned_keys_unsorted) != 0:
raise IllegalConditionsException('This set contains unresolvable conditions')
self.resolved = False
if not disable_log:
logger.warning('This set contains unresolvable conditions')
def _check_conditions_resolved(self, key, sampled_keys):
""" Checks if all required variables have already been sampled so we can sample this key """
@@ -532,14 +536,26 @@ class ConditionalPriorDict(PriorDict):
def sample_subset(self, keys=iter([]), size=None):
self.convert_floats_to_delta_functions()
sorted_keys = self.get_subset_sorted_keys(keys)
subset_dict = ConditionalPriorDict({key: self[key] for key in keys})
if not subset_dict.resolved:
raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
res = dict()
for key in sorted_keys:
res[key] = self[key].sample(size=size, **self._get_required_variables(key))
for key in subset_dict.sorted_keys:
res[key] = subset_dict[key].sample(size=size, **subset_dict.get_required_variables(key))
return res
def _get_required_variables(self, key):
""" Returns the required variables to sample a given conditional key."""
def get_required_variables(self, key):
""" Returns the required variables to sample a given conditional key.
Parameters
----------
key : str
Name of the key that we want to know the required variables for
Returns
----------
dict: key/value pairs of the required variables
"""
return {k: self[k].least_recently_sampled for k in getattr(self[key], 'required_variables', [])}
def prob(self, sample, **kwargs):
@@ -557,9 +573,10 @@ class ConditionalPriorDict(PriorDict):
float: Joint probability of all individual sample probabilities
"""
self._check_resolved()
for key, value in sample.items():
self[key].least_recently_sampled = value
res = [self[key].prob(sample[key], **self._get_required_variables(key)) for key in sample]
res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample]
return np.product(res, **kwargs)
def ln_prob(self, sample, axis=None):
@@ -577,9 +594,10 @@ class ConditionalPriorDict(PriorDict):
float: Joint log probability of all the individual sample probabilities
"""
self._check_resolved()
for key, value in sample.items():
self[key].least_recently_sampled = value
res = [self[key].ln_prob(sample[key], **self._get_required_variables(key)) for key in sample]
res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample]
return np.sum(res, axis=axis)
def rescale(self, keys, theta):
@@ -596,8 +614,9 @@ class ConditionalPriorDict(PriorDict):
-------
list: List of floats containing the rescaled sample
"""
result = OrderedDict(dict())
self._check_resolved()
self._update_rescale_keys(keys)
result = OrderedDict(dict())
for key, index in zip(self._rescale_keys, self._rescale_indexes):
required_variables = {k: result[k] for k in getattr(self[key], 'required_variables', [])}
result[key] = self[key].rescale(theta[index], **required_variables)
@@ -614,10 +633,9 @@ class ConditionalPriorDict(PriorDict):
self._rescale_keys = np.append(unconditional_keys, conditional_keys)
self._rescale_indexes = np.append(unconditional_idxs, conditional_idxs)
def get_subset_sorted_keys(self, subset_keys):
subset_unconditional_keys = [key for key in self.unconditional_keys if key in subset_keys]
subset_conditional_keys = [key for key in self.conditional_keys if key in subset_keys]
return subset_unconditional_keys + subset_conditional_keys
def _check_resolved(self):
if not self.resolved:
raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
@property
def conditional_keys(self):
@@ -633,7 +651,7 @@ class ConditionalPriorDict(PriorDict):
def __setitem__(self, key, value):
super(ConditionalPriorDict, self).__setitem__(key, value)
self._resolve_conditions()
self._resolve_conditions(disable_log=True)
def create_default_prior(name, default_priors_file=None):
@@ -3465,7 +3483,7 @@ def conditional_prior_factory(prior_class):
return ConditionalPrior
ConditionalPrior = conditional_prior_factory(Prior) # Only for testing purposes
ConditionalBasePrior = conditional_prior_factory(Prior) # Only for testing purposes
ConditionalUniform = conditional_prior_factory(Uniform)
ConditionalDeltaFunction = conditional_prior_factory(DeltaFunction)
ConditionalPowerLaw = conditional_prior_factory(PowerLaw)
@@ -3484,7 +3502,8 @@ ConditionalLogistic = conditional_prior_factory(Logistic)
ConditionalCauchy = conditional_prior_factory(Cauchy)
ConditionalGamma = conditional_prior_factory(Gamma)
ConditionalChiSquared = conditional_prior_factory(ChiSquared)
ConditionalConditionalInterped = conditional_prior_factory(Interped)
ConditionalFermiDirac = conditional_prior_factory(FermiDirac)
ConditionalInterped = conditional_prior_factory(Interped)
class PriorException(Exception):
Loading