diff --git a/examples/injection_examples/basic_tutorial.py b/examples/injection_examples/basic_tutorial.py index dd990be32f856d842a49386009d90852fd5bec04..83e1556f4c8b0315d83a4459139040c99516ef03 100644 --- a/examples/injection_examples/basic_tutorial.py +++ b/examples/injection_examples/basic_tutorial.py @@ -14,7 +14,7 @@ time_duration = 4. sampling_frequency = 2048. outdir = 'outdir' label = 'basic_tutorial' -tupak.utils.setup_logger(outdir=outdir, label=label, log_level="info") +tupak.utils.setup_logger(outdir=outdir, label=label) np.random.seed(170809) diff --git a/examples/injection_examples/how_to_specify_the_prior.py b/examples/injection_examples/how_to_specify_the_prior.py index 500cf91a9635a9c0034beb7e2e96db97570622bf..8829a71f7846b7cf1130d0d606b13f0b37ce2f8c 100644 --- a/examples/injection_examples/how_to_specify_the_prior.py +++ b/examples/injection_examples/how_to_specify_the_prior.py @@ -6,7 +6,7 @@ from __future__ import division, print_function import tupak import numpy as np -tupak.utils.setup_logger(log_level="info") +tupak.utils.setup_logger() time_duration = 4. sampling_frequency = 2048. diff --git a/examples/injection_examples/marginalized_likelihood.py b/examples/injection_examples/marginalized_likelihood.py index 7fcd0acee31f1c19b266007f6cc98d3e4e77d024..67ad222560cad941c2d23a3a6de6987a95a51ca5 100644 --- a/examples/injection_examples/marginalized_likelihood.py +++ b/examples/injection_examples/marginalized_likelihood.py @@ -7,7 +7,7 @@ from __future__ import division, print_function import tupak import numpy as np -tupak.utils.setup_logger(log_level="info") +tupak.utils.setup_logger() time_duration = 4. sampling_frequency = 2048. diff --git a/examples/other_examples/alternative_likelihoods.py b/examples/other_examples/alternative_likelihoods.py index dd735af89c49280262d861dd6e8d26550f9f3b7e..0d31bd09ddf4ee9c582917568bf52182f16d6284 100644 --- a/examples/other_examples/alternative_likelihoods.py +++ b/examples/other_examples/alternative_likelihoods.py @@ -9,7 +9,7 @@ import numpy as np import matplotlib.pyplot as plt # A few simple setup steps -tupak.utils.setup_logger(log_level="info") +tupak.utils.setup_logger() label = 'test' outdir = 'outdir' diff --git a/tupak/result.py b/tupak/result.py index 60f6165079755848eedb8d13b209931b2cfe0a13..d14d37561477bb9d6ce8784f43512d85ab94316f 100644 --- a/tupak/result.py +++ b/tupak/result.py @@ -12,7 +12,25 @@ except ImportError: "You do not have the optional module chainconsumer installed") +def result_file_name(outdir, label): + """ Returns the standard filename used for a result file """ + return '{}/{}_result.h5'.format(outdir, label) + + +def read_in_result(outdir, label): + """ Read in a saved .h5 data file """ + filename = result_file_name(outdir, label) + if os.path.isfile(filename): + return Result(deepdish.io.load(filename)) + else: + return None + + class Result(dict): + def __init__(self, dictionary=None): + if type(dictionary) is dict: + for key in dictionary: + setattr(self, key, dictionary[key]) def __getattr__(self, name): try: @@ -25,15 +43,19 @@ class Result(dict): def __repr__(self): """Print a summary """ - return ("nsamples: {:d}\n" - "noise_logz: {:6.3f}\n" - "logz: {:6.3f} +/- {:6.3f}\n" - "log_bayes_factor: {:6.3f} +/- {:6.3f}\n" - .format(len(self.samples), self.noise_logz, self.logz, self.logzerr, self.log_bayes_factor, - self.logzerr)) + if hasattr(self, 'samples'): + return ("nsamples: {:d}\n" + "noise_logz: {:6.3f}\n" + "logz: {:6.3f} +/- {:6.3f}\n" + "log_bayes_factor: {:6.3f} +/- {:6.3f}\n" + .format(len(self.samples), self.noise_logz, self.logz, + self.logzerr, self.log_bayes_factor, self.logzerr)) + else: + return '' def save_to_file(self, outdir, label): - file_name = '{}/{}_result.h5'.format(outdir, label) + """ Writes the Result to a deepdish h5 file """ + file_name = result_file_name(outdir, label) if os.path.isdir(outdir) is False: os.makedirs(outdir) if os.path.isfile(file_name): @@ -98,6 +120,10 @@ class Result(dict): kwargs['parameters'] = self.get_latex_labels_from_parameter_keys( kwargs['parameters']) + # Check all parameter_labels are a valid string + for i, label in enumerate(self.parameter_labels): + if label is None: + self.parameter_labels[i] = 'Unknown' c = ChainConsumer() c.add_chain(self.samples, parameters=self.parameter_labels, name=self.label) @@ -194,3 +220,19 @@ class Result(dict): self.posterior['chi_p'] = max(self.posterior.a_1 * np.sin(self.posterior.tilt_1), (4 * self.posterior.q + 3) / (3 * self.posterior.q + 4) * self.posterior.q * self.posterior.a_2 * np.sin(self.posterior.tilt_2)) + + def check_attribute_match_to_other_object(self, name, other_object): + """ Check attribute name exists in other_object and is the same """ + A = getattr(self, name, False) + B = getattr(other_object, name, False) + logging.debug('Checking {} value: {}=={}'.format(name, A, B)) + if (A is not False) and (B is not False): + typeA = type(A) + typeB = type(B) + if typeA == typeB: + if typeA in [str, float, int, dict, list]: + return A == B + elif typeA in [np.ndarray]: + return np.all(A == B) + return False + diff --git a/tupak/sampler.py b/tupak/sampler.py index 281da7b52a1493f73a6af30a6aa7bc180434f5c6..0f965bb4eee0d8214012bf02c39f206fd2817781 100644 --- a/tupak/sampler.py +++ b/tupak/sampler.py @@ -7,7 +7,7 @@ import sys import numpy as np import matplotlib.pyplot as plt -from .result import Result +from .result import Result, read_in_result from .prior import Prior, fill_priors from . import utils from . import prior @@ -54,6 +54,7 @@ class Sampler(object): self.kwargs = kwargs self.result = result + self.check_cached_result() self.log_summary_for_sampler() @@ -79,6 +80,15 @@ class Sampler(object): else: raise TypeError('result must either be a Result or None') + @property + def search_parameter_keys(self): + return self.__search_parameter_keys + + @property + def fixed_parameter_keys(self): + return self.__fixed_parameter_keys + + @property def external_sampler(self): return self.__external_sampler @@ -179,9 +189,35 @@ class Sampler(object): def run_sampler(self): pass + def check_cached_result(self): + """ Check if the cached data file exists and can be used """ + + if utils.command_line_args.clean: + logging.debug("Command line argument clean given, forcing rerun") + self.cached_result = None + return + self.cached_result = read_in_result(self.outdir, self.label) + if utils.command_line_args.use_cached: + logging.debug("Command line argument cached given, no cache check performed") + return + + logging.debug("Checking cached data") + if self.cached_result: + check_keys = ['search_parameter_keys', 'fixed_parameter_keys', + 'kwargs'] + use_cache = True + for key in check_keys: + if self.cached_result.check_attribute_match_to_other_object( + key, self) is False: + logging.debug("Cached value {} is unmatched".format(key)) + use_cache = False + if use_cache is False: + self.cached_result = None + def log_summary_for_sampler(self): - logging.info("Using sampler {} with kwargs {}".format( - self.__class__.__name__, self.kwargs)) + if self.cached_result is None: + logging.info("Using sampler {} with kwargs {}".format( + self.__class__.__name__, self.kwargs)) class Nestle(Sampler): @@ -358,7 +394,7 @@ class Ptemcee(Sampler): def run_sampler(likelihood, priors=None, label='label', outdir='outdir', sampler='nestle', use_ratio=True, injection_parameters=None, - **sampler_kwargs): + **kwargs): """ The primary interface to easy parameter estimation @@ -383,7 +419,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', injection_parameters: dict A dictionary of injection parameters used in creating the data (if using simulated data). Appended to the result object and saved. - **sampler_kwargs: + **kwargs: All kwargs are passed directly to the samplers `run` functino Returns @@ -404,7 +440,11 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', sampler_class = globals()[sampler.title()] sampler = sampler_class(likelihood, priors, sampler, outdir=outdir, label=label, use_ratio=use_ratio, - **sampler_kwargs) + **kwargs) + if sampler.cached_result: + logging.info("Using cached result") + return sampler.cached_result + result = sampler.run_sampler() result.noise_logz = likelihood.noise_log_likelihood() if use_ratio: @@ -414,7 +454,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', result.log_bayes_factor = result.logz - result.noise_logz result.injection_parameters = injection_parameters result.fixed_parameter_keys = [key for key in priors if isinstance(key, prior.DeltaFunction)] - # result.prior = prior # Removed as this breaks the saving of the data + result.priors = priors + result.kwargs = sampler.kwargs result.samples_to_data_frame() result.save_to_file(outdir=outdir, label=label) return result diff --git a/tupak/utils.py b/tupak/utils.py index a6405f5883a583d40967d01b102eabaad37cc3ae..71eb761fa4ad05a896e724cca3abc97df5ff8fa8 100644 --- a/tupak/utils.py +++ b/tupak/utils.py @@ -4,6 +4,7 @@ import os import numpy as np from math import fmod from gwpy.timeseries import TimeSeries +import argparse # Constants speed_of_light = 299792458.0 # speed of light in m/s @@ -281,7 +282,7 @@ def get_vertex_position_geocentric(latitude, longitude, elevation): return np.array([x_comp, y_comp, z_comp]) -def setup_logger(outdir=None, label=None, log_level='info'): +def setup_logger(outdir=None, label=None, log_level=None): """ Setup logging output: call at the start of the script to use Parameters @@ -298,6 +299,8 @@ def setup_logger(outdir=None, label=None, log_level='info'): LEVEL = getattr(logging, log_level.upper()) except AttributeError: raise ValueError('log_level {} not understood'.format(log_level)) + elif log_level is None: + LEVEL = command_line_args.log_level else: LEVEL = int(log_level) @@ -509,4 +512,33 @@ def get_open_strain_data( return strain +def set_up_command_line_arguments(): + parser = argparse.ArgumentParser( + description="Command line interface for tupak scripts") + parser.add_argument("-v", "--verbose", action="store_true", + help=("Increase output verbosity [logging.DEBUG]." + + " Overridden by script level settings")) + parser.add_argument("-q", "--quite", action="store_true", + help=("Decrease output verbosity [logging.WARNING]." + + " Overridden by script level settings")) + parser.add_argument("-c", "--clean", action="store_true", + help="Force clean data, never use cached data") + parser.add_argument("-u", "--use-cached", action="store_true", + help="Force cached data and do not check its validity") + args, _ = parser.parse_known_args() + + if args.quite: + args.log_level = logging.WARNING + elif args.verbose: + args.log_level = logging.DEBUG + else: + args.log_level = logging.INFO + + return args + + +command_line_args = set_up_command_line_arguments() + + +