Skip to content
Snippets Groups Projects
Commit de6ab27a authored by Gregory Ashton's avatar Gregory Ashton Committed by Moritz Huebner
Browse files

Improvements to the core prior class

Cherry picked results from !905:
- Add ignore errstate to avoid warning messages
- Add helper properties for the width, non_fixed_keys, fixed_keys, and constraint_keys
- Add constraints to the ConditionalPrior
parent fef914d0
No related branches found
No related tags found
1 merge request!944Improvements to the core prior class
......@@ -148,8 +148,11 @@ class PowerLaw(Prior):
normalising = (1 + self.alpha) / (self.maximum ** (1 + self.alpha) -
self.minimum ** (1 + self.alpha))
return (self.alpha * np.nan_to_num(np.log(val)) + np.log(normalising)) + np.log(
1. * self.is_in_prior_range(val))
with np.errstate(divide='ignore', invalid='ignore'):
ln_in_range = np.log(1. * self.is_in_prior_range(val))
ln_p = self.alpha * np.nan_to_num(np.log(val)) + np.log(normalising)
return ln_p + ln_in_range
def cdf(self, val):
if self.alpha == -1:
......
......@@ -185,7 +185,8 @@ class Prior(object):
np.nan
"""
return np.log(self.prob(val))
with np.errstate(divide='ignore'):
return np.log(self.prob(val))
def is_in_prior_range(self, val):
"""Returns True if val is in the prior boundaries, zero otherwise
......@@ -313,6 +314,10 @@ class Prior(object):
def maximum(self, maximum):
self._maximum = maximum
@property
def width(self):
return self.maximum - self.minimum
def get_instantiation_dict(self):
subclass_args = infer_args_from_method(self.__init__)
dict_with_properties = get_dict_with_properties(self)
......
......@@ -367,6 +367,28 @@ class PriorDict(dict):
logger.debug('{} not a known prior.'.format(key))
return samples
@property
def non_fixed_keys(self):
keys = self.keys()
keys = [k for k in keys if isinstance(self[k], Prior)]
keys = [k for k in keys if self[k].is_fixed is False]
keys = [k for k in keys if k not in self.constraint_keys]
return keys
@property
def fixed_keys(self):
return [
k for k, p in self.items()
if (p.is_fixed and k not in self.constraint_keys)
]
@property
def constraint_keys(self):
return [
k for k, p in self.items()
if isinstance(p, Constraint)
]
def sample_subset_constrained(self, keys=iter([]), size=None):
if size is None or size == 1:
while True:
......@@ -432,6 +454,9 @@ class PriorDict(dict):
prob = np.product([self[key].prob(sample[key])
for key in sample], **kwargs)
return self.check_prob(sample, prob)
def check_prob(self, sample, prob):
ratio = self.normalize_constraint_factor(tuple(sample.keys()))
if np.all(prob == 0.):
return prob
......@@ -465,7 +490,9 @@ class PriorDict(dict):
"""
ln_prob = np.sum([self[key].ln_prob(sample[key])
for key in sample], axis=axis)
return self.check_ln_prob(sample, ln_prob)
def check_ln_prob(self, sample, ln_prob):
ratio = self.normalize_constraint_factor(tuple(sample.keys()))
if np.all(np.isinf(ln_prob)):
return ln_prob
......@@ -648,7 +675,8 @@ class ConditionalPriorDict(PriorDict):
for key, value in sample.items():
self[key].least_recently_sampled = value
res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample]
return np.product(res, **kwargs)
prob = np.product(res, **kwargs)
return self.check_prob(sample, prob)
def ln_prob(self, sample, axis=None):
"""
......@@ -669,7 +697,8 @@ class ConditionalPriorDict(PriorDict):
for key, value in sample.items():
self[key].least_recently_sampled = value
res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample]
return np.sum(res, axis=axis)
ln_prob = np.sum(res, axis=axis)
return self.check_ln_prob(sample, ln_prob)
def rescale(self, keys, theta):
"""Rescale samples from unit cube to prior
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment