From c724b30e38dfad7d0229b6f5a1e33287a8c73b70 Mon Sep 17 00:00:00 2001
From: Moritz Huebner <moritz.huebner@ligo.org>
Date: Thu, 24 Jan 2019 21:18:37 -0600
Subject: [PATCH] Resolve "`Result.plot_corner` should still work OOTB if I
 move the hdf5 file"

---
 CHANGELOG.md         |   2 +
 bilby/core/result.py | 125 +++++++++++++++++++++++++++----------------
 test/result_test.py  |  38 ++++++-------
 3 files changed, 99 insertions(+), 66 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index d341b2c6..9809cb44 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -30,6 +30,8 @@
 - Renamed "prior" to "priors" in bilby.gw.likelihood.GravtitationalWaveTransient
   for consistency with bilby.core. **WARNING**: This will break scripts which
   use marginalization.
+- Added `outdir` kwarg for plotting methods in `bilby.core.result.Result`. This makes plotting
+into custom destinations easier.
 - Fixed definition of matched_filter_snr, the interferometer method has become `ifo.inner_product`.
 
 ### Added
diff --git a/bilby/core/result.py b/bilby/core/result.py
index 26a45668..0b96f48a 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -11,6 +11,7 @@ import corner
 import scipy.stats
 import matplotlib
 import matplotlib.pyplot as plt
+from matplotlib import lines as mpllines
 
 from . import utils
 from .utils import (logger, infer_parameters_from_function,
@@ -119,9 +120,10 @@ class Result(object):
             Version information for software used to generate the result. Note,
             this information is generated when the result object is initialized
 
-        Note:
-            All sampling output parameters, e.g. the samples themselves are
-            typically not given at initialisation, but set at a later stage.
+        Note
+        ---------
+        All sampling output parameters, e.g. the samples themselves are
+        typically not given at initialisation, but set at a later stage.
 
         """
 
@@ -151,6 +153,9 @@ class Result(object):
         self.version = version
         self.max_autocorrelation_time = max_autocorrelation_time
 
+        self.prior_values = None
+        self._kde = None
+
     def __str__(self):
         """Print a summary """
         if getattr(self, 'posterior', None) is not None:
@@ -285,7 +290,7 @@ class Result(object):
                 pass
         return dictionary
 
-    def save_to_file(self, overwrite=False):
+    def save_to_file(self, overwrite=False, outdir=None):
         """
         Writes the Result to a deepdish h5 file
 
@@ -294,9 +299,12 @@ class Result(object):
         overwrite: bool, optional
             Whether or not to overwrite an existing result file.
             default=False
+        outdir: str, optional
+            Path to the outdir. Default is the one stored in the result object.
         """
-        file_name = result_file_name(self.outdir, self.label)
-        utils.check_directory_exists_and_if_not_mkdir(self.outdir)
+        outdir = self._safe_outdir_creation(outdir, self.save_to_file)
+        file_name = result_file_name(outdir, self.label)
+
         if os.path.isfile(file_name):
             if overwrite:
                 logger.debug('Removing existing file {}'.format(file_name))
@@ -326,10 +334,10 @@ class Result(object):
             logger.error("\n\n Saving the data has failed with the "
                          "following message:\n {} \n\n".format(e))
 
-    def save_posterior_samples(self):
+    def save_posterior_samples(self, outdir=None):
         """Saves posterior samples to a file"""
-        filename = '{}/{}_posterior_samples.txt'.format(self.outdir, self.label)
-        utils.check_directory_exists_and_if_not_mkdir(self.outdir)
+        outdir = self._safe_outdir_creation(outdir, self.save_posterior_samples)
+        filename = '{}/{}_posterior_samples.txt'.format(outdir, self.label)
         self.posterior.to_csv(filename, index=False, header=True)
 
     def get_latex_labels_from_parameter_keys(self, keys):
@@ -389,7 +397,7 @@ class Result(object):
         return self.posterior_volume / self.prior_volume(priors)
 
     def get_one_dimensional_median_and_error_bar(self, key, fmt='.2f',
-                                                 quantiles=[0.16, 0.84]):
+                                                 quantiles=(0.16, 0.84)):
         """ Calculate the median and error bar for a given key
 
         Parameters
@@ -398,8 +406,8 @@ class Result(object):
             The parameter key for which to calculate the median and error bar
         fmt: str, ('.2f')
             A format string
-        quantiles: list
-            A length-2 list of the lower and upper-quantiles to calculate
+        quantiles: list, tuple
+            A length-2 tuple of the lower and upper-quantiles to calculate
             the errors bars for.
 
         Returns
@@ -428,8 +436,8 @@ class Result(object):
     def plot_single_density(self, key, prior=None, cumulative=False,
                             title=None, truth=None, save=True,
                             file_base_name=None, bins=50, label_fontsize=16,
-                            title_fontsize=16, quantiles=[0.16, 0.84], dpi=300):
-        """ Plot a 1D marginal density, either probablility or cumulative.
+                            title_fontsize=16, quantiles=(0.16, 0.84), dpi=300):
+        """ Plot a 1D marginal density, either probability or cumulative.
 
         Parameters
         ----------
@@ -458,8 +466,8 @@ class Result(object):
             The number of histogram bins
         label_fontsize, title_fontsize: int
             The fontsizes for the labels and titles
-        quantiles: list
-            A length-2 list of the lower and upper-quantiles to calculate
+        quantiles: tuple
+            A length-2 tuple of the lower and upper-quantiles to calculate
             the errors bars for.
         dpi: int
             Dots per inch resolution of the plot
@@ -493,7 +501,7 @@ class Result(object):
 
         if isinstance(prior, Prior):
             theta = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 300)
