Skip to content
Snippets Groups Projects
Commit 9cab1ac3 authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'master' of git.ligo.org:Monash/tupak

parents 846f6cc1 d0925ada
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -14,7 +14,7 @@ time_duration = 4.
sampling_frequency = 2048.
outdir = 'outdir'
label = 'basic_tutorial'
tupak.utils.setup_logger(outdir=outdir, label=label, log_level="info")
tupak.utils.setup_logger(outdir=outdir, label=label)
np.random.seed(170809)
......
......@@ -6,7 +6,7 @@ from __future__ import division, print_function
import tupak
import numpy as np
tupak.utils.setup_logger(log_level="info")
tupak.utils.setup_logger()
time_duration = 4.
sampling_frequency = 2048.
......
......@@ -7,7 +7,7 @@ from __future__ import division, print_function
import tupak
import numpy as np
tupak.utils.setup_logger(log_level="info")
tupak.utils.setup_logger()
time_duration = 4.
sampling_frequency = 2048.
......
......@@ -9,7 +9,7 @@ import numpy as np
import matplotlib.pyplot as plt
# A few simple setup steps
tupak.utils.setup_logger(log_level="info")
tupak.utils.setup_logger()
label = 'test'
outdir = 'outdir'
......
......@@ -12,7 +12,25 @@ except ImportError:
"You do not have the optional module chainconsumer installed")
def result_file_name(outdir, label):
""" Returns the standard filename used for a result file """
return '{}/{}_result.h5'.format(outdir, label)
def read_in_result(outdir, label):
""" Read in a saved .h5 data file """
filename = result_file_name(outdir, label)
if os.path.isfile(filename):
return Result(deepdish.io.load(filename))
else:
return None
class Result(dict):
def __init__(self, dictionary=None):
if type(dictionary) is dict:
for key in dictionary:
setattr(self, key, dictionary[key])
def __getattr__(self, name):
try:
......@@ -25,15 +43,19 @@ class Result(dict):
def __repr__(self):
"""Print a summary """
return ("nsamples: {:d}\n"
"noise_logz: {:6.3f}\n"
"logz: {:6.3f} +/- {:6.3f}\n"
"log_bayes_factor: {:6.3f} +/- {:6.3f}\n"
.format(len(self.samples), self.noise_logz, self.logz, self.logzerr, self.log_bayes_factor,
self.logzerr))
if hasattr(self, 'samples'):
return ("nsamples: {:d}\n"
"noise_logz: {:6.3f}\n"
"logz: {:6.3f} +/- {:6.3f}\n"
"log_bayes_factor: {:6.3f} +/- {:6.3f}\n"
.format(len(self.samples), self.noise_logz, self.logz,
self.logzerr, self.log_bayes_factor, self.logzerr))
else:
return ''
def save_to_file(self, outdir, label):
file_name = '{}/{}_result.h5'.format(outdir, label)
""" Writes the Result to a deepdish h5 file """
file_name = result_file_name(outdir, label)
if os.path.isdir(outdir) is False:
os.makedirs(outdir)
if os.path.isfile(file_name):
......@@ -98,6 +120,10 @@ class Result(dict):
kwargs['parameters'] = self.get_latex_labels_from_parameter_keys(
kwargs['parameters'])
# Check all parameter_labels are a valid string
for i, label in enumerate(self.parameter_labels):
if label is None:
self.parameter_labels[i] = 'Unknown'
c = ChainConsumer()
c.add_chain(self.samples, parameters=self.parameter_labels,
name=self.label)
......@@ -194,3 +220,19 @@ class Result(dict):
self.posterior['chi_p'] = max(self.posterior.a_1 * np.sin(self.posterior.tilt_1),
(4 * self.posterior.q + 3) / (3 * self.posterior.q + 4) * self.posterior.q
* self.posterior.a_2 * np.sin(self.posterior.tilt_2))
def check_attribute_match_to_other_object(self, name, other_object):
""" Check attribute name exists in other_object and is the same """
A = getattr(self, name, False)
B = getattr(other_object, name, False)
logging.debug('Checking {} value: {}=={}'.format(name, A, B))
if (A is not False) and (B is not False):
typeA = type(A)
typeB = type(B)
if typeA == typeB:
if typeA in [str, float, int, dict, list]:
return A == B
elif typeA in [np.ndarray]:
return np.all(A == B)
return False
......@@ -7,7 +7,7 @@ import sys
import numpy as np
import matplotlib.pyplot as plt
from .result import Result
from .result import Result, read_in_result
from .prior import Prior, fill_priors
from . import utils
from . import prior
......@@ -54,6 +54,7 @@ class Sampler(object):
self.kwargs = kwargs
self.result = result
self.check_cached_result()
self.log_summary_for_sampler()
......@@ -79,6 +80,15 @@ class Sampler(object):
else:
raise TypeError('result must either be a Result or None')
@property
def search_parameter_keys(self):
return self.__search_parameter_keys
@property
def fixed_parameter_keys(self):
return self.__fixed_parameter_keys
@property
def external_sampler(self):
return self.__external_sampler
......@@ -179,9 +189,35 @@ class Sampler(object):
def run_sampler(self):
pass
def check_cached_result(self):
""" Check if the cached data file exists and can be used """
if utils.command_line_args.clean:
logging.debug("Command line argument clean given, forcing rerun")
self.cached_result = None
return
self.cached_result = read_in_result(self.outdir, self.label)
if utils.command_line_args.use_cached:
logging.debug("Command line argument cached given, no cache check performed")
return
logging.debug("Checking cached data")
if self.cached_result:
check_keys = ['search_parameter_keys', 'fixed_parameter_keys',
'kwargs']
use_cache = True
for key in check_keys:
if self.cached_result.check_attribute_match_to_other_object(
key, self) is False:
logging.debug("Cached value {} is unmatched".format(key))
use_cache = False
if use_cache is False:
self.cached_result = None
def log_summary_for_sampler(self):
logging.info("Using sampler {} with kwargs {}".format(
self.__class__.__name__, self.kwargs))
if self.cached_result is None:
logging.info("Using sampler {} with kwargs {}".format(
self.__class__.__name__, self.kwargs))
class Nestle(Sampler):
......@@ -358,7 +394,7 @@ class Ptemcee(Sampler):
def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
sampler='nestle', use_ratio=True, injection_parameters=None,
**sampler_kwargs):
**kwargs):
"""
The primary interface to easy parameter estimation
......@@ -383,7 +419,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
injection_parameters: dict
A dictionary of injection parameters used in creating the data (if
using simulated data). Appended to the result object and saved.
**sampler_kwargs:
**kwargs:
All kwargs are passed directly to the samplers `run` functino
Returns
......@@ -404,7 +440,11 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
sampler_class = globals()[sampler.title()]
sampler = sampler_class(likelihood, priors, sampler, outdir=outdir,
label=label, use_ratio=use_ratio,
**sampler_kwargs)
**kwargs)
if sampler.cached_result:
logging.info("Using cached result")
return sampler.cached_result
result = sampler.run_sampler()
result.noise_logz = likelihood.noise_log_likelihood()
if use_ratio:
......@@ -414,7 +454,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
result.log_bayes_factor = result.logz - result.noise_logz
result.injection_parameters = injection_parameters
result.fixed_parameter_keys = [key for key in priors if isinstance(key, prior.DeltaFunction)]
# result.prior = prior # Removed as this breaks the saving of the data
result.priors = priors
result.kwargs = sampler.kwargs
result.samples_to_data_frame()
result.save_to_file(outdir=outdir, label=label)
return result
......
......@@ -4,6 +4,7 @@ import os
import numpy as np
from math import fmod
from gwpy.timeseries import TimeSeries
import argparse
# Constants
speed_of_light = 299792458.0 # speed of light in m/s
......@@ -281,7 +282,7 @@ def get_vertex_position_geocentric(latitude, longitude, elevation):
return np.array([x_comp, y_comp, z_comp])
def setup_logger(outdir=None, label=None, log_level='info'):
def setup_logger(outdir=None, label=None, log_level=None):
""" Setup logging output: call at the start of the script to use
Parameters
......@@ -298,6 +299,8 @@ def setup_logger(outdir=None, label=None, log_level='info'):
LEVEL = getattr(logging, log_level.upper())
except AttributeError:
raise ValueError('log_level {} not understood'.format(log_level))
elif log_level is None:
LEVEL = command_line_args.log_level
else:
LEVEL = int(log_level)
......@@ -509,4 +512,33 @@ def get_open_strain_data(
return strain
def set_up_command_line_arguments():
parser = argparse.ArgumentParser(
description="Command line interface for tupak scripts")
parser.add_argument("-v", "--verbose", action="store_true",
help=("Increase output verbosity [logging.DEBUG]." +
" Overridden by script level settings"))
parser.add_argument("-q", "--quite", action="store_true",
help=("Decrease output verbosity [logging.WARNING]." +
" Overridden by script level settings"))
parser.add_argument("-c", "--clean", action="store_true",
help="Force clean data, never use cached data")
parser.add_argument("-u", "--use-cached", action="store_true",
help="Force cached data and do not check its validity")
args, _ = parser.parse_known_args()
if args.quite:
args.log_level = logging.WARNING
elif args.verbose:
args.log_level = logging.DEBUG
else:
args.log_level = logging.INFO
return args
command_line_args = set_up_command_line_arguments()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment