Commit 8ee7d041 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Basic version with no checking - if cached data exists use it

parent a3354742
...@@ -12,7 +12,24 @@ except ImportError: ...@@ -12,7 +12,24 @@ except ImportError:
"You do not have the optional module chainconsumer installed") "You do not have the optional module chainconsumer installed")
def result_file_name(outdir, label):
""" Returns the standard filename used for a result file """
return '{}/{}_result.h5'.format(outdir, label)
def read_in_result(outdir, label):
filename = result_file_name(outdir, label)
if os.path.isfile(filename):
return Result(deepdish.io.load(filename))
else:
return None
class Result(dict): class Result(dict):
def __init__(self, dictionary=None):
if type(dictionary) is dict:
for key in dictionary:
setattr(self, key, dictionary[key])
def __getattr__(self, name): def __getattr__(self, name):
try: try:
...@@ -33,7 +50,7 @@ class Result(dict): ...@@ -33,7 +50,7 @@ class Result(dict):
self.logzerr)) self.logzerr))
def save_to_file(self, outdir, label): def save_to_file(self, outdir, label):
file_name = '{}/{}_result.h5'.format(outdir, label) file_name = result_file_name(outdir, label)
if os.path.isdir(outdir) is False: if os.path.isdir(outdir) is False:
os.makedirs(outdir) os.makedirs(outdir)
if os.path.isfile(file_name): if os.path.isfile(file_name):
......
...@@ -7,7 +7,7 @@ import sys ...@@ -7,7 +7,7 @@ import sys
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from .result import Result from .result import Result, read_in_result
from .prior import Prior, fill_priors from .prior import Prior, fill_priors
from . import utils from . import utils
from . import prior from . import prior
...@@ -54,6 +54,7 @@ class Sampler(object): ...@@ -54,6 +54,7 @@ class Sampler(object):
self.kwargs = kwargs self.kwargs = kwargs
self.result = result self.result = result
self.check_cached_result()
self.log_summary_for_sampler() self.log_summary_for_sampler()
...@@ -179,6 +180,9 @@ class Sampler(object): ...@@ -179,6 +180,9 @@ class Sampler(object):
def run_sampler(self): def run_sampler(self):
pass pass
def check_cached_result(self):
self.cached_result = read_in_result(self.outdir, self.label)
def log_summary_for_sampler(self): def log_summary_for_sampler(self):
logging.info("Using sampler {} with kwargs {}".format( logging.info("Using sampler {} with kwargs {}".format(
self.__class__.__name__, self.kwargs)) self.__class__.__name__, self.kwargs))
...@@ -405,6 +409,10 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', ...@@ -405,6 +409,10 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
sampler = sampler_class(likelihood, priors, sampler, outdir=outdir, sampler = sampler_class(likelihood, priors, sampler, outdir=outdir,
label=label, use_ratio=use_ratio, label=label, use_ratio=use_ratio,
**sampler_kwargs) **sampler_kwargs)
if sampler.cached_result:
logging.info("Using cached result")
return sampler.cached_result
result = sampler.run_sampler() result = sampler.run_sampler()
result.noise_logz = likelihood.noise_log_likelihood() result.noise_logz = likelihood.noise_log_likelihood()
if use_ratio: if use_ratio:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment