Skip to content
Snippets Groups Projects
Commit 089643b2 authored by Bruce Edelman's avatar Bruce Edelman
Browse files

try out the caching non-analytical solution

parent c04966e3
No related branches found
No related tags found
1 merge request!704Resolve #430 (Add normalisation flag to constrained prior)
......@@ -438,12 +438,10 @@ class Prior(object):
class Constraint(Prior):
def __init__(self, minimum, maximum, name=None, latex_label=None,
unit=None, normalisation=None):
def __init__(self, minimum, maximum, name=None, latex_label=None, unit=None):
super(Constraint, self).__init__(minimum=minimum, maximum=maximum, name=name,
latex_label=latex_label, unit=unit)
self._is_fixed = True
self.normalisation = normalisation
def prob(self, val):
return (val > self.minimum) & (val < self.maximum)
......
......@@ -3,6 +3,7 @@ from io import open as ioopen
import json
import numpy as np
import os
from functools import lru_cache
from future.utils import iteritems
from matplotlib.cbook import flatten
......@@ -378,6 +379,12 @@ class PriorDict(dict):
if not isinstance(self[key], Constraint)}
return all_samples
@lru_cache()
def normalize_constraint_factor(self, keys):
samples = self.sample_subset(keys=keys, size=1000)
keep = np.array(self.evaluate_constraints(samples))
return len(keep) / np.count_nonzero(keep)
def prob(self, sample, **kwargs):
"""
......@@ -396,6 +403,10 @@ class PriorDict(dict):
prob = np.product([self[key].prob(sample[key])
for key in sample], **kwargs)
ratio = 1
if np.any(isinstance([self[key] for key in sample], Constraint)):
ratio = self.normalize_constraint_factor(sample.keys())
if np.all(prob == 0.):
return prob
else:
......@@ -407,17 +418,7 @@ class PriorDict(dict):
else:
constrained_prob = np.zeros_like(prob)
keep = np.array(self.evaluate_constraints(sample), dtype=bool)
constrained_prob[keep] = prob[keep]
out_sample = self.conversion_function(sample)
norms_flagged = False
for key in self:
if isinstance(self[key], Constraint) and key in out_sample:
if self[key].normalisation is not None:
constrained_prob *= self[key].normalisation
if not norms_flagged:
norms_flagged = True
if not norms_flagged:
constrained_prob *= len(keep) / np.count_nonzero(keep)
constrained_prob[keep] = prob[keep]*ratio
return constrained_prob
def ln_prob(self, sample, axis=None):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment