Commit d9150de8 authored by Leo Pound Singer's avatar Leo Pound Singer

bayestar: add argparse GlobAction class for file globs

Original: 24b99a47975bf3d8811fc942f84ced7dd7d6533c
parent e29dd910
......@@ -62,7 +62,7 @@ if __name__ == '__main__':
'db', type=command.SQLiteType('r'), metavar='DB.sqlite',
help='Input SQLite database from search pipeline')
parser.add_argument(
'fitsfileglobs', metavar='GLOB.fits[.gz]', nargs='+',
'fitsfilenames', metavar='GLOB.fits[.gz]', nargs='+', action='glob',
help='Input FITS filenames and/or globs')
opts = parser.parse_args()
......@@ -146,9 +146,6 @@ if __name__ == '__main__':
(command.sqlite_get_filename(db), contours, modes, areas)
).imap_unordered
progress.update(-1, 'obtaining filenames of sky maps')
fitsfilenames = tuple(command.chainglob(opts.fitsfileglobs))
colnames = (
['coinc_event_id', 'simulation_id', 'far', 'snr', 'searched_area',
'searched_prob', 'offset', 'runtime', 'distmean', 'diststd',
......@@ -160,8 +157,8 @@ if __name__ == '__main__':
print(*colnames, sep="\t", file=opts.output)
count_records = 0
progress.max = len(fitsfilenames)
for record in map(process, fitsfilenames):
progress.max = len(opts.fitsfilenames)
for record in map(process, opts.fitsfilenames):
count_records += 1
progress.update(count_records, record[0])
print(*record, sep="\t", file=opts.output)
......@@ -31,7 +31,7 @@ parser.add_argument('--contour', metavar='PERCENT', type=float, default=90,
parser.add_argument('--alpha', metavar='ALPHA', type=float, default=0.1,
help='alpha blending for each sky map [default: %(default)s]')
parser.add_argument(
'fitsfileglobs', metavar='GLOB.fits[.gz]', nargs='+',
'fitsfilenames', metavar='GLOB.fits[.gz]', nargs='+', action='glob',
help='Input FITS filenames and/or globs')
parser.set_defaults(colormap=None)
opts = parser.parse_args()
......@@ -54,19 +54,16 @@ ax.grid()
progress = ProgressBar()
progress.update(-1, 'obtaining filenames of sky maps')
fitsfilenames = tuple(command.chainglob(opts.fitsfileglobs))
progress.max = len(fitsfilenames)
progress.max = len(opts.fitsfilenames)
matplotlib.rc('path', simplify=True, simplify_threshold=1)
if opts.colormap is None:
colors = ['k'] * len(fitsfilenames)
colors = ['k'] * len(opts.fitsfilenames)
else:
colors = matplotlib.cm.get_cmap(opts.colormap)
colors = colors(np.linspace(0, 1, len(fitsfilenames)))
for count_records, (color, fitsfilename) in enumerate(zip(colors, fitsfilenames)):
colors = colors(np.linspace(0, 1, len(opts.fitsfilenames)))
for count_records, (color, fitsfilename) in enumerate(zip(colors, opts.fitsfilenames)):
progress.update(count_records, fitsfilename)
skymap, metadata = fits.read_sky_map(fitsfilename, nest=None)
nside = hp.npix2nside(len(skymap))
......
......@@ -24,6 +24,7 @@ __author__ = "Leo Singer <leo.singer@ligo.org>"
import argparse
import contextlib
import copy
from distutils.dir_util import mkpath
from distutils.errors import DistutilsFileError
import errno
......@@ -59,9 +60,15 @@ def TemporaryDirectory(suffix='', prefix='tmp', dir=None, delete=True):
shutil.rmtree(dir)
def chainglob(patterns):
"""Generate a list of all files matching a list of globs."""
return itertools.chain.from_iterable(glob.iglob(s) for s in patterns)
class GlobAction(argparse._AppendAction):
"""Generate a list of filenames from a list of filenames and globs."""
def __call__(self, parser, namespace, values, *args, **kwargs):
values = tuple(
itertools.chain.from_iterable(glob.iglob(s) for s in values))
if values:
super(GlobAction, self).__call__(
parser, namespace, values, *args, **kwargs)
waveform_parser = argparse.ArgumentParser(add_help=False)
......@@ -279,6 +286,7 @@ class ArgumentParser(argparse.ArgumentParser):
conflict_handler=conflict_handler,
add_help=add_help)
self.add_argument('--version', action=VersionAction)
self.register('action', 'glob', GlobAction)
class DirType(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment