import logging import os import numpy as np import deepdish import pandas as pd import tupak try: from chainconsumer import ChainConsumer except ImportError: def ChainConsumer(): raise ImportError( "You do not have the optional module chainconsumer installed") def result_file_name(outdir, label): """ Returns the standard filename used for a result file """ return '{}/{}_result.h5'.format(outdir, label) def read_in_result(outdir=None, label=None, filename=None): """ Read in a saved .h5 data file Parameters ---------- outdir, label: str If given, use the default naming convention for saved results file filename: str If given, try to load from this filename Returns: result: tupak.result.Result instance """ if filename is None: filename = result_file_name(outdir, label) if os.path.isfile(filename): return Result(deepdish.io.load(filename)) else: raise ValueError("No information given to load file") class Result(dict): def __init__(self, dictionary=None): if type(dictionary) is dict: for key in dictionary: setattr(self, key, dictionary[key]) def __getattr__(self, name): try: return self[name] except KeyError: raise AttributeError(name) __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ def __repr__(self): """Print a summary """ if hasattr(self, 'samples'): return ("nsamples: {:d}\n" "noise_logz: {:6.3f}\n" "logz: {:6.3f} +/- {:6.3f}\n" "log_bayes_factor: {:6.3f} +/- {:6.3f}\n" .format(len(self.samples), self.noise_logz, self.logz, self.logzerr, self.log_bayes_factor, self.logzerr)) else: return '' def get_result_dictionary(self): return dict(self) def save_to_file(self, outdir, label): """ Writes the Result to a deepdish h5 file """ file_name = result_file_name(outdir, label) if os.path.isdir(outdir) is False: os.makedirs(outdir) if os.path.isfile(file_name): logging.info( 'Renaming existing file {} to {}.old'.format(file_name, file_name)) os.rename(file_name, file_name + '.old') logging.info("Saving result to {}".format(file_name)) try: deepdish.io.save(file_name, self.get_result_dictionary()) except Exception as e: logging.error( "\n\n Saving the data has failed with the following message:\n {} \n\n" .format(e)) def get_latex_labels_from_parameter_keys(self, keys): return_list = [] for k in keys: if k in self.search_parameter_keys: idx = self.search_parameter_keys.index(k) return_list.append(self.parameter_labels[idx]) elif k in self.parameter_labels: return_list.append(k) else: raise ValueError('key {} not a parameter label or latex label' .format(k)) return return_list def plot_corner(self, save=True, **kwargs): """ Plot a corner-plot using chain-consumer Parameters ---------- save: bool If true, save the image using the given label and outdir Returns ------- fig: A matplotlib figure instance """ # Set some defaults (unless already set) kwargs['figsize'] = kwargs.get('figsize', 'GROW') if save: filename = '{}/{}_corner.png'.format(self.outdir, self.label) kwargs['filename'] = kwargs.get('filename', filename) logging.info('Saving corner plot to {}'.format(kwargs['filename'])) if getattr(self, 'injection_parameters', None) is not None: # If no truth argument given, set these to the injection params injection_parameters = [self.injection_parameters[key] for key in self.search_parameter_keys] kwargs['truth'] = kwargs.get('truth', injection_parameters) if type(kwargs.get('truth')) == dict: old_keys = kwargs['truth'].keys() new_keys = self.get_latex_labels_from_parameter_keys(old_keys) for old, new in zip(old_keys, new_keys): kwargs['truth'][new] = kwargs['truth'].pop(old) if 'parameters' in kwargs: kwargs['parameters'] = self.get_latex_labels_from_parameter_keys( kwargs['parameters']) # Check all parameter_labels are a valid string for i, label in enumerate(self.parameter_labels): if label is None: self.parameter_labels[i] = 'Unknown' c = ChainConsumer() c.add_chain(self.samples, parameters=self.parameter_labels, name=self.label) fig = c.plotter.plot(**kwargs) return fig def plot_walks(self, save=True, **kwargs): """ Plot the chain walks using chain-consumer Parameters ---------- save: bool If true, save the image using the given label and outdir Returns ------- fig: A matplotlib figure instance """ # Set some defaults (unless already set) if save: kwargs['filename'] = '{}/{}_walks.png'.format(self.outdir, self.label) logging.info('Saving walker plot to {}'.format(kwargs['filename'])) if self.injection_parameters is not None: kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys] c = ChainConsumer() c.add_chain(self.samples, parameters=self.parameter_labels) fig = c.plotter.plot_walks(**kwargs) return fig def plot_distributions(self, save=True, **kwargs): """ Plot the chain walks using chain-consumer Parameters ---------- save: bool If true, save the image using the given label and outdir Returns ------- fig: A matplotlib figure instance """ # Set some defaults (unless already set) if save: kwargs['filename'] = '{}/{}_distributions.png'.format(self.outdir, self.label) logging.info('Saving distributions plot to {}'.format(kwargs['filename'])) if self.injection_parameters is not None: kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys] c = ChainConsumer() c.add_chain(self.samples, parameters=self.parameter_labels) fig = c.plotter.plot_distributions(**kwargs) return fig def write_prior_to_file(self, outdir): """ Write the prior distribution to file. :return: """ outfile = outdir + '.prior' with open(outfile, "w") as prior_file: for key in self.prior: prior_file.write(self.prior[key]) def samples_to_data_frame(self, likelihood=None, priors=None, conversion_function=None): """ Convert array of samples to data frame. Parameters ---------- likelihood: tupak.likelihood.GravitationalWaveTransient GravitationalWaveTransient used for sampling. priors: dict Dictionary of prior object, used to fill in delta function priors. conversion_function: function Function which adds in extra parameters to the data frame, should take the data_frame, likelihood and prior as arguments. """ data_frame = pd.DataFrame(self.samples, columns=self.search_parameter_keys) if conversion_function is not None: conversion_function(data_frame, likelihood, priors) self.posterior = data_frame def construct_cbc_derived_parameters(self): """ Construct widely used derived parameters of CBCs :return: """ self.posterior['mass_chirp'] = (self.posterior.mass_1 * self.posterior.mass_2) ** 0.6 / ( self.posterior.mass_1 + self.posterior.mass_2) ** 0.2 self.posterior['q'] = self.posterior.mass_2 / self.posterior.mass_1 self.posterior['eta'] = (self.posterior.mass_1 * self.posterior.mass_2) / ( self.posterior.mass_1 + self.posterior.mass_2) ** 2 self.posterior['chi_eff'] = (self.posterior.a_1 * np.cos(self.posterior.tilt_1) + self.posterior.q * self.posterior.a_2 * np.cos(self.posterior.tilt_2)) / ( 1 + self.posterior.q) self.posterior['chi_p'] = max(self.posterior.a_1 * np.sin(self.posterior.tilt_1), (4 * self.posterior.q + 3) / (3 * self.posterior.q + 4) * self.posterior.q * self.posterior.a_2 * np.sin(self.posterior.tilt_2)) def check_attribute_match_to_other_object(self, name, other_object): """ Check attribute name exists in other_object and is the same """ A = getattr(self, name, False) B = getattr(other_object, name, False) logging.debug('Checking {} value: {}=={}'.format(name, A, B)) if (A is not False) and (B is not False): typeA = type(A) typeB = type(B) if typeA == typeB: if typeA in [str, float, int, dict, list]: return A == B elif typeA in [np.ndarray]: return np.all(A == B) return False