Skip to content
Snippets Groups Projects
Commit 4b38827b authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'add-cli-for-merging' into 'master'

Adds option to merge runs from the command line

See merge request !467
parents 03219f50 b3ec4358
No related branches found
No related tags found
1 merge request!467Adds option to merge runs from the command line
Pipeline #60011 passed
...@@ -52,7 +52,7 @@ class Likelihood(object): ...@@ -52,7 +52,7 @@ class Likelihood(object):
@property @property
def meta_data(self): def meta_data(self):
return self._meta_data return getattr(self, '_meta_data', None)
@meta_data.setter @meta_data.setter
def meta_data(self, meta_data): def meta_data(self, meta_data):
......
...@@ -408,12 +408,15 @@ class Result(object): ...@@ -408,12 +408,15 @@ class Result(object):
pass pass
return dictionary return dictionary
def save_to_file(self, overwrite=False, outdir=None, extension='json', gzip=False): def save_to_file(self, filename=None, overwrite=False, outdir=None,
extension='json', gzip=False):
""" """
Writes the Result to a json or deepdish h5 file Writes the Result to a json or deepdish h5 file
Parameters Parameters
---------- ----------
filename: optional,
Filename to write to (overwrites the default)
overwrite: bool, optional overwrite: bool, optional
Whether or not to overwrite an existing result file. Whether or not to overwrite an existing result file.
default=False default=False
...@@ -431,19 +434,20 @@ class Result(object): ...@@ -431,19 +434,20 @@ class Result(object):
extension = "json" extension = "json"
outdir = self._safe_outdir_creation(outdir, self.save_to_file) outdir = self._safe_outdir_creation(outdir, self.save_to_file)
file_name = result_file_name(outdir, self.label, extension, gzip) if filename is None:
filename = result_file_name(outdir, self.label, extension, gzip)
if os.path.isfile(file_name): if os.path.isfile(filename):
if overwrite: if overwrite:
logger.debug('Removing existing file {}'.format(file_name)) logger.debug('Removing existing file {}'.format(filename))
os.remove(file_name) os.remove(filename)
else: else:
logger.debug( logger.debug(
'Renaming existing file {} to {}.old'.format(file_name, 'Renaming existing file {} to {}.old'.format(filename,
file_name)) filename))
os.rename(file_name, file_name + '.old') os.rename(filename, filename + '.old')
logger.debug("Saving result to {}".format(file_name)) logger.debug("Saving result to {}".format(filename))
# Convert the prior to a string representation for saving on disk # Convert the prior to a string representation for saving on disk
dictionary = self._get_save_data_dictionary() dictionary = self._get_save_data_dictionary()
...@@ -462,17 +466,17 @@ class Result(object): ...@@ -462,17 +466,17 @@ class Result(object):
import gzip import gzip
# encode to a string # encode to a string
json_str = json.dumps(dictionary, cls=BilbyJsonEncoder).encode('utf-8') json_str = json.dumps(dictionary, cls=BilbyJsonEncoder).encode('utf-8')
with gzip.GzipFile(file_name, 'w') as file: with gzip.GzipFile(filename, 'w') as file:
file.write(json_str) file.write(json_str)
else: else:
with open(file_name, 'w') as file: with open(filename, 'w') as file:
json.dump(dictionary, file, indent=2, cls=BilbyJsonEncoder) json.dump(dictionary, file, indent=2, cls=BilbyJsonEncoder)
elif extension == 'hdf5': elif extension == 'hdf5':
import deepdish import deepdish
for key in dictionary: for key in dictionary:
if isinstance(dictionary[key], pd.DataFrame): if isinstance(dictionary[key], pd.DataFrame):
dictionary[key] = dictionary[key].to_dict() dictionary[key] = dictionary[key].to_dict()
deepdish.io.save(file_name, dictionary) deepdish.io.save(filename, dictionary)
else: else:
raise ValueError("Extension type {} not understood".format(extension)) raise ValueError("Extension type {} not understood".format(extension))
except Exception as e: except Exception as e:
...@@ -1293,7 +1297,7 @@ class ResultList(list): ...@@ -1293,7 +1297,7 @@ class ResultList(list):
result = copy(self[0]) result = copy(self[0])
if result.label is not None: if result.label is not None:
result.label += 'combined' result.label += '_combined'
self._check_consistent_sampler() self._check_consistent_sampler()
self._check_consistent_data() self._check_consistent_data()
......
...@@ -29,14 +29,17 @@ import bilby ...@@ -29,14 +29,17 @@ import bilby
def setup_command_line_args(): def setup_command_line_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Helper tool for bilby result files", description="Helper tool for bilby result files")
epilog=print(__doc__))
parser.add_argument("-r", "--results", nargs='+', required=True, parser.add_argument("-r", "--results", nargs='+', required=True,
help="List of results files.") help="List of results files.")
parser.add_argument("-c", "--convert", type=str, choices=['json', 'hdf5'], parser.add_argument("-c", "--convert", type=str, choices=['json', 'hdf5'],
help="Convert all results.", default=False) help="Convert all results.", default=False)
parser.add_argument("-m", "--merge", action='store_true',
help="Merge the set of runs, output saved using the outdir and label")
parser.add_argument("-o", "--outdir", type=str, default=None, parser.add_argument("-o", "--outdir", type=str, default=None,
help="Output directory.") help="Output directory.")
parser.add_argument("-l", "--label", type=str, default=None,
help="New label for output result object")
parser.add_argument("-b", "--bayes", action='store_true', parser.add_argument("-b", "--bayes", action='store_true',
help="Print all Bayes factors.") help="Print all Bayes factors.")
parser.add_argument("-p", "--print", nargs='+', default=None, parser.add_argument("-p", "--print", nargs='+', default=None,
...@@ -55,7 +58,7 @@ def read_in_results(filename_list): ...@@ -55,7 +58,7 @@ def read_in_results(filename_list):
results_list = [] results_list = []
for filename in filename_list: for filename in filename_list:
results_list.append(bilby.core.result.read_in_result(filename=filename)) results_list.append(bilby.core.result.read_in_result(filename=filename))
return results_list return bilby.core.result.ResultList(results_list)
def print_bayes_factors(results_list): def print_bayes_factors(results_list):
...@@ -97,3 +100,10 @@ def main(): ...@@ -97,3 +100,10 @@ def main():
print_bayes_factors(results_list) print_bayes_factors(results_list)
if args.ipython: if args.ipython:
drop_to_ipython(results_list) drop_to_ipython(results_list)
if args.merge:
result = results_list.combine()
if args.label is not None:
result.label = args.label
if args.outdir is not None:
result.outdir = args.outdir
result.save_to_file()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment