Skip to content
Snippets Groups Projects
Commit 4e3e6892 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch 'core-prior-improvements' into 'master'

Improvements to the core prior class

See merge request !944
parents 3fbc60da de6ab27a
No related branches found
No related tags found
1 merge request!944Improvements to the core prior class
Pipeline #218629 passed with warnings
......@@ -148,8 +148,11 @@ class PowerLaw(Prior):
normalising = (1 + self.alpha) / (self.maximum ** (1 + self.alpha) -
self.minimum ** (1 + self.alpha))
return (self.alpha * np.nan_to_num(np.log(val)) + np.log(normalising)) + np.log(
1. * self.is_in_prior_range(val))
with np.errstate(divide='ignore', invalid='ignore'):
ln_in_range = np.log(1. * self.is_in_prior_range(val))
ln_p = self.alpha * np.nan_to_num(np.log(val)) + np.log(normalising)
return ln_p + ln_in_range
def cdf(self, val):
if self.alpha == -1:
......
......@@ -185,7 +185,8 @@ class Prior(object):
np.nan
"""
return np.log(self.prob(val))
with np.errstate(divide='ignore'):
return np.log(self.prob(val))
def is_in_prior_range(self, val):
"""Returns True if val is in the prior boundaries, zero otherwise
......@@ -313,6 +314,10 @@ class Prior(object):
def maximum(self, maximum):
self._maximum = maximum
@property
def width(self):
return self.maximum - self.minimum
def get_instantiation_dict(self):
subclass_args = infer_args_from_method(self.__init__)
dict_with_properties = get_dict_with_properties(self)
......
......@@ -377,6 +377,28 @@ class PriorDict(dict):
logger.debug('{} not a known prior.'.format(key))
return samples
@property
def non_fixed_keys(self):
keys = self.keys()
keys = [k for k in keys if isinstance(self[k], Prior)]
keys = [k for k in keys if self[k].is_fixed is False]
keys = [k for k in keys if k not in self.constraint_keys]
return keys
@property
def fixed_keys(self):
return [
k for k, p in self.items()
if (p.is_fixed and k not in self.constraint_keys)
]
@property
def constraint_keys(self):
return [
k for k, p in self.items()
if isinstance(p, Constraint)
]
def sample_subset_constrained(self, keys=iter([]), size=None):
if size is None or size == 1:
while True:
......@@ -442,6 +464,9 @@ class PriorDict(dict):
prob = np.product([self[key].prob(sample[key])
for key in sample], **kwargs)
return self.check_prob(sample, prob)
def check_prob(self, sample, prob):
ratio = self.normalize_constraint_factor(tuple(sample.keys()))
if np.all(prob == 0.):
return prob
......@@ -475,7 +500,9 @@ class PriorDict(dict):
"""
ln_prob = np.sum([self[key].ln_prob(sample[key])
for key in sample], axis=axis)
return self.check_ln_prob(sample, ln_prob)
def check_ln_prob(self, sample, ln_prob):
ratio = self.normalize_constraint_factor(tuple(sample.keys()))
if np.all(np.isinf(ln_prob)):
return ln_prob
......@@ -658,7 +685,8 @@ class ConditionalPriorDict(PriorDict):
for key, value in sample.items():
self[key].least_recently_sampled = value
res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample]
return np.product(res, **kwargs)
prob = np.product(res, **kwargs)
return self.check_prob(sample, prob)
def ln_prob(self, sample, axis=None):
"""
......@@ -679,7 +707,8 @@ class ConditionalPriorDict(PriorDict):
for key, value in sample.items():
self[key].least_recently_sampled = value
res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample]
return np.sum(res, axis=axis)
ln_prob = np.sum(res, axis=axis)
return self.check_ln_prob(sample, ln_prob)
def rescale(self, keys, theta):
"""Rescale samples from unit cube to prior
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment