Commit 8ee7d041 authored by Gregory Ashton's avatar Gregory Ashton

Basic version with no checking - if cached data exists use it

parent a3354742
......@@ -12,7 +12,24 @@ except ImportError:
"You do not have the optional module chainconsumer installed")
def result_file_name(outdir, label):
""" Returns the standard filename used for a result file """
return '{}/{}_result.h5'.format(outdir, label)
def read_in_result(outdir, label):
filename = result_file_name(outdir, label)
if os.path.isfile(filename):
return Result(deepdish.io.load(filename))
else:
return None
class Result(dict):
def __init__(self, dictionary=None):
if type(dictionary) is dict:
for key in dictionary:
setattr(self, key, dictionary[key])
def __getattr__(self, name):
try:
......@@ -33,7 +50,7 @@ class Result(dict):
self.logzerr))
def save_to_file(self, outdir, label):
file_name = '{}/{}_result.h5'.format(outdir, label)
file_name = result_file_name(outdir, label)
if os.path.isdir(outdir) is False:
os.makedirs(outdir)
if os.path.isfile(file_name):
......
......@@ -7,7 +7,7 @@ import sys
import numpy as np
import matplotlib.pyplot as plt
from .result import Result
from .result import Result, read_in_result
from .prior import Prior, fill_priors
from . import utils
from . import prior
......@@ -54,6 +54,7 @@ class Sampler(object):
self.kwargs = kwargs
self.result = result
self.check_cached_result()
self.log_summary_for_sampler()
......@@ -179,6 +180,9 @@ class Sampler(object):
def run_sampler(self):
pass
def check_cached_result(self):
self.cached_result = read_in_result(self.outdir, self.label)
def log_summary_for_sampler(self):
logging.info("Using sampler {} with kwargs {}".format(
self.__class__.__name__, self.kwargs))
......@@ -405,6 +409,10 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
sampler = sampler_class(likelihood, priors, sampler, outdir=outdir,
label=label, use_ratio=use_ratio,
**sampler_kwargs)
if sampler.cached_result:
logging.info("Using cached result")
return sampler.cached_result
result = sampler.run_sampler()
result.noise_logz = likelihood.noise_log_likelihood()
if use_ratio:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment