result.py 9.03 KB
Newer Older
1 2
import logging
import os
3
import numpy as np
4
import deepdish
5
import pandas as pd
6

7 8 9 10 11 12 13
try:
    from chainconsumer import ChainConsumer
except ImportError:
    def ChainConsumer():
        raise ImportError(
            "You do not have the optional module chainconsumer installed")

14

15 16 17 18 19
def result_file_name(outdir, label):
    """ Returns the standard filename used for a result file """
    return '{}/{}_result.h5'.format(outdir, label)


Gregory Ashton's avatar
Gregory Ashton committed
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
def read_in_result(outdir=None, label=None, filename=None):
    """ Read in a saved .h5 data file

    Parameters
    ----------
    outdir, label: str
        If given, use the default naming convention for saved results file
    filename: str
        If given, try to load from this filename

    Returns:
    result: tupak.result.Result instance

    """
    if filename is None:
        filename = result_file_name(outdir, label)
36 37 38
    if os.path.isfile(filename):
        return Result(deepdish.io.load(filename))
    else:
Gregory Ashton's avatar
Gregory Ashton committed
39
        raise ValueError("No information given to load file")
40 41


42
class Result(dict):
43 44 45 46
    def __init__(self, dictionary=None):
        if type(dictionary) is dict:
            for key in dictionary:
                setattr(self, key, dictionary[key])
47 48 49 50 51 52 53 54 55 56 57 58

    def __getattr__(self, name):
        try:
            return self[name]
        except KeyError:
            raise AttributeError(name)

    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    def __repr__(self):
        """Print a summary """
59 60 61 62 63 64 65 66 67
        if hasattr(self, 'samples'):
            return ("nsamples: {:d}\n"
                    "noise_logz: {:6.3f}\n"
                    "logz: {:6.3f} +/- {:6.3f}\n"
                    "log_bayes_factor: {:6.3f} +/- {:6.3f}\n"
                    .format(len(self.samples), self.noise_logz, self.logz,
                            self.logzerr, self.log_bayes_factor, self.logzerr))
        else:
            return ''
68 69

    def save_to_file(self, outdir, label):
70
        """ Writes the Result to a deepdish h5 file """
71
        file_name = result_file_name(outdir, label)
72 73 74 75 76 77 78 79 80
        if os.path.isdir(outdir) is False:
            os.makedirs(outdir)
        if os.path.isfile(file_name):
            logging.info(
                'Renaming existing file {} to {}.old'.format(file_name,
                                                             file_name))
            os.rename(file_name, file_name + '.old')

        logging.info("Saving result to {}".format(file_name))
Gregory Ashton's avatar
Fix #49  
Gregory Ashton committed
81 82 83 84 85 86
        try:
            deepdish.io.save(file_name, self)
        except Exception as e:
            logging.error(
                "\n\n Saving the data has failed with the following message:\n {} \n\n"
                .format(e))
Gregory Ashton's avatar
Gregory Ashton committed
87

Gregory Ashton's avatar
Gregory Ashton committed
88 89 90 91 92 93 94 95 96 97 98 99 100
    def get_latex_labels_from_parameter_keys(self, keys):
        return_list = []
        for k in keys:
            if k in self.search_parameter_keys:
                idx = self.search_parameter_keys.index(k)
                return_list.append(self.parameter_labels[idx])
            elif k in self.parameter_labels:
                return_list.append(k)
            else:
                raise ValueError('key {} not a parameter label or latex label'
                                 .format(k))
        return return_list

Gregory Ashton's avatar
Gregory Ashton committed
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    def plot_corner(self, save=True, **kwargs):
        """ Plot a corner-plot using chain-consumer

        Parameters
        ----------
        save: bool
            If true, save the image using the given label and outdir

        Returns
        -------
        fig:
            A matplotlib figure instance
        """

        # Set some defaults (unless already set)
        kwargs['figsize'] = kwargs.get('figsize', 'GROW')
        if save:
Gregory Ashton's avatar
Gregory Ashton committed
118 119
            filename = '{}/{}_corner.png'.format(self.outdir, self.label)
            kwargs['filename'] = kwargs.get('filename', filename)
Gregory Ashton's avatar
Gregory Ashton committed
120
            logging.info('Saving corner plot to {}'.format(kwargs['filename']))
121
        if self.injection_parameters is not None:
Gregory Ashton's avatar
Gregory Ashton committed
122 123 124 125 126 127 128 129 130 131 132 133 134 135
            # If no truth argument given, set these to the injection params
            injection_parameters = [self.injection_parameters[key]
                                    for key in self.search_parameter_keys]
            kwargs['truth'] = kwargs.get('truth', injection_parameters)

        if type(kwargs.get('truth')) == dict:
            old_keys = kwargs['truth'].keys()
            new_keys = self.get_latex_labels_from_parameter_keys(old_keys)
            for old, new in zip(old_keys, new_keys):
                kwargs['truth'][new] = kwargs['truth'].pop(old)
        if 'parameters' in kwargs:
            kwargs['parameters'] = self.get_latex_labels_from_parameter_keys(
                kwargs['parameters'])

136 137 138 139
        # Check all parameter_labels are a valid string
        for i, label in enumerate(self.parameter_labels):
            if label is None:
                self.parameter_labels[i] = 'Unknown'
Gregory Ashton's avatar
Gregory Ashton committed
140
        c = ChainConsumer()
Gregory Ashton's avatar
Gregory Ashton committed
141 142
        c.add_chain(self.samples, parameters=self.parameter_labels,
                    name=self.label)
Gregory Ashton's avatar
Gregory Ashton committed
143 144 145
        fig = c.plotter.plot(**kwargs)
        return fig

146
    def plot_walks(self, save=True, **kwargs):
moritz's avatar
moritz committed
147
        """ Plot the chain walks using chain-consumer
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163

        Parameters
        ----------
        save: bool
            If true, save the image using the given label and outdir

        Returns
        -------
        fig:
            A matplotlib figure instance
        """

        # Set some defaults (unless already set)
        if save:
            kwargs['filename'] = '{}/{}_walks.png'.format(self.outdir, self.label)
            logging.info('Saving walker plot to {}'.format(kwargs['filename']))
164 165
        if self.injection_parameters is not None:
            kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys]
166 167 168 169 170 171
        c = ChainConsumer()
        c.add_chain(self.samples, parameters=self.parameter_labels)
        fig = c.plotter.plot_walks(**kwargs)
        return fig

    def plot_distributions(self, save=True, **kwargs):
moritz's avatar
moritz committed
172
        """ Plot the chain walks using chain-consumer
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188

        Parameters
        ----------
        save: bool
            If true, save the image using the given label and outdir

        Returns
        -------
        fig:
            A matplotlib figure instance
        """

        # Set some defaults (unless already set)
        if save:
            kwargs['filename'] = '{}/{}_distributions.png'.format(self.outdir, self.label)
            logging.info('Saving distributions plot to {}'.format(kwargs['filename']))
189 190
        if self.injection_parameters is not None:
            kwargs['truth'] = [self.injection_parameters[key] for key in self.search_parameter_keys]
191 192 193 194 195
        c = ChainConsumer()
        c.add_chain(self.samples, parameters=self.parameter_labels)
        fig = c.plotter.plot_distributions(**kwargs)
        return fig

Colm Talbot's avatar
Colm Talbot committed
196 197 198 199 200 201 202 203 204 205 206
    def write_prior_to_file(self, outdir):
        """
        Write the prior distribution to file.

        :return:
        """
        outfile = outdir + '.prior'
        with open(outfile, "w") as prior_file:
            for key in self.prior:
                prior_file.write(self.prior[key])

207 208 209 210 211 212 213 214 215
    def samples_to_data_frame(self):
        """
        Convert array of samples to data frame.

        :return:
        """
        data_frame = pd.DataFrame(self.samples, columns=self.search_parameter_keys)
        self.posterior = data_frame
        for key in self.fixed_parameter_keys:
216
            self.posterior[key] = self.priors[key].sample(len(self.posterior))
Gregory Ashton's avatar
Gregory Ashton committed
217

218 219 220 221 222 223
    def construct_cbc_derived_parameters(self):
        """
        Construct widely used derived parameters of CBCs

        :return:
        """
moritz's avatar
moritz committed
224 225
        self.posterior['mass_chirp'] = (self.posterior.mass_1 * self.posterior.mass_2) ** 0.6 / (
                self.posterior.mass_1 + self.posterior.mass_2) ** 0.2
226
        self.posterior['q'] = self.posterior.mass_2 / self.posterior.mass_1
moritz's avatar
moritz committed
227 228
        self.posterior['eta'] = (self.posterior.mass_1 * self.posterior.mass_2) / (
                self.posterior.mass_1 + self.posterior.mass_2) ** 2
229 230

        self.posterior['chi_eff'] = (self.posterior.a_1 * np.cos(self.posterior.tilt_1)
moritz's avatar
moritz committed
231 232
                                     + self.posterior.q * self.posterior.a_2 * np.cos(self.posterior.tilt_2)) / (
                                                1 + self.posterior.q)
233 234
        self.posterior['chi_p'] = max(self.posterior.a_1 * np.sin(self.posterior.tilt_1),
                                      (4 * self.posterior.q + 3) / (3 * self.posterior.q + 4) * self.posterior.q
moritz's avatar
moritz committed
235
                                      * self.posterior.a_2 * np.sin(self.posterior.tilt_2))
236

237 238
    def check_attribute_match_to_other_object(self, name, other_object):
        """ Check attribute name exists in other_object and is the same """
239
        A = getattr(self, name, False)
240
        B = getattr(other_object, name, False)
241 242 243 244 245 246 247 248 249 250 251
        logging.debug('Checking {} value: {}=={}'.format(name, A, B))
        if (A is not False) and (B is not False):
            typeA = type(A)
            typeB = type(B)
            if typeA == typeB:
                if typeA in [str, float, int, dict, list]:
                    return A == B
                elif typeA in [np.ndarray]:
                    return np.all(A == B)
        return False