From da685c533ffb83df2891c206758ece989f380eca Mon Sep 17 00:00:00 2001 From: Sylvia Biscoveanu <sylvia.biscoveanu@ligo.org> Date: Mon, 25 Feb 2019 17:33:48 -0600 Subject: [PATCH] Jsonify results --- bilby/core/result.py | 86 ++++++++++++++++++++++++++++++---- bilby/core/sampler/__init__.py | 6 ++- test/result_test.py | 58 ++++++++++++++++++++--- 3 files changed, 133 insertions(+), 17 deletions(-) diff --git a/bilby/core/result.py b/bilby/core/result.py index 5c95f1e83..3813e664e 100644 --- a/bilby/core/result.py +++ b/bilby/core/result.py @@ -8,6 +8,7 @@ import numpy as np import deepdish import pandas as pd import corner +import json import scipy.stats import matplotlib import matplotlib.pyplot as plt @@ -19,7 +20,7 @@ from .utils import (logger, infer_parameters_from_function, from .prior import Prior, PriorDict, DeltaFunction -def result_file_name(outdir, label): +def result_file_name(outdir, label, extension='json'): """ Returns the standard filename used for a result file Parameters @@ -28,17 +29,27 @@ def result_file_name(outdir, label): Name of the output directory label: str Naming scheme of the output file + extension: str, optional + Whether to save as `hdf5` or `json` Returns ------- str: File name of the output file """ - return '{}/{}_result.h5'.format(outdir, label) + if extension == 'hdf5': + return '{}/{}_result.h5'.format(outdir, label) + else: + return '{}/{}_result.json'.format(outdir, label) def read_in_result(filename=None, outdir=None, label=None): - """ Wrapper to bilby.core.result.Result.from_hdf5 """ - return Result.from_hdf5(filename=filename, outdir=outdir, label=label) + """ Wrapper to bilby.core.result.Result.from_hdf5 + or bilby.core.result.Result.from_json """ + try: + result = Result.from_json(filename=filename, outdir=outdir, label=label) + except (IOError, ValueError): + result = Result.from_hdf5(filename=filename, outdir=outdir, label=label) + return result class Result(object): @@ -155,7 +166,7 @@ class Result(object): if (outdir is None) and (label is None): raise ValueError("No information given to load file") else: - filename = result_file_name(outdir, label) + filename = result_file_name(outdir, label, extension='hdf5') if os.path.isfile(filename): dictionary = deepdish.io.load(filename) # Some versions of deepdish/pytables return the dictionanary as @@ -169,6 +180,50 @@ class Result(object): else: raise IOError("No result '{}' found".format(filename)) + @classmethod + def from_json(cls, filename=None, outdir=None, label=None): + """ Read in a saved .json data file + + Parameters + ---------- + filename: str + If given, try to load from this filename + outdir, label: str + If given, use the default naming convention for saved results file + + Returns + ------- + result: bilby.core.result.Result + + Raises + ------- + ValueError: If no filename is given and either outdir or label is None + If no bilby.core.result.Result is found in the path + + """ + if filename is None: + if (outdir is None) and (label is None): + raise ValueError("No information given to load file") + else: + filename = result_file_name(outdir, label) + if os.path.isfile(filename): + dictionary = json.load(open(filename, 'r')) + for key in dictionary.keys(): + # Convert some dictionaries back to DataFrames + if key in ['posterior', 'nested_samples']: + dictionary[key] = pd.DataFrame.from_dict(dictionary[key]) + # Convert the loaded priors to bilby prior type + if key == 'priors': + for param in dictionary[key].keys(): + dictionary[key][param] = str(dictionary[key][param]) + dictionary[key] = PriorDict(dictionary[key]) + try: + return cls(**dictionary) + except TypeError as e: + raise IOError("Unable to load dictionary, error={}".format(e)) + else: + raise IOError("No result '{}' found".format(filename)) + def __str__(self): """Print a summary """ if getattr(self, 'posterior', None) is not None: @@ -303,9 +358,9 @@ class Result(object): pass return dictionary - def save_to_file(self, overwrite=False, outdir=None): + def save_to_file(self, overwrite=False, outdir=None, extension='json'): """ - Writes the Result to a deepdish h5 file + Writes the Result to a json or deepdish h5 file Parameters ---------- @@ -314,9 +369,11 @@ class Result(object): default=False outdir: str, optional Path to the outdir. Default is the one stored in the result object. + extension: str, optional + Whether to save as hdf5 instead of json """ outdir = self._safe_outdir_creation(outdir, self.save_to_file) - file_name = result_file_name(outdir, self.label) + file_name = result_file_name(outdir, self.label, extension) if os.path.isfile(file_name): if overwrite: @@ -341,8 +398,19 @@ class Result(object): if hasattr(dictionary['sampler_kwargs'][key], '__call__'): dictionary['sampler_kwargs'][key] = str(dictionary['sampler_kwargs']) + # Convert to json saveable format + if extension != 'hdf5': + for key in dictionary.keys(): + if isinstance(dictionary[key], pd.core.frame.DataFrame): + dictionary[key] = dictionary[key].to_dict() + elif isinstance(dictionary[key], np.ndarray): + dictionary[key] = dictionary[key].tolist() + try: - deepdish.io.save(file_name, dictionary) + if extension == 'hdf5': + deepdish.io.save(file_name, dictionary) + else: + json.dump(dictionary, open(file_name, 'w'), indent=2) except Exception as e: logger.error("\n\n Saving the data has failed with the " "following message:\n {} \n\n".format(e)) diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 7a15b410c..147d95c5d 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -85,6 +85,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', overwritten. save: bool If true, save the priors and results to disk. + If hdf5, save as an hdf5 file instead of json. result_class: bilby.core.result.Result, or child of The result class to use. By default, `bilby.core.result.Result` is used, but objects which inherit from this class can be given providing @@ -183,7 +184,10 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', result.samples_to_posterior(likelihood=likelihood, priors=priors, conversion_function=conversion_function) - if save: + if save == 'hdf5': + result.save_to_file(extension='hdf5') + logger.info("Results saved to {}/".format(outdir)) + elif save: result.save_to_file() logger.info("Results saved to {}/".format(outdir)) if plot: diff --git a/test/result_test.py b/test/result_test.py index 6dbc5a234..2aea03571 100644 --- a/test/result_test.py +++ b/test/result_test.py @@ -45,10 +45,16 @@ class TestResult(unittest.TestCase): del self.result pass - def test_result_file_name(self): + def test_result_file_name_default(self): outdir = 'outdir' label = 'label' self.assertEqual(bilby.core.result.result_file_name(outdir, label), + '{}/{}_result.json'.format(outdir, label)) + + def test_result_file_name_hdf5(self): + outdir = 'outdir' + label = 'label' + self.assertEqual(bilby.core.result.result_file_name(outdir, label, extension='hdf5'), '{}/{}_result.h5'.format(outdir, label)) def test_fail_save_and_load(self): @@ -104,8 +110,8 @@ class TestResult(unittest.TestCase): with self.assertRaises(ValueError): _ = self.result.posterior - def test_save_and_load(self): - self.result.save_to_file() + def test_save_and_load_hdf5(self): + self.result.save_to_file(extension='hdf5') loaded_result = bilby.core.result.read_in_result( outdir=self.result.outdir, label=self.result.label) self.assertTrue(pd.DataFrame.equals @@ -123,23 +129,61 @@ class TestResult(unittest.TestCase): self.assertEqual(self.result.priors['c'], loaded_result.priors['c']) self.assertEqual(self.result.priors['d'], loaded_result.priors['d']) - def test_save_and_dont_overwrite(self): + def test_save_and_load_default(self): + self.result.save_to_file() + loaded_result = bilby.core.result.read_in_result( + outdir=self.result.outdir, label=self.result.label) + self.assertTrue(np.array_equal + (self.result.posterior.sort_values(by=['x']), + loaded_result.posterior.sort_values(by=['x']))) + self.assertTrue(self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys) + self.assertTrue(self.result.search_parameter_keys == loaded_result.search_parameter_keys) + self.assertEqual(self.result.meta_data, loaded_result.meta_data) + self.assertEqual(self.result.injection_parameters, loaded_result.injection_parameters) + self.assertEqual(self.result.log_evidence, loaded_result.log_evidence) + self.assertEqual(self.result.log_noise_evidence, loaded_result.log_noise_evidence) + self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err) + self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor) + self.assertEqual(self.result.priors['x'], loaded_result.priors['x']) + self.assertEqual(self.result.priors['y'], loaded_result.priors['y']) + self.assertEqual(self.result.priors['c'], loaded_result.priors['c']) + self.assertEqual(self.result.priors['d'], loaded_result.priors['d']) + + def test_save_and_dont_overwrite_default(self): shutil.rmtree( - '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label), + '{}/{}_result.json.old'.format(self.result.outdir, self.result.label), ignore_errors=True) self.result.save_to_file(overwrite=False) self.result.save_to_file(overwrite=False) + self.assertTrue(os.path.isfile( + '{}/{}_result.json.old'.format(self.result.outdir, self.result.label))) + + def test_save_and_dont_overwrite_hdf5(self): + shutil.rmtree( + '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label), + ignore_errors=True) + self.result.save_to_file(overwrite=False, extension='hdf5') + self.result.save_to_file(overwrite=False, extension='hdf5') self.assertTrue(os.path.isfile( '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label))) - def test_save_and_overwrite(self): + def test_save_and_overwrite_hdf5(self): shutil.rmtree( '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label), ignore_errors=True) + self.result.save_to_file(overwrite=True, extension='hdf5') + self.result.save_to_file(overwrite=True, extension='hdf5') + self.assertFalse(os.path.isfile( + '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label))) + + def test_save_and_overwrite_default(self): + shutil.rmtree( + '{}/{}_result.json.old'.format(self.result.outdir, self.result.label), + ignore_errors=True) self.result.save_to_file(overwrite=True) self.result.save_to_file(overwrite=True) self.assertFalse(os.path.isfile( - '{}/{}_result.h5.old'.format(self.result.outdir, self.result.label))) + '{}/{}_result.json.old'.format(self.result.outdir, self.result.label))) def test_save_samples(self): self.result.save_posterior_samples() -- GitLab