Commit d57ea0c3 authored by Matthew David Pitkin's avatar Matthew David Pitkin Committed by Colm Talbot

Allow JSON results files to be gzipped

parent 8c03c68f
......@@ -20,7 +20,7 @@ from .utils import (logger, infer_parameters_from_function,
from .prior import Prior, PriorDict, DeltaFunction
def result_file_name(outdir, label, extension='json'):
def result_file_name(outdir, label, extension='json', gzip=False):
""" Returns the standard filename used for a result file
Parameters
......@@ -31,18 +31,23 @@ def result_file_name(outdir, label, extension='json'):
Naming scheme of the output file
extension: str, optional
Whether to save as `hdf5` or `json`
gzip: bool, optional
Set to True to append `.gz` to the extension for saving in gzipped format
Returns
-------
str: File name of the output file
"""
if extension in ['json', 'hdf5']:
return '{}/{}_result.{}'.format(outdir, label, extension)
if extension == 'json' and gzip:
return '{}/{}_result.{}.gz'.format(outdir, label, extension)
else:
return '{}/{}_result.{}'.format(outdir, label, extension)
else:
raise ValueError("Extension type {} not understood".format(extension))
def _determine_file_name(filename, outdir, label, extension):
def _determine_file_name(filename, outdir, label, extension, gzip):
""" Helper method to determine the filename """
if filename is not None:
return filename
......@@ -50,10 +55,10 @@ def _determine_file_name(filename, outdir, label, extension):
if (outdir is None) and (label is None):
raise ValueError("No information given to load file")
else:
return result_file_name(outdir, label, extension)
return result_file_name(outdir, label, extension, gzip)
def read_in_result(filename=None, outdir=None, label=None, extension='json'):
def read_in_result(filename=None, outdir=None, label=None, extension='json', gzip=False):
""" Reads in a stored bilby result object
Parameters
......@@ -65,10 +70,13 @@ def read_in_result(filename=None, outdir=None, label=None, extension='json'):
naming scheme.
"""
filename = _determine_file_name(filename, outdir, label, extension)
filename = _determine_file_name(filename, outdir, label, extension, gzip)
# Get the actual extension (may differ from the default extension if the filename is given)
extension = os.path.splitext(filename)[1].lstrip('.')
if extension == 'gz': # gzipped file
extension = os.path.splitext(os.path.splitext(filename)[0])[1].lstrip('.')
if 'json' in extension:
result = Result.from_json(filename=filename)
elif ('hdf5' in extension) or ('h5' in extension):
......@@ -91,7 +99,7 @@ class Result(object):
log_prior_evaluations=None, sampling_time=None, nburn=None,
walkers=None, max_autocorrelation_time=None,
parameter_labels=None, parameter_labels_with_unit=None,
version=None):
gzip=False, version=None):
""" A class to store the results of the sampling run
Parameters
......@@ -129,6 +137,8 @@ class Result(object):
The estimated maximum autocorrelation time for MCMC samplers
parameter_labels, parameter_labels_with_unit: list
Lists of the latex-formatted parameter labels
gzip: bool
Set to True to gzip the results file (if using json format)
version: str,
Version information for software used to generate the result. Note,
this information is generated when the result object is initialized
......@@ -191,7 +201,7 @@ class Result(object):
"""
import deepdish
filename = _determine_file_name(filename, outdir, label, 'hdf5')
filename = _determine_file_name(filename, outdir, label, 'hdf5', False)
if os.path.isfile(filename):
dictionary = deepdish.io.load(filename)
......@@ -209,7 +219,7 @@ class Result(object):
raise IOError("No result '{}' found".format(filename))
@classmethod
def from_json(cls, filename=None, outdir=None, label=None):
def from_json(cls, filename=None, outdir=None, label=None, gzip=False):
""" Read in a saved .json data file
Parameters
......@@ -229,11 +239,17 @@ class Result(object):
If no bilby.core.result.Result is found in the path
"""
filename = _determine_file_name(filename, outdir, label, 'json')
filename = _determine_file_name(filename, outdir, label, 'json', gzip)
if os.path.isfile(filename):
with open(filename, 'r') as file:
dictionary = json.load(file, object_hook=decode_bilby_json)
if gzip or os.path.splitext(filename)[1].lstrip('.') == 'gz':
import gzip
with gzip.GzipFile(filename, 'r') as file:
json_str = file.read().decode('utf-8')
dictionary = json.loads(json_str, object_hook=decode_bilby_json)
else:
with open(filename, 'r') as file:
dictionary = json.load(file, object_hook=decode_bilby_json)
for key in dictionary.keys():
# Convert the loaded priors to bilby prior type
if key == 'priors':
......@@ -381,7 +397,7 @@ class Result(object):
pass
return dictionary
def save_to_file(self, overwrite=False, outdir=None, extension='json'):
def save_to_file(self, overwrite=False, outdir=None, extension='json', gzip=False):
"""
Writes the Result to a json or deepdish h5 file
......@@ -394,9 +410,12 @@ class Result(object):
Path to the outdir. Default is the one stored in the result object.
extension: str, optional {json, hdf5}
Determines the method to use to store the data
gzip: bool, optional
If true, and outputing to a json file, this will gzip the resulting
file and add '.gz' to the file extension.
"""
outdir = self._safe_outdir_creation(outdir, self.save_to_file)
file_name = result_file_name(outdir, self.label, extension)
file_name = result_file_name(outdir, self.label, extension, gzip)
if os.path.isfile(file_name):
if overwrite:
......@@ -423,8 +442,15 @@ class Result(object):
try:
if extension == 'json':
with open(file_name, 'w') as file:
json.dump(dictionary, file, indent=2, cls=BilbyJsonEncoder)
if gzip:
import gzip
# encode to a string
json_str = json.dumps(dictionary, cls=BilbyJsonEncoder).encode('utf-8')
with gzip.GzipFile(file_name, 'w') as file:
file.write(json_str)
else:
with open(file_name, 'w') as file:
json.dump(dictionary, file, indent=2, cls=BilbyJsonEncoder)
elif extension == 'hdf5':
import deepdish
for key in dictionary:
......
......@@ -44,8 +44,8 @@ if command_line_args.sampler_help:
def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
sampler='dynesty', use_ratio=None, injection_parameters=None,
conversion_function=None, plot=False, default_priors_file=None,
clean=None, meta_data=None, save=True, result_class=None,
**kwargs):
clean=None, meta_data=None, save=True, gzip=False,
result_class=None, **kwargs):
"""
The primary interface to easy parameter estimation
......@@ -88,6 +88,8 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
save: bool
If true, save the priors and results to disk.
If hdf5, save as an hdf5 file instead of json.
gzip: bool
If true, and save is true, gzip the saved results file.
result_class: bilby.core.result.Result, or child of
The result class to use. By default, `bilby.core.result.Result` is used,
but objects which inherit from this class can be given providing
......@@ -190,7 +192,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
result.save_to_file(extension='hdf5')
logger.info("Results saved to {}/".format(outdir))
elif save:
result.save_to_file()
result.save_to_file(gzip=gzip)
logger.info("Results saved to {}/".format(outdir))
if plot:
result.plot_corner()
......
......@@ -194,6 +194,26 @@ class TestResult(unittest.TestCase):
self.assertEqual(self.result.priors['c'], loaded_result.priors['c'])
self.assertEqual(self.result.priors['d'], loaded_result.priors['d'])
def test_save_and_load_gzip(self):
self.result.save_to_file(gzip=True)
loaded_result = bilby.core.result.read_in_result(
outdir=self.result.outdir, label=self.result.label, gzip=True)
self.assertTrue(np.array_equal
(self.result.posterior.sort_values(by=['x']),
loaded_result.posterior.sort_values(by=['x'])))
self.assertTrue(self.result.fixed_parameter_keys == loaded_result.fixed_parameter_keys)
self.assertTrue(self.result.search_parameter_keys == loaded_result.search_parameter_keys)
self.assertEqual(self.result.meta_data, loaded_result.meta_data)
self.assertEqual(self.result.injection_parameters, loaded_result.injection_parameters)
self.assertEqual(self.result.log_evidence, loaded_result.log_evidence)
self.assertEqual(self.result.log_noise_evidence, loaded_result.log_noise_evidence)
self.assertEqual(self.result.log_evidence_err, loaded_result.log_evidence_err)
self.assertEqual(self.result.log_bayes_factor, loaded_result.log_bayes_factor)
self.assertEqual(self.result.priors['x'], loaded_result.priors['x'])
self.assertEqual(self.result.priors['y'], loaded_result.priors['y'])
self.assertEqual(self.result.priors['c'], loaded_result.priors['c'])
self.assertEqual(self.result.priors['d'], loaded_result.priors['d'])
def test_save_and_dont_overwrite_default(self):
shutil.rmtree(
'{}/{}_result.json.old'.format(self.result.outdir, self.result.label),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment