From 5b1a57782cbc30025e1b32779f411cfa725c8d4f Mon Sep 17 00:00:00 2001 From: Bruce Edelman <bruce.edelman@ligo.org> Date: Thu, 19 Mar 2020 09:38:10 -0700 Subject: [PATCH] change from lru_cached to just a plain dictionary approach for ease of application --- bilby/core/prior/dict.py | 45 ++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index 6e57ce0bc..b715e12fc 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -3,7 +3,6 @@ from io import open as ioopen import json import numpy as np import os -from functools import lru_cache from future.utils import iteritems from matplotlib.cbook import flatten @@ -42,6 +41,7 @@ class PriorDict(dict): self.from_file(filename) elif dictionary is not None: raise ValueError("PriorDict input dictionary not understood") + self.cached_normalizations= {} self.convert_floats_to_delta_functions() @@ -379,23 +379,27 @@ class PriorDict(dict): if not isinstance(self[key], Constraint)} return all_samples - @lru_cache() def normalize_constraint_factor(self, keys): - min_accept = 1000 - sampling_chunk = 5000 - samples = self.sample_subset(keys=keys, size=sampling_chunk) - keep = np.atleast_1d(self.evaluate_constraints(samples)) - if len(keep) == 1: - return 1 - all_samples = {key: np.array([]) for key in keys} - _first_key = list(all_samples.keys())[0] - while np.count_nonzero(keep) < min_accept: + if repr(keys) in self.cached_norm_factors.keys(): + return self._cached_normalizations[repr(keys)] + else: + min_accept = 1000 + sampling_chunk = 5000 samples = self.sample_subset(keys=keys, size=sampling_chunk) - for key in samples: - all_samples[key] = np.hstack( - [all_samples[key], samples[key].flatten()]) - keep = np.array(self.evaluate_constraints(all_samples), dtype=bool) - return len(keep) / np.count_nonzero(keep) + keep = np.atleast_1d(self.evaluate_constraints(samples)) + if len(keep) == 1: + return 1 + all_samples = {key: np.array([]) for key in keys} + _first_key = list(all_samples.keys())[0] + while np.count_nonzero(keep) < min_accept: + samples = self.sample_subset(keys=keys, size=sampling_chunk) + for key in samples: + all_samples[key] = np.hstack( + [all_samples[key], samples[key].flatten()]) + keep = np.array(self.evaluate_constraints(all_samples), dtype=bool) + factor = len(keep) / np.count_nonzero(keep) + self._cached_normalizations[repr(keys)] = factor + return factor def prob(self, sample, **kwargs): """ @@ -511,15 +515,6 @@ class PriorDict(dict): """ return self.__class__(dictionary=dict(self)) - def __key(self): - return tuple((k, self[k]) for k in sorted(self)) - - def __hash__(self): - return hash(self.__key()) - - def __eq__(self, other): - return self.__key() == other.__key() - class PriorSet(PriorDict): -- GitLab