From da685c533ffb83df2891c206758ece989f380eca Mon Sep 17 00:00:00 2001
From: Sylvia Biscoveanu <sylvia.biscoveanu@ligo.org>
Date: Mon, 25 Feb 2019 17:33:48 -0600
Subject: [PATCH] Jsonify results

---
 bilby/core/result.py           | 86 ++++++++++++++++++++++++++++++----
 bilby/core/sampler/__init__.py |  6 ++-
 test/result_test.py            | 58 ++++++++++++++++++++---
 3 files changed, 133 insertions(+), 17 deletions(-)

diff --git a/bilby/core/result.py b/bilby/core/result.py
index 5c95f1e83..3813e664e 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -8,6 +8,7 @@ import numpy as np
 import deepdish
 import pandas as pd
 import corner
+import json
 import scipy.stats
 import matplotlib
 import matplotlib.pyplot as plt
@@ -19,7 +20,7 @@ from .utils import (logger, infer_parameters_from_function,
 from .prior import Prior, PriorDict, DeltaFunction
 
 
-def result_file_name(outdir, label):
+def result_file_name(outdir, label, extension='json'):
     """ Returns the standard filename used for a result file
 
     Parameters
@@ -28,17 +29,27 @@ def result_file_name(outdir, label):
         Name of the output directory
     label: str
         Naming scheme of the output file
+    extension: str, optional
+        Whether to save as `hdf5` or `json`
 
     Returns
     -------
     str: File name of the output file
     """
-    return '{}/{}_result.h5'.format(outdir, label)
+    if extension == 'hdf5':
+        return '{}/{}_result.h5'.format(outdir, label)
+    else:
+        return '{}/{}_result.json'.format(outdir, label)
 
 
 def read_in_result(filename=None, outdir=None, label=None):
-    """ Wrapper to bilby.core.result.Result.from_hdf5 """
-    return Result.from_hdf5(filename=filename, outdir=outdir, label=label)
+    """ Wrapper to bilby.core.result.Result.from_hdf5
+        or bilby.core.result.Result.from_json """
+    try:
+        result = Result.from_json(filename=filename, outdir=outdir, label=label)
+    except (IOError, ValueError):
+        result = Result.from_hdf5(filename=filename, outdir=outdir, label=label)
+    return result
 
 
 class Result(object):
@@ -155,7 +166,7 @@ class Result(object):
             if (outdir is None) and (label is None):
                 raise ValueError("No information given to load file")
             else:
-                filename = result_file_name(outdir, label)
+                filename = result_file_name(outdir, label, extension='hdf5')
         if os.path.isfile(filename):
             dictionary = deepdish.io.load(filename)
             # Some versions of deepdish/pytables return the dictionanary as
@@ -169,6 +180,50 @@ class Result(object):
         else:
             raise IOError("No result '{}' found".format(filename))
 
+    @classmethod
+    def from_json(cls, filename=None, outdir=None, label=None):
+        """ Read in a saved .json data file
+
+        Parameters
+        ----------
+        filename: str
+            If given, try to load from this filename
+        outdir, label: str
+            If given, use the default naming convention for saved results file
+
+        Returns
+        -------
+        result: bilby.core.result.Result
+
+        Raises
+        -------
+        ValueError: If no filename is given and either outdir or label is None
+                    If no bilby.core.result.Result is found in the path
+
+        """
+        if filename is None:
+            if (outdir is None) and (label is None):
+                raise ValueError("No information given to load file")
+            else:
+                filename = result_file_name(outdir, label)
+        if os.path.isfile(filename):
+            dictionary = json.load(open(filename, 'r'))
+            for key in dictionary.keys():
+                # Convert some dictionaries back to DataFrames
+                if key in ['posterior', 'nested_samples']:
+                    dictionary[key] = pd.DataFrame.from_dict(dictionary[key])
+                # Convert the loaded priors to bilby prior type
+                if key == 'priors':
+                    for param in dictionary[key].keys():
+                        dictionary[key][param] = str(dictionary[key][param])
+                    dictionary[key] = PriorDict(dictionary[key])
+            try:
+                return cls(**dictionary)
+            except TypeError as e:
+                raise IOError("Unable to load dictionary, error={}".format(e))
+        else:
+            raise IOError("No result '{}' found".format(filename))
+
     def __str__(self):
         """Print a summary """
         if getattr(self, 'posterior', None) is not None:
@@ -303,9 +358,9 @@ class Result(object):
                 pass
         return dictionary
 
-    def save_to_file(self, overwrite=False, outdir=None):
+    def save_to_file(self, overwrite=False, outdir=None, extension='json'):
         """
-        Writes the Result to a deepdish h5 file
+        Writes the Result to a json or deepdish h5 file
 
         Parameters
         ----------
@@ -314,9 +369,11 @@ class Result(object):
             default=False
         outdir: str, optional
             Path to the outdir. Default is the one stored in the result object.
+        extension: str, optional
+            Whether to save as hdf5 instead of json
         """
         outdir = self._safe_outdir_creation(outdir, self.save_to_file)
-        file_name = result_file_name(outdir, self.label)
+        file_name = result_file_name(outdir, self.label, extension)
 
         if os.path.isfile(file_name):
             if overwrite:
@@ -341,8 +398,19 @@ class Result(object):
                 if hasattr(dictionary['sampler_kwargs'][key], '__call__'):
                     dictionary['sampler_kwargs'][key] = str(dictionary['sampler_kwargs'])
 
+        # Convert to json saveable format
+        if extension != 'hdf5':
+            for key in dictionary.keys():
+                if isinstance(dictionary[key], pd.core.frame.DataFrame):
+                    dictionary[key] = dictionary[key].to_dict()
+                elif isinstance(dictionary[key], np.ndarray):
+                    dictionary[key] = dictionary[key].tolist()
+
         try:
-            deepdish.io.save(file_name, dictionary)
+            if extension == 'hdf5':
+                deepdish.io.save(file_name, dictionary)
+            else:
+                json.dump(dictionary, open(file_name, 'w'), indent=2)
         except Exception as e:
             logger.error("\n\n Saving the data has failed with the "
                          "following message:\n {} \n\n".format(e))
diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py
index 7a15b410c..147d95c5d 100644
--- a/bilby/core/sampler/__init__.py
+++ b/bilby/core/sampler/__init__.py
@@ -85,6 +85,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
         overwritten.
     save: bool
         If true, save the priors and results to disk.
+        If hdf5, save as an hdf5 file instead of json.
     result_class: bilby.core.result.Result, or child of
         The result class to use. By default, `bilby.core.result.Result` is used,
         but objects which inherit from this class can be given providing
@@ -183,7 +184,10 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
 
     result.samples_to_posterior(likelihood=likelihood, priors=priors,
                                 conversion_function=conversion_function)
-    if save:
+    if save == 'hdf5':
+        result.save_to_file(extension='hdf5')
+        logger.info("Results saved to {}/".format(outdir))
+    elif save:
         result.save_to_file()
         logger.info("Results saved to {}/".format(outdir))
     if plot:
diff --git a/test/result_test.py b/test/result_test.py
index 6dbc5a234..2aea03571 100644
--- a/test/result_test.py
+++ b/test/result_test.py
@@ -45,10 +45,16 @@ class TestResult(unittest.TestCase):
         del self.result
         pass
 
-    def test_result_file_name(self):
+    def test_result_file_name_default(self):
         outdir = 'outdir'
         label = 'label'
         self.assertEqual(bilby.core.result.result_file_name(outdir, label),
+                         '{}/{}_result.json'.format(outdir, label))
+
+    def test_result_file_name_hdf5(self):
+        outdir = 'outdir'
+        label = 'label'
+        self.assertEqual(bilby.core.result.result_file_name(outdir, label, extension='hdf5'),
                          '{}/{}_result.h5'.format(outdir, label))
 
     def test_fail_save_and_load(self):
@@ -104,8 +110,8 @@ class TestResult(unittest.TestCase):
         with self.assertRaises(ValueError):
             _ = self.result.posterior
 
-    def test_save_and_load(self):
-        self.result.save_to_file()
+    def test_save_and_load_hdf5(self):
+        self.result.save_to_file(extension='hdf5')
         loaded_result = bilby.core.result.read_in_result(
             outdir=self.result.outdir, label=self.result.label)
         self.assertTrue(pd.DataFrame.equals
@@ -123,23 +129,61 @@ class TestResult(unittest.TestCase):
         self.assertEqual(self.result.priors['c'], loaded_result.priors['c'])
         self.assertEqual(self.result.priors['d'], loaded_result.priors['d'])
 
-    def test_save_and_dont_overwrite(self):
+    def test_save_and_load_default(self):
+        self.result.save_to_file()
+        loaded_result = bilby.core.result.read_in_result(
+            outdir=self.result.outdir, label=self.result.label)
+        self.assertTrue(np.array_equal
+                        (self.result.posterior.sort_values(by=['x']),
+                            loaded_result.posterior.sort_values(by=['x'])))
+        self.assertTrue(self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys)
+        self.assertTrue(self.result.search_parameter_keys == loaded_result.search_parameter_keys)
+        self.assertEqual(self.result.meta_data, loaded_result.meta_data)
+        self.assertEqual(self.result.injection_parameters, loaded_result.injection_parameters)
+        self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
+        self.assertEqual(self.result.log_noise_evidence, loaded_result.log_noise_evidence)
+        self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
+        self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
+        self.assertEqual(self.result.priors['x'], loaded_result.priors['x'])
+        self.assertEqual(self.result.priors['y'], loaded_result.priors['y'])
+        self.assertEqual(self.result.priors['c'], loaded_result.priors['c'])
+        self.assertEqual(self.result.priors['d'], loaded_result.priors['d'])
+
+    def test_save_and_dont_overwrite_default(self):
         shutil.rmtree(
-            '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
+            '{}/{}_result.json.old'.format(self.result.outdir, self.result.label),
             ignore_errors=True)
         self.result.save_to_file(overwrite=False)
         self.result.save_to_file(overwrite=False)
+        self.assertTrue(os.path.isfile(
+            '{}/{}_result.json.old'.format(self.result.outdir, self.result.label)))
+
+    def test_save_and_dont_overwrite_hdf5(self):
+        shutil.rmtree(
+            '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
+            ignore_errors=True)
+        self.result.save_to_file(overwrite=False, extension='hdf5')
+        self.result.save_to_file(overwrite=False, extension='hdf5')
         self.assertTrue(os.path.isfile(
             '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label)))
 
-    def test_save_and_overwrite(self):
+    def test_save_and_overwrite_hdf5(self):
         shutil.rmtree(
             '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label),
             ignore_errors=True)
+        self.result.save_to_file(overwrite=True, extension='hdf5')
+        self.result.save_to_file(overwrite=True, extension='hdf5')
+        self.assertFalse(os.path.isfile(
+            '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label)))
+
+    def test_save_and_overwrite_default(self):
+        shutil.rmtree(
+            '{}/{}_result.json.old'.format(self.result.outdir, self.result.label),
+            ignore_errors=True)
         self.result.save_to_file(overwrite=True)
         self.result.save_to_file(overwrite=True)
         self.assertFalse(os.path.isfile(
-            '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label)))
+            '{}/{}_result.json.old'.format(self.result.outdir, self.result.label)))
 
     def test_save_samples(self):
         self.result.save_posterior_samples()
-- 
GitLab