-            ax.plot(theta, Prior.prob(theta), color='C2')
+            ax.plot(theta, prior.prob(theta), color='C2')
 
         if save:
             fig.tight_layout()
@@ -508,7 +516,8 @@ class Result(object):
 
     def plot_marginals(self, parameters=None, priors=None, titles=True,
                        file_base_name=None, bins=50, label_fontsize=16,
-                       title_fontsize=16, quantiles=[0.16, 0.84], dpi=300):
+                       title_fontsize=16, quantiles=(0.16, 0.84), dpi=300,
+                       outdir=None):
         """ Plot 1D marginal distributions
 
         Parameters
@@ -531,12 +540,14 @@ class Result(object):
         bins: int
             The number of histogram bins
         label_fontsize, title_fontsize: int
-            The fontsizes for the labels and titles
-        quantiles: list
-            A length-2 list of the lower and upper-quantiles to calculate
+            The font sizes for the labels and titles
+        quantiles: tuple
+            A length-2 tuple of the lower and upper-quantiles to calculate
             the errors bars for.
         dpi: int
             Dots per inch resolution of the plot
+        outdir: str, optional
+            Path to the outdir. Default is the one store in the result object.
 
         Returns
         -------
@@ -558,7 +569,8 @@ class Result(object):
                 truths = self.injection_parameters
 
         if file_base_name is None:
-            file_base_name = '{}/{}_1d/'.format(self.outdir, self.label)
+            outdir = self._safe_outdir_creation(outdir, self.plot_marginals)
+            file_base_name = '{}/{}_1d/'.format(outdir, self.label)
             check_directory_exists_and_if_not_mkdir(file_base_name)
 
         if priors is True:
@@ -609,7 +621,8 @@ class Result(object):
         **kwargs:
             Other keyword arguments are passed to `corner.corner`. We set some
             defaults to improve the basic look and feel, but these can all be
-            overridden.
+            overridden. Also optional an 'outdir' argument which can be used
+            to override the outdir set by the absolute path of the result object.
 
         Notes
         -----
@@ -720,8 +733,8 @@ class Result(object):
 
         if save:
             if filename is None:
-                utils.check_directory_exists_and_if_not_mkdir(self.outdir)
-                filename = '{}/{}_corner.png'.format(self.outdir, self.label)
+                outdir = self._safe_outdir_creation(kwargs.get('outdir'), self.plot_corner)
+                filename = '{}/{}_corner.png'.format(outdir, self.label)
             logger.debug('Saving corner plot to {}'.format(filename))
             fig.savefig(filename, dpi=dpi)
             plt.close(fig)
@@ -752,16 +765,16 @@ class Result(object):
             ax.set_ylabel(self.parameter_labels[i])
 
         fig.tight_layout()
-        filename = '{}/{}_walkers.png'.format(self.outdir, self.label)
+        outdir = self._safe_outdir_creation(kwargs.get('outdir'), self.plot_walkers)
+        filename = '{}/{}_walkers.png'.format(outdir, self.label)
         logger.debug('Saving walkers plot to {}'.format('filename'))
-        utils.check_directory_exists_and_if_not_mkdir(self.outdir)
         fig.savefig(filename)
         plt.close(fig)
 
     def plot_with_data(self, model, x, y, ndraws=1000, npoints=1000,
                        xlabel=None, ylabel=None, data_label='data',
                        data_fmt='o', draws_label=None, filename=None,
-                       maxl_label='max likelihood', dpi=300):
+                       maxl_label='max likelihood', dpi=300, outdir=None):
         """ Generate a figure showing the data and fits to the data
 
         Parameters
