Commit 8aca08a6 authored by Gregory Ashton's avatar Gregory Ashton

Adds in some basic level of checking of the cached data

- If the sampler kwargs are the same
- If the search parameter keys are the same

MISSING: check the priors. This requires a more detailed look into how
to compare our prior distribtions.
parent 8ee7d041
Pipeline #19211 passed with stages
in 7 minutes and 51 seconds
......@@ -42,12 +42,15 @@ class Result(dict):
def __repr__(self):
"""Print a summary """
return ("nsamples: {:d}\n"
"noise_logz: {:6.3f}\n"
"logz: {:6.3f} +/- {:6.3f}\n"
"log_bayes_factor: {:6.3f} +/- {:6.3f}\n"
.format(len(self.samples), self.noise_logz, self.logz, self.logzerr, self.log_bayes_factor,
if hasattr(self, 'samples'):
return ("nsamples: {:d}\n"
"noise_logz: {:6.3f}\n"
"logz: {:6.3f} +/- {:6.3f}\n"
"log_bayes_factor: {:6.3f} +/- {:6.3f}\n"
.format(len(self.samples), self.noise_logz, self.logz,
self.logzerr, self.log_bayes_factor, self.logzerr))
return ''
def save_to_file(self, outdir, label):
file_name = result_file_name(outdir, label)
......@@ -211,3 +214,21 @@ class Result(dict):
self.posterior['chi_p'] = max(self.posterior.a_1 * np.sin(self.posterior.tilt_1),
(4 * self.posterior.q + 3) / (3 * self.posterior.q + 4) * self.posterior.q
* self.posterior.a_2 * np.sin(self.posterior.tilt_2))
def check_attribute_match_to_other_result(self, name, other_result):
""" Check attribute name exists in other_result and is the same """
A = getattr(self, name, False)
B = getattr(other_result, name, False)
logging.debug('Checking {} value: {}=={}'.format(name, A, B))
if (A is not False) and (B is not False):
typeA = type(A)
typeB = type(B)
if typeA == typeB:
if typeA in [str, float, int, dict, list]:
return A == B
elif typeA in [np.ndarray]:
return np.all(A == B)
return False
......@@ -46,11 +46,11 @@ class Sampler(object):
self.use_ratio = use_ratio
self.external_sampler = external_sampler
self.__search_parameter_keys = []
self.__fixed_parameter_keys = []
self.search_parameter_keys = []
self.fixed_parameter_keys = []
self.ndim = len(self.__search_parameter_keys)
self.ndim = len(self.search_parameter_keys)
self.kwargs = kwargs
self.result = result
......@@ -69,10 +69,10 @@ class Sampler(object):
def result(self, result):
if result is None:
self.__result = Result()
self.__result.search_parameter_keys = self.__search_parameter_keys
self.__result.search_parameter_keys = self.search_parameter_keys
self.__result.parameter_labels = [
self.priors[k].latex_label for k in
self.__result.label = self.label
self.__result.outdir = self.outdir
elif type(result) is Result:
......@@ -123,17 +123,17 @@ class Sampler(object):
for key in self.priors:
if isinstance(self.priors[key], Prior) is True \
and self.priors[key].is_fixed is False:
elif isinstance(self.priors[key], Prior) \
and self.priors[key].is_fixed is True:
self.likelihood.parameters[key] = \
self.fixed_parameter_keys.append(key)"Search parameters:")
for key in self.__search_parameter_keys:
for key in self.search_parameter_keys:' {} ~ {}'.format(key, self.priors[key]))
for key in self.__fixed_parameter_keys:
for key in self.fixed_parameter_keys:' {} = {}'.format(key, self.priors[key].peak))
def verify_parameters(self):
......@@ -144,15 +144,15 @@ class Sampler(object):
"Source model does not contain keys {}".format(unmatched_keys))
def prior_transform(self, theta):
return [self.priors[key].rescale(t) for key, t in zip(self.__search_parameter_keys, theta)]
return [self.priors[key].rescale(t) for key, t in zip(self.search_parameter_keys, theta)]
def log_prior(self, theta):
return np.sum(
[np.log(self.priors[key].prob(t)) for key, t in
zip(self.__search_parameter_keys, theta)])
zip(self.search_parameter_keys, theta)])
def log_likelihood(self, theta):
for i, k in enumerate(self.__search_parameter_keys):
for i, k in enumerate(self.search_parameter_keys):
self.likelihood.parameters[k] = theta[i]
if self.use_ratio:
return self.likelihood.log_likelihood_ratio()
......@@ -170,7 +170,7 @@ class Sampler(object):
draw = np.array([self.priors[key].sample()
for key in self.__search_parameter_keys])
for key in self.search_parameter_keys])
if np.isinf(self.log_likelihood(draw)):'Prior draw {} has inf likelihood'.format(draw))
if np.isinf(self.log_prior(draw)):
......@@ -181,7 +181,19 @@ class Sampler(object):
def check_cached_result(self):
logging.debug("Checking cached data")
self.cached_result = read_in_result(self.outdir, self.label)
if self.cached_result:
check_keys = ['search_parameter_keys', 'fixed_parameter_keys',
use_cache = True
for key in check_keys:
if self.cached_result.check_attribute_match_to_other_result(
key, self) is False:
logging.debug("Cached value {} is unmatched".format(key))
use_cache = False
if use_cache is False:
self.cached_result = None
def log_summary_for_sampler(self):"Using sampler {} with kwargs {}".format(
......@@ -362,7 +374,7 @@ class Ptemcee(Sampler):
def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
sampler='nestle', use_ratio=True, injection_parameters=None,
The primary interface to easy parameter estimation
......@@ -387,7 +399,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
injection_parameters: dict
A dictionary of injection parameters used in creating the data (if
using simulated data). Appended to the result object and saved.
All kwargs are passed directly to the samplers `run` functino
......@@ -408,7 +420,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
sampler_class = globals()[sampler.title()]
sampler = sampler_class(likelihood, priors, sampler, outdir=outdir,
label=label, use_ratio=use_ratio,
if sampler.cached_result:"Using cached result")
return sampler.cached_result
......@@ -422,7 +434,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
result.log_bayes_factor = result.logz - result.noise_logz
result.injection_parameters = injection_parameters
result.fixed_parameter_keys = [key for key in priors if isinstance(key, prior.DeltaFunction)]
# result.prior = prior # Removed as this breaks the saving of the data
result.priors = priors
result.kwargs = sampler.kwargs
result.save_to_file(outdir=outdir, label=label)
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment