From 130324e260832b0c5cb42510d64701407d46b221 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Sun, 28 Oct 2018 21:23:20 -0500
Subject: [PATCH] Overhaul of Result() implementation

---
 CHANGELOG.md                       |  12 ++
 bilby/core/prior.py                |   4 +-
 bilby/core/result.py               | 314 +++++++++++++++++++----------
 bilby/core/sampler/__init__.py     |  14 +-
 bilby/core/sampler/base_sampler.py |  36 ++--
 bilby/core/sampler/dynesty.py      |  10 +-
 bilby/core/sampler/pymc3.py        |  18 +-
 test/result_test.py                | 175 +++++++++++++++-
 test/sampler_test.py               |  11 -
 9 files changed, 427 insertions(+), 167 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3c8e52e2..4344928d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,18 @@ Changes currently on master, but not under a tag.
 
 - Fixed a bug which caused `Interferometer.detector_tensor` not to update when `latitude`, `longitude`, `xarm_azimuth`, `yarm_azimuth`, `xarm_tilt`, `yarm_tilt` were updated.
 
+### Changes
+- Switch the ordering the key-word arguments in `result.read_in_result` to put
+  `filename` first. This allows users to quickly read in results by filename
+- Result object no longer a child of `dict`. Additionally, the list of
+  attributes and saved attributes is standardised
+- The above changes effect the saving of posteriors. Users can expect that
+  opening files made in python 2(3) which where written in 3(2) may no longer
+  work. It was felt that the overheads of maintaining cross-version
+  compatibility were too much. Note, working in only python 2 or 3, we do not
+  expect users to encounter issues.
+- Intermediate data products of samples, nested_samples are stored in the h5
+
 ## [0.3.1] 2018-11-06
 
 ### Changes
diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index eb92472f..34b7fae9 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -37,6 +37,8 @@ class PriorDict(OrderedDict):
         elif dictionary is not None:
             raise ValueError("PriorDict input dictionary not understood")
 
+        self.convert_floats_to_delta_functions()
+
     def to_file(self, outdir, label):
         """ Write the prior distribution to file.
 
@@ -81,7 +83,7 @@ class PriorDict(OrderedDict):
             if isinstance(val, str):
                 try:
                     prior = eval(val)
-                    if isinstance(prior, Prior):
+                    if isinstance(prior, (Prior, float, int, str)):
                         val = prior
                 except (NameError, SyntaxError, TypeError):
                     logger.debug(
diff --git a/bilby/core/result.py b/bilby/core/result.py
index 01e37b05..df4374b5 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -30,15 +30,15 @@ def result_file_name(outdir, label):
     return '{}/{}_result.h5'.format(outdir, label)
 
 
-def read_in_result(outdir=None, label=None, filename=None):
+def read_in_result(filename=None, outdir=None, label=None):
     """ Read in a saved .h5 data file
 
     Parameters
     ----------
-    outdir, label: str
-        If given, use the default naming convention for saved results file
     filename: str
         If given, try to load from this filename
+    outdir, label: str
+        If given, use the default naming convention for saved results file
 
     Returns
     -------
@@ -51,73 +51,96 @@ def read_in_result(outdir=None, label=None, filename=None):
 
     """
     if filename is None:
-        filename = result_file_name(outdir, label)
-    elif (outdir is None or label is None) and filename is None:
-        raise ValueError("No information given to load file")
+        if (outdir is None) and (label is None):
+            raise ValueError("No information given to load file")
+        else:
+            filename = result_file_name(outdir, label)
     if os.path.isfile(filename):
-        return Result(deepdish.io.load(filename))
+        return Result(**deepdish.io.load(filename))
     else:
-        raise ValueError("No result found")
-
-
-class Result(dict):
-    def __init__(self, dictionary=None):
-        """ A class to save the results of the sampling run.
+        raise IOError("No result '{}' found".format(filename))
+
+
+class Result(object):
+    def __init__(self, label='no_label', outdir='.', sampler=None,
+                 search_parameter_keys=None, fixed_parameter_keys=None,
+                 priors=None, sampler_kwargs=None, injection_parameters=None,
+                 meta_data=None, posterior=None, samples=None,
+                 nested_samples=None, log_evidence=np.nan,
+                 log_evidence_err=np.nan, log_noise_evidence=np.nan,
+                 log_bayes_factor=np.nan, log_likelihood_evaluations=None,
+                 sampling_time=None, nburn=None, walkers=None,
+                 max_autocorrelation_time=None, parameter_labels=None,
+                 parameter_labels_with_unit=None):
+        """ A class to store the results of the sampling run
 
         Parameters
         ----------
-        dictionary: dict
-            A dictionary containing values to be set in this instance
-        """
-
-        # Set some defaults
-        self.outdir = '.'
-        self.label = 'no_name'
-
-        dict.__init__(self)
-        if type(dictionary) is dict:
-            for key in dictionary:
-                val = self._standardise_a_string(dictionary[key])
-                setattr(self, key, val)
-
-        if getattr(self, 'priors', None) is not None:
-            self.priors = PriorDict(self.priors)
-
-    def __add__(self, other):
-        matches = ['sampler', 'search_parameter_keys']
-        for match in matches:
-            # The 1 and 0 here ensure that if either doesn't have a match for
-            # some reason, a error will be thrown.
-            if getattr(other, match, 1) != getattr(self, match, 0):
-                raise ValueError(
-                    "Unable to add results generated with different {}".format(match))
-
-        self.samples = np.concatenate([self.samples, other.samples])
-        self.posterior = pd.concat([self.posterior, other.posterior])
-        return self
-
-    def __dir__(self):
-        """ Adds tab completion in ipython
-
-        See: http://ipython.org/ipython-doc/dev/config/integrating.html
+        label, outdir, sampler: str
+            The label, output directory, and sampler used
+        search_parameter_keys, fixed_parameter_keys: list
+            Lists of the search and fixed parameter keys. Elemenents of the
+            list should be of type `str` and matchs the keys of the `prior`
+        priors: dict, bilby.core.prior.PriorDict
+            A dictionary of the priors used in the run
+        sampler_kwargs: dict
+            Key word arguments passed to the sampler
+        injection_parameters: dict
+            A dictionary of the injection parameters
+        meta_data: dict
+            A dictionary of meta data to store about the run
+        posterior: pandas.DataFrame
+            A pandas data frame of the posterior
+        samples, nested_samples: array_like
+            An array of the output posterior samples and the unweighted samples
+        log_evidence, log_evidence_err, log_noise_evidence, log_bayes_factor: float
+            Natural log evidences
+        log_likelihood_evaluations: array_like
+            The evaluations of the likelihood for each sample point
+        sampling_time: float
+            The time taken to complete the sampling
+        nburn: int
+            The number of burn-in steps discarded for MCMC samplers
+        walkers: array_like
+            The samplers taken by a ensemble MCMC samplers
+        max_autocorrelation_time: float
+            The estimated maximum autocorrelation time for MCMC samplers
+        parameter_labels, parameter_labels_with_unit: list
+            Lists of the latex-formatted parameter labels
+
+        Note:
+            All sampling output parameters, e.g. the samples themselves are
+            typically not given at initialisation, but set at a later stage.
 
         """
-        methods = ['plot_corner', 'save_to_file', 'save_posterior_samples']
-        return self.keys() + methods
-
-    def __getattr__(self, name):
-        try:
-            return self[name]
-        except KeyError:
-            raise AttributeError(name)
-
-    __setattr__ = dict.__setitem__
-    __delattr__ = dict.__delitem__
 
-    def __repr__(self):
+        self.label = label
+        self.outdir = os.path.abspath(outdir)
+        self.sampler = sampler
+        self.search_parameter_keys = search_parameter_keys
+        self.fixed_parameter_keys = fixed_parameter_keys
+        self.parameter_labels = parameter_labels
+        self.parameter_labels_with_unit = parameter_labels_with_unit
+        self.priors = priors
+        self.sampler_kwargs = sampler_kwargs
+        self.meta_data = meta_data
+        self.injection_parameters = injection_parameters
+        self.posterior = posterior
+        self.samples = samples
+        self.nested_samples = nested_samples
+        self.walkers = walkers
+        self.nburn = nburn
+        self.log_evidence = log_evidence
+        self.log_evidence_err = log_evidence_err
+        self.log_noise_evidence = log_noise_evidence
+        self.log_bayes_factor = log_bayes_factor
+        self.log_likelihood_evaluations = log_likelihood_evaluations
+        self.sampling_time = sampling_time
+
+    def __str__(self):
         """Print a summary """
-        if hasattr(self, 'posterior'):
-            if hasattr(self, 'log_noise_evidence'):
+        if getattr(self, 'posterior', None) is not None:
+            if getattr(self, 'log_noise_evidence', None) is not None:
                 return ("nsamples: {:d}\n"
                         "log_noise_evidence: {:6.3f}\n"
                         "log_evidence: {:6.3f} +/- {:6.3f}\n"
@@ -132,40 +155,109 @@ class Result(dict):
         else:
             return ''
 
-    @staticmethod
-    def _standardise_a_string(item):
-        """ When reading in data, ensure all strings are decoded correctly
+    @property
+    def priors(self):
+        if self._priors is not None:
+            return self._priors
+        else:
+            raise ValueError('Result object has no priors')
 
-        Parameters
-        ----------
-        item: str
+    @priors.setter
+    def priors(self, priors):
+        if isinstance(priors, dict):
+            self._priors = PriorDict(priors)
+            if self.parameter_labels is None:
+                self.parameter_labels = [self.priors[k].latex_label for k in
+                                         self.search_parameter_keys]
+            if self.parameter_labels_with_unit is None:
+                self.parameter_labels_with_unit = [
+                    self.priors[k].latex_label_with_unit for k in
+                    self.search_parameter_keys]
+
+        elif priors is None:
+            self._priors = priors
+            self.parameter_labels = self.search_parameter_keys
+            self.parameter_labels_with_unit = self.search_parameter_keys
+        else:
+            raise ValueError("Input priors not understood")
 
-        Returns
-        -------
-        str: decoded string
-        """
-        if type(item) in [bytes]:
-            return item.decode()
+    @property
+    def samples(self):
+        """ An array of samples """
+        if self._samples is not None:
+            return self._samples
         else:
-            return item
+            raise ValueError("Result object has no stored samples")
 
-    @staticmethod
-    def _standardise_strings(item):
-        """
+    @samples.setter
+    def samples(self, samples):
+        self._samples = samples
 
-        Parameters
-        ----------
-        item: list
-            List of strings to be decoded
+    @property
+    def nested_samples(self):
+        """" An array of unweighted samples """
+        if self._nested_samples is not None:
+            return self._nested_samples
+        else:
+            raise ValueError("Result object has no stored nested samples")
 
-        Returns
-        -------
-        list: list of decoded strings in item
+    @nested_samples.setter
+    def nested_samples(self, nested_samples):
+        self._nested_samples = nested_samples
 
-        """
-        if type(item) in [list]:
-            item = [Result._standardise_a_string(i) for i in item]
-        return item
+    @property
+    def walkers(self):
+        """" An array of the ensemble walkers """
+        if self._walkers is not None:
+            return self._walkers
+        else:
+            raise ValueError("Result object has no stored walkers")
+
+    @walkers.setter
+    def walkers(self, walkers):
+        self._walkers = walkers
+
+    @property
+    def nburn(self):
+        """" An array of the ensemble walkers """
+        if self._nburn is not None:
+            return self._nburn
+        else:
+            raise ValueError("Result object has no stored nburn")
+
+    @nburn.setter
+    def nburn(self, nburn):
+        self._nburn = nburn
+
+    @property
+    def posterior(self):
+        """ A pandas data frame of the posterior """
+        if self._posterior is not None:
+            return self._posterior
+        else:
+            raise ValueError("Result object has no stored posterior")
+
+    @posterior.setter
+    def posterior(self, posterior):
+        self._posterior = posterior
+
+    def _get_save_data_dictionary(self):
+        save_attrs = [
+            'label', 'outdir', 'sampler', 'log_evidence', 'log_evidence_err',
+            'log_noise_evidence', 'log_bayes_factor', 'priors', 'posterior',
+            'injection_parameters', 'meta_data', 'search_parameter_keys',
+            'fixed_parameter_keys', 'sampling_time', 'sampler_kwargs',
+            'log_likelihood_evaluations', 'samples', 'nested_samples',
+            'walkers', 'nburn', 'parameter_labels',
+            'parameter_labels_with_unit']
+        dictionary = OrderedDict()
+        for attr in save_attrs:
+            try:
+                dictionary[attr] = getattr(self, attr)
+            except ValueError as e:
+                logger.debug("Unable to save {}, message: {}".format(attr, e))
+                pass
+        return dictionary
 
     def save_to_file(self, overwrite=False):
         """
@@ -192,15 +284,15 @@ class Result(dict):
         logger.debug("Saving result to {}".format(file_name))
 
         # Convert the prior to a string representation for saving on disk
-        dictionary = dict(self)
+        dictionary = self._get_save_data_dictionary()
         if dictionary.get('priors', False):
             dictionary['priors'] = {key: str(self.priors[key]) for key in self.priors}
 
-        # Convert callable kwargs to strings to avoid pickling issues
-        if hasattr(self, 'kwargs'):
-            for key in self.kwargs:
-                if hasattr(self.kwargs[key], '__call__'):
-                    self.kwargs[key] = str(self.kwargs[key])
+        # Convert callable sampler_kwargs to strings to avoid pickling issues
+        if dictionary.get('sampler_kwargs', None) is not None:
+            for key in dictionary['sampler_kwargs']:
+                if hasattr(dictionary['sampler_kwargs'][key], '__call__'):
+                    dictionary['sampler_kwargs'][key] = str(dictionary['sampler_kwargs'])
 
         try:
             deepdish.io.save(file_name, dictionary)
@@ -211,6 +303,7 @@ class Result(dict):
     def save_posterior_samples(self):
         """Saves posterior samples to a file"""
         filename = '{}/{}_posterior_samples.txt'.format(self.outdir, self.label)
+        utils.check_directory_exists_and_if_not_mkdir(self.outdir)
         self.posterior.to_csv(filename, index=False, header=True)
 
     def get_latex_labels_from_parameter_keys(self, keys):
@@ -316,9 +409,9 @@ class Result(dict):
         parameters: (list, dict), optional
             If given, either a list of the parameter names to include, or a
             dictionary of parameter names and their "true" values to plot.
-        priors: {bool (False), bilby.core.prior.PriorSet}
+        priors: {bool (False), bilby.core.prior.PriorDict}
             If true, add the stored prior probability density functions to the
-            one-dimensional marginal distributions. If instead a PriorSet
+            one-dimensional marginal distributions. If instead a PriorDict
             is provided, this will be plotted.
         titles: bool
             If true, add 1D titles of the median and (by default 1-sigma)
@@ -606,11 +699,15 @@ class Result(dict):
             s = model_posterior.sample().to_dict('records')[0]
             ax.plot(xsmooth, model(xsmooth, **s), alpha=0.25, lw=0.1, color='r',
                     label=draws_label)
-        if all(~np.isnan(self.posterior.log_likelihood)):
-            logger.info('Plotting maximum likelihood')
-            s = model_posterior.ix[self.posterior.log_likelihood.idxmax()]
-            ax.plot(xsmooth, model(xsmooth, **s), lw=1, color='k',
-                    label=maxl_label)
+        try:
+            if all(~np.isnan(self.posterior.log_likelihood)):
+                logger.info('Plotting maximum likelihood')
+                s = model_posterior.ix[self.posterior.log_likelihood.idxmax()]
+                ax.plot(xsmooth, model(xsmooth, **s), lw=1, color='k',
+                        label=maxl_label)
+        except AttributeError:
+            logger.debug(
+                "No log likelihood values stored, unable to plot max")
 
         ax.plot(x, y, data_fmt, markersize=2, label=data_label)
 
@@ -625,13 +722,16 @@ class Result(dict):
         ax.legend(numpoints=3)
         fig.tight_layout()
         if filename is None:
+            utils.check_directory_exists_and_if_not_mkdir(self.outdir)
             filename = '{}/{}_plot_with_data'.format(self.outdir, self.label)
         fig.savefig(filename, dpi=dpi)
 
     def samples_to_posterior(self, likelihood=None, priors=None,
                              conversion_function=None):
         """
-        Convert array of samples to posterior (a Pandas data frame).
+        Convert array of samples to posterior (a Pandas data frame)
+
+        Also applies the conversion function to any stored posterior
 
         Parameters
         ----------
@@ -643,7 +743,9 @@ class Result(dict):
             Function which adds in extra parameters to the data frame,
             should take the data_frame, likelihood and prior as arguments.
         """
-        if hasattr(self, 'posterior') is False:
+        try:
+            data_frame = self.posterior
+        except ValueError:
             data_frame = pd.DataFrame(
                 self.samples, columns=self.search_parameter_keys)
             for key in priors:
@@ -653,10 +755,6 @@ class Result(dict):
                     data_frame[key] = priors[key]
             data_frame['log_likelihood'] = getattr(
                 self, 'log_likelihood_evaluations', np.nan)
-            # remove the array of samples
-            del self.samples
-        else:
-            data_frame = self.posterior
         if conversion_function is not None:
             data_frame = conversion_function(data_frame, likelihood, priors)
         self.posterior = data_frame
@@ -679,7 +777,7 @@ class Result(dict):
                     self.prior_values[key]\
                         = priors[key].prob(self.posterior[key].values)
 
-    def check_attribute_match_to_other_object(self, name, other_object):
+    def _check_attribute_match_to_other_object(self, name, other_object):
         """ Check attribute name exists in other_object and is the same
 
         Parameters
diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py
index bf5f8391..12138af0 100644
--- a/bilby/core/sampler/__init__.py
+++ b/bilby/core/sampler/__init__.py
@@ -116,6 +116,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
             sampler_class = implemented_samplers[sampler.lower()]
             sampler = sampler_class(
                 likelihood, priors=priors, outdir=outdir, label=label,
+                injection_parameters=injection_parameters, meta_data=meta_data,
                 use_ratio=use_ratio, plot=plot, **kwargs)
         else:
             print(implemented_samplers)
@@ -125,6 +126,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
         sampler = sampler.__init__(
             likelihood, priors=priors,
             outdir=outdir, label=label, use_ratio=use_ratio, plot=plot,
+            injection_parameters=injection_parameters, meta_data=meta_data,
             **kwargs)
     else:
         raise ValueError(
@@ -142,11 +144,6 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
     else:
         result = sampler.run_sampler()
 
-    if type(meta_data) == dict:
-        result.update(meta_data)
-
-    result.priors = priors
-
     end_time = datetime.datetime.now()
     result.sampling_time = (end_time - start_time).total_seconds()
     logger.info('Sampling time: {}'.format(end_time - start_time))
@@ -160,15 +157,14 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
         result.log_noise_evidence = likelihood.noise_log_likelihood()
         result.log_bayes_factor = \
             result.log_evidence - result.log_noise_evidence
-    if injection_parameters is not None:
-        result.injection_parameters = injection_parameters
+
+    if result.injection_parameters is not None:
         if conversion_function is not None:
             result.injection_parameters = conversion_function(
                 result.injection_parameters)
-    result.fixed_parameter_keys = sampler.fixed_parameter_keys
+
     result.samples_to_posterior(likelihood=likelihood, priors=priors,
                                 conversion_function=conversion_function)
-    result.kwargs = sampler.kwargs
     if save:
         result.save_to_file()
         logger.info("Results saved to {}/".format(outdir))
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index c9f1190f..1f2ee876 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -29,6 +29,10 @@ class Sampler(object):
         or just the log-likelihood
     plot: bool, optional
         Switch to set whether or not you want to create traceplots
+    injection_parameters:
+        A dictionary of the injection parameters
+    meta_data:
+        A dictionary of extra meta data to store in the result
     **kwargs: dict
 
     Attributes
@@ -72,7 +76,8 @@ class Sampler(object):
 
     def __init__(
             self, likelihood, priors, outdir='outdir', label='label',
-            use_ratio=False, plot=False, skip_import_verification=False, **kwargs):
+            use_ratio=False, plot=False, skip_import_verification=False,
+            injection_parameters=None, meta_data=None, **kwargs):
         self.likelihood = likelihood
         if isinstance(priors, PriorDict):
             self.priors = priors
@@ -80,6 +85,8 @@ class Sampler(object):
             self.priors = PriorDict(priors)
         self.label = label
         self.outdir = outdir
+        self.injection_parameters = injection_parameters
+        self.meta_data = meta_data
         self.use_ratio = use_ratio
         if not skip_import_verification:
             self._verify_external_sampler()
@@ -186,19 +193,13 @@ class Sampler(object):
         bilby.core.result.Result: An initial template for the result
 
         """
-        result = Result()
-        result.sampler = self.__class__.__name__.lower()
-        result.search_parameter_keys = self.__search_parameter_keys
-        result.fixed_parameter_keys = self.__fixed_parameter_keys
-        result.parameter_labels = [
-            self.priors[k].latex_label for k in
-            self.__search_parameter_keys]
-        result.parameter_labels_with_unit = [
-            self.priors[k].latex_label_with_unit for k in
-            self.__search_parameter_keys]
-        result.label = self.label
-        result.outdir = self.outdir
-        result.kwargs = self.kwargs
+        result = Result(label=self.label, outdir=self.outdir,
+                        sampler=self.__class__.__name__.lower(),
+                        search_parameter_keys=self.__search_parameter_keys,
+                        fixed_parameter_keys=self.__fixed_parameter_keys,
+                        priors=self.priors, meta_data=self.meta_data,
+                        injection_parameters=self.injection_parameters,
+                        sampler_kwargs=self.kwargs)
         return result
 
     def _check_if_priors_can_be_sampled(self):
@@ -358,8 +359,9 @@ class Sampler(object):
             return
 
         try:
-            self.cached_result = read_in_result(self.outdir, self.label)
-        except ValueError:
+            self.cached_result = read_in_result(
+                outdir=self.outdir, label=self.label)
+        except IOError:
             self.cached_result = None
 
         if command_line_args.use_cached:
@@ -373,7 +375,7 @@ class Sampler(object):
                           'kwargs']
             use_cache = True
             for key in check_keys:
-                if self.cached_result.check_attribute_match_to_other_object(
+                if self.cached_result._check_attribute_match_to_other_object(
                         key, self) is False:
                     logger.debug("Cached value {} is unmatched".format(key))
                     use_cache = False
diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index 1f512999..11197d1e 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -159,11 +159,13 @@ class Dynesty(NestedSampler):
 
         # self.result.sampler_output = out
         weights = np.exp(out['logwt'] - out['logz'][-1])
-        self.result.samples = dynesty.utils.resample_equal(out.samples, weights)
-        self.result.nested_samples = DataFrame(
+        nested_samples = DataFrame(
             out.samples, columns=self.search_parameter_keys)
-        self.result.nested_samples['weights'] = weights
-        self.result.nested_samples['log_likelihood'] = out.logl
+        nested_samples['weights'] = weights
+        nested_samples['log_likelihood'] = out.logl
+
+        self.result.samples = dynesty.utils.resample_equal(out.samples, weights)
+        self.result.nested_samples = nested_samples
         self.result.log_likelihood_evaluations = self.reorder_loglikelihoods(
             unsorted_loglikelihoods=out.logl, unsorted_samples=out.samples,
             sorted_samples=self.result.samples)
diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py
index bf6435ab..e1d70da5 100644
--- a/bilby/core/sampler/pymc3.py
+++ b/bilby/core/sampler/pymc3.py
@@ -105,16 +105,14 @@ class Pymc3(MCMCSampler):
         """
         Initialise results within Pymc3 subclass.
         """
-        result = Result()
-        result.sampler = self.__class__.__name__.lower()
-        result.search_parameter_keys = self.__search_parameter_keys
-        result.fixed_parameter_keys = self.__fixed_parameter_keys
-        result.parameter_labels = [
-            self.priors[k].latex_label for k in
-            self.__search_parameter_keys]
-        result.label = self.label
-        result.outdir = self.outdir
-        result.kwargs = self.kwargs
+
+        result = Result(label=self.label, outdir=self.outdir,
+                        sampler=self.__class__.__name__.lower(),
+                        search_parameter_keys=self.__search_parameter_keys,
+                        fixed_parameter_keys=self.__fixed_parameter_keys,
+                        priors=self.priors, meta_data=self.meta_data,
+                        injection_parameters=self.injection_parameters,
+                        sampler_kwargs=self.kwargs)
         return result
 
     @property
diff --git a/test/result_test.py b/test/result_test.py
index 8dabec04..d46ba82b 100644
--- a/test/result_test.py
+++ b/test/result_test.py
@@ -5,23 +5,33 @@ import unittest
 import numpy as np
 import pandas as pd
 import shutil
+import os
 
 
 class TestResult(unittest.TestCase):
 
     def setUp(self):
         bilby.utils.command_line_args.test = False
-        result = bilby.core.result.Result()
-        test_directory = 'test_directory'
-        result.outdir = test_directory
-        result.label = 'test'
+        priors = bilby.prior.PriorSet(dict(
+            x=bilby.prior.Uniform(0, 1, 'x', latex_label='$x$', unit='s'),
+            y=bilby.prior.Uniform(0, 1, 'y', latex_label='$y$', unit='m'),
+            c=1,
+            d=2))
+        result = bilby.core.result.Result(
+            label='label', outdir='outdir', sampler='nestle',
+            search_parameter_keys=['x', 'y'], fixed_parameter_keys=['c', 'd'],
+            priors=priors, sampler_kwargs=dict(test='test', func=lambda x: x),
+            injection_parameters=dict(x=0.5, y=0.5),
+            meta_data=dict(test='test'))
 
         N = 100
         posterior = pd.DataFrame(dict(x=np.random.normal(0, 1, N),
                                       y=np.random.normal(0, 1, N)))
-        result.search_parameter_keys = ['x', 'y']
-        result.parameter_labels_with_unit = ['x', 'y']
         result.posterior = posterior
+        result.log_evidence = 10
+        result.log_evidence_err = 11
+        result.log_bayes_factor = 12
+        result.log_noise_evidence = 13
         self.result = result
         pass
 
@@ -34,6 +44,158 @@ class TestResult(unittest.TestCase):
         del self.result
         pass
 
+    def test_result_file_name(self):
+        outdir = 'outdir'
+        label = 'label'
+        self.assertEqual(bilby.core.result.result_file_name(outdir, label),
+                         '{}/{}_result.h5'.format(outdir, label))
+
+    def test_fail_save_and_load(self):
+        with self.assertRaises(ValueError):
+            bilby.core.result.read_in_result()
+
+        with self.assertRaises(IOError):
+            bilby.core.result.read_in_result(filename='not/a/file')
+
+    def test_unset_priors(self):
+        result = bilby.core.result.Result(
+            label='label', outdir='outdir', sampler='nestle',
+            search_parameter_keys=['x', 'y'], fixed_parameter_keys=['c', 'd'],
+            priors=None, sampler_kwargs=dict(test='test'),
+            injection_parameters=dict(x=0.5, y=0.5),
+            meta_data=dict(test='test'))
+        with self.assertRaises(ValueError):
+            result.priors
+        self.assertEqual(result.parameter_labels, result.search_parameter_keys)
+        self.assertEqual(result.parameter_labels_with_unit, result.search_parameter_keys)
+
+    def test_unknown_priors_fail(self):
+        with self.assertRaises(ValueError):
+            bilby.core.result.Result(
+                label='label', outdir='outdir', sampler='nestle',
+                search_parameter_keys=['x', 'y'], fixed_parameter_keys=['c', 'd'],
+                priors=['a', 'b'], sampler_kwargs=dict(test='test'),
+                injection_parameters=dict(x=0.5, y=0.5),
+                meta_data=dict(test='test'))
+
+    def test_set_samples(self):
+        samples = [1, 2, 3]
+        self.result.samples = samples
+        self.assertEqual(samples, self.result.samples)
+
+    def test_set_nested_samples(self):
+        nested_samples = [1, 2, 3]
+        self.result.nested_samples = nested_samples
+        self.assertEqual(nested_samples, self.result.nested_samples)
+
+    def test_set_walkers(self):
+        walkers = [1, 2, 3]
+        self.result.walkers = walkers
+        self.assertEqual(walkers, self.result.walkers)
+
+    def test_set_nburn(self):
+        nburn = 1
+        self.result.nburn = nburn
+        self.assertEqual(nburn, self.result.nburn)
+
+    def test_unset_posterior(self):
+        self.result.posterior = None
+        with self.assertRaises(ValueError):
+            self.result.posterior
+
+    def test_save_and_load(self):
+        self.result.save_to_file()
+        loaded_result = bilby.core.result.read_in_result(
+            outdir=self.result.outdir, label=self.result.label)
+        self.assertTrue(
+            all(self.result.posterior == loaded_result.posterior))
+        self.assertTrue(self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys)
+        self.assertTrue(self.result.search_parameter_keys == loaded_result.search_parameter_keys)
+        self.assertEqual(self.result.meta_data, loaded_result.meta_data)
+        self.assertEqual(self.result.injection_parameters, loaded_result.injection_parameters)
+        self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
+        self.assertEqual(self.result.log_noise_evidence, loaded_result.log_noise_evidence)
+        self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
+        self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
+        self.assertEqual(self.result.priors['x'], loaded_result.priors['x'])
+        self.assertEqual(self.result.priors['y'], loaded_result.priors['y'])
+        self.assertEqual(self.result.priors['c'], loaded_result.priors['c'])
+        self.assertEqual(self.result.priors['d'], loaded_result.priors['d'])
+
+    def test_save_and_dont_overwrite(self):
+        shutil.rmtree(
+            '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
+            ignore_errors=True)
+        self.result.save_to_file(overwrite=False)
+        self.result.save_to_file(overwrite=False)
+        self.assertTrue(os.path.isfile(
+            '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label)))
+
+    def test_save_and_overwrite(self):
+        shutil.rmtree(
+            '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
+            ignore_errors=True)
+        self.result.save_to_file(overwrite=True)
+        self.result.save_to_file(overwrite=True)
+        self.assertFalse(os.path.isfile(
+            '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label)))
+
+    def test_save_samples(self):
+        self.result.save_posterior_samples()
+        filename = '{}/{}_posterior_samples.txt'.format(self.result.outdir, self.result.label)
+        self.assertTrue(os.path.isfile(filename))
+        df = pd.read_csv(filename)
+        self.assertTrue(all(self.result.posterior == df))
+
+    def test_samples_to_posterior(self):
+        self.result.posterior = None
+        x = [1, 2, 3]
+        y = [4, 6, 8]
+        log_likelihood = [6, 7, 8]
+        self.result.samples = np.array([x, y]).T
+        self.result.log_likelihood_evaluations = log_likelihood
+        self.result.samples_to_posterior(priors=self.result.priors)
+        self.assertTrue(all(self.result.posterior['x'] == x))
+        self.assertTrue(all(self.result.posterior['y'] == y))
+        self.assertTrue(
+            all(self.result.posterior['log_likelihood'] == log_likelihood))
+        self.assertTrue(
+            all(self.result.posterior['c'] == self.result.priors['c'].peak))
+        self.assertTrue(
+            all(self.result.posterior['d'] == self.result.priors['d'].peak))
+
+    def test_calculate_prior_values(self):
+        self.result.calculate_prior_values(priors=self.result.priors)
+        self.assertEqual(len(self.result.posterior), len(self.result.prior_values))
+
+    def test_plot_multiple(self):
+        filename='multiple.png'.format(self.result.outdir)
+        bilby.core.result.plot_multiple([self.result, self.result],
+                                        filename=filename)
+        self.assertTrue(os.path.isfile(filename))
+        os.remove(filename)
+
+    def test_plot_walkers(self):
+        self.result.walkers = np.random.uniform(0, 1, (10, 11, 2))
+        self.result.nburn = 5
+        self.result.plot_walkers()
+        self.assertTrue(
+            os.path.isfile('{}/{}_walkers.png'.format(
+                self.result.outdir, self.result.label)))
+
+    def test_plot_with_data(self):
+        x = np.linspace(0, 1, 10)
+        y = np.linspace(0, 1, 10)
+
+        def model(x):
+            return x
+        self.result.plot_with_data(model, x, y, ndraws=10)
+        self.assertTrue(
+            os.path.isfile('{}/{}_plot_with_data.png'.format(
+                self.result.outdir, self.result.label)))
+        self.result.posterior['log_likelihood'] = np.random.uniform(0, 1, len(self.result.posterior))
+        self.result.plot_with_data(model, x, y, ndraws=10, xlabel='a', ylabel='y')
+
     def test_plot_corner(self):
         self.result.injection_parameters = dict(x=0.8, y=1.1)
         self.result.plot_corner()
@@ -68,6 +230,5 @@ class TestResult(unittest.TestCase):
         with self.assertRaises(ValueError):
             self.result.plot_corner(priors='test')
 
-
 if __name__ == '__main__':
     unittest.main()
diff --git a/test/sampler_test.py b/test/sampler_test.py
index d4cac1ae..f3ba3ffb 100644
--- a/test/sampler_test.py
+++ b/test/sampler_test.py
@@ -1,7 +1,6 @@
 from __future__ import absolute_import
 import bilby
 from bilby.core import prior
-from bilby.core.result import Result
 import unittest
 from mock import MagicMock
 import numpy as np
@@ -54,16 +53,6 @@ class TestSampler(unittest.TestCase):
     def test_label(self):
         self.assertEqual(self.sampler.label, 'label')
 
-    def test_result(self):
-        expected_result = Result()
-        expected_result.search_parameter_keys = ['c']
-        expected_result.fixed_parameter_keys = ['a']
-        expected_result.parameter_labels = [None]
-        expected_result.label = 'label'
-        expected_result.outdir = 'test_directory'
-        expected_result.kwargs = {}
-        self.assertDictEqual(self.sampler.result.__dict__, expected_result.__dict__)
-
     def test_prior_transform_transforms_search_parameter_keys(self):
         self.sampler.prior_transform([0])
         expected_prior = prior.Uniform(0, 1)
-- 
GitLab