@@ -787,6 +800,8 @@ class Result(object):
         filename: str
             If given, the filename to use. Otherwise, the filename is generated
             from the outdir and label attributes.
+        outdir: str, optional
+            Path to the outdir. Default is the one store in the result object.
 
         """
 
@@ -825,8 +840,8 @@ class Result(object):
         ax.legend(numpoints=3)
         fig.tight_layout()
         if filename is None:
-            utils.check_directory_exists_and_if_not_mkdir(self.outdir)
-            filename = '{}/{}_plot_with_data'.format(self.outdir, self.label)
+            outdir = self._safe_outdir_creation(outdir, self.plot_with_data)
+            filename = '{}/{}_plot_with_data'.format(outdir, self.label)
         fig.savefig(filename, dpi=dpi)
         plt.close(fig)
 
@@ -944,20 +959,20 @@ class Result(object):
         bool: True if attribute name matches with an attribute of other_object, False otherwise
 
         """
-        A = getattr(self, name, False)
-        B = getattr(other_object, name, False)
-        logger.debug('Checking {} value: {}=={}'.format(name, A, B))
-        if (A is not False) and (B is not False):
-            typeA = type(A)
-            typeB = type(B)
-            if typeA == typeB:
-                if typeA in [str, float, int, dict, list]:
+        a = getattr(self, name, False)
+        b = getattr(other_object, name, False)
+        logger.debug('Checking {} value: {}=={}'.format(name, a, b))
+        if (a is not False) and (b is not False):
+            type_a = type(a)
+            type_b = type(b)
+            if type_a == type_b:
+                if type_a in [str, float, int, dict, list]:
                     try:
-                        return A == B
+                        return a == b
                     except ValueError:
                         return False
-                elif typeA in [np.ndarray]:
-                    return np.all(A == B)
+                elif type_a in [np.ndarray]:
+                    return np.all(a == b)
         return False
 
     @property
@@ -966,9 +981,9 @@ class Result(object):
 
         Uses `scipy.stats.gaussian_kde` to generate the kernel density
         """
-        try:
+        if self._kde:
             return self._kde
-        except AttributeError:
+        else:
             self._kde = scipy.stats.gaussian_kde(
                 self.posterior[self.search_parameter_keys].values.T)
             return self._kde
@@ -998,6 +1013,18 @@ class Result(object):
                           for s in sample]
         return self.kde(ordered_sample)
 
+    def _safe_outdir_creation(self, outdir=None, caller_func=None):
+        if outdir is None:
+            outdir = self.outdir
+        try:
+            utils.check_directory_exists_and_if_not_mkdir(outdir)
+        except PermissionError:
+            raise FileMovedError("Can not write in the out directory.\n"
+                                 "Did you move the here file from another system?\n"
+                                 "Try calling " + caller_func.__name__ + " with the 'outdir' "
+                                 "keyword argument, e.g. " + caller_func.__name__ + "(outdir='.')")
+        return outdir
+
 
 def plot_multiple(results, filename=None, labels=None, colours=None,
                   save=True, evidences=False, **kwargs):
@@ -1050,7 +1077,7 @@ def plot_multiple(results, filename=None, labels=None, colours=None,
         hist_kwargs['color'] = c
         fig = result.plot_corner(fig=fig, save=False, color=c, **kwargs)
         default_filename += '_{}'.format(result.label)
-        lines.append(matplotlib.lines.Line2D([0], [0], color=c))
+        lines.append(mpllines.Line2D([0], [0], color=c))
         default_labels.append(result.label)
 
     # Rescale the axes
@@ -1100,7 +1127,7 @@ def make_pp_plot(results, filename=None, save=True, **kwargs):
     Returns
     -------
     fig:
-        Matplotlib figure
+        matplotlib figure
     """
     fig = plt.figure()
     credible_levels = pd.DataFrame()
@@ -1122,3 +1149,11 @@ def make_pp_plot(results, filename=None, save=True, **kwargs):
             filename = 'outdir/pp.png'
         plt.savefig(filename)
     return fig
+
+
+class ResultError(Exception):
+    """ Base exception for all Result related errors """
+
+
+class FileMovedError(ResultError):
+    """ Exceptions that occur when files have been moved """
diff --git a/test/result_test.py b/test/result_test.py
index 2308bcea..6dbc5a23 100644
--- a/test/result_test.py
+++ b/test/result_test.py
@@ -25,9 +25,9 @@ class TestResult(unittest.TestCase):
             injection_parameters=dict(x=0.5, y=0.5),
             meta_data=dict(test='test'))
 
-        N = 100
-        posterior = pd.DataFrame(dict(x=np.random.normal(0, 1, N),
-                                      y=np.random.normal(0, 1, N)))
+        n = 100
+        posterior = pd.DataFrame(dict(x=np.random.normal(0, 1, n),
+                                      y=np.random.normal(0, 1, n)))
         result.posterior = posterior
         result.log_evidence = 10
         result.log_evidence_err = 11
@@ -66,7 +66,7 @@ class TestResult(unittest.TestCase):
             injection_parameters=dict(x=0.5, y=0.5),
             meta_data=dict(test='test'))
         with self.assertRaises(ValueError):
-            result.priors
+            _ = result.priors
         self.assertEqual(result.parameter_labels, result.search_parameter_keys)
         self.assertEqual(result.parameter_labels_with_unit, result.search_parameter_keys)
 
@@ -102,14 +102,14 @@ class TestResult(unittest.TestCase):
     def test_unset_posterior(self):
         self.result.posterior = None
         with self.assertRaises(ValueError):
-            self.result.posterior
+            _ = self.result.posterior
 
     def test_save_and_load(self):
         self.result.save_to_file()
         loaded_result = bilby.core.result.read_in_result(
             outdir=self.result.outdir, label=self.result.label)
-        self.assertTrue(
-            all(self.result.posterior == loaded_result.posterior))
+        self.assertTrue(pd.DataFrame.equals
+                        (self.result.posterior, loaded_result.posterior))
         self.assertTrue(self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys)
         self.assertTrue(self.result.search_parameter_keys == loaded_result.search_parameter_keys)
         self.assertEqual(self.result.meta_data, loaded_result.meta_data)
@@ -146,31 +146,28 @@ class TestResult(unittest.TestCase):
         filename = '{}/{}_posterior_samples.txt'.format(self.result.outdir, self.result.label)
         self.assertTrue(os.path.isfile(filename))
         df = pd.read_csv(filename)
-        self.assertTrue(all(self.result.posterior == df))
+        self.assertTrue(np.allclose(self.result.posterior.values, df.values))
 
     def test_samples_to_posterior(self):
         self.result.posterior = None
         x = [1, 2, 3]
         y = [4, 6, 8]
-        log_likelihood = [6, 7, 8]
+        log_likelihood = np.array([6, 7, 8])
         self.result.samples = np.array([x, y]).T
         self.result.log_likelihood_evaluations = log_likelihood
         self.result.samples_to_posterior(priors=self.result.priors)
         self.assertTrue(all(self.result.posterior['x'] == x))
         self.assertTrue(all(self.result.posterior['y'] == y))
-        self.assertTrue(
-            all(self.result.posterior['log_likelihood'] == log_likelihood))
-        self.assertTrue(
-            all(self.result.posterior['c'] == self.result.priors['c'].peak))
-        self.assertTrue(
-            all(self.result.posterior['d'] == self.result.priors['d'].peak))
+        self.assertTrue(np.array_equal(self.result.posterior.log_likelihood.values, log_likelihood))
+        self.assertTrue(all(self.result.posterior.c.values == self.result.priors['c'].peak))
+        self.assertTrue(all(self.result.posterior.d.values == self.result.priors['d'].peak))
 
     def test_calculate_prior_values(self):
         self.result.calculate_prior_values(priors=self.result.priors)
         self.assertEqual(len(self.result.posterior), len(self.result.prior_values))
 
     def test_plot_multiple(self):
-        filename='multiple.png'.format(self.result.outdir)
+        filename = 'multiple.png'.format(self.result.outdir)
         bilby.core.result.plot_multiple([self.result, self.result],
                                         filename=filename)
         self.assertTrue(os.path.isfile(filename))
@@ -188,8 +185,8 @@ class TestResult(unittest.TestCase):
         x = np.linspace(0, 1, 10)
         y = np.linspace(0, 1, 10)
 
-        def model(x):
-            return x
+        def model(xx):
+            return xx
         self.result.plot_with_data(model, x, y, ndraws=10)
         self.assertTrue(
             os.path.isfile('{}/{}_plot_with_data.png'.format(
@@ -260,9 +257,8 @@ class TestResult(unittest.TestCase):
         sample = [dict(x=0, y=0.1), dict(x=0.8, y=0)]
         self.assertTrue(
             isinstance(self.result.posterior_probability(sample), np.ndarray))
-        self.assertTrue(
-            all(self.result.posterior_probability(sample)
-                == self.result.kde([[0, 0.1], [0.8, 0]])))
+        self.assertTrue(np.array_equal(self.result.posterior_probability(sample),
+                                       self.result.kde([[0, 0.1], [0.8, 0]])))
 
 
 if __name__ == '__main__':
-- 
GitLab