command.py 17.9 KB
Newer Older
1
#
Leo Pound Singer's avatar
Leo Pound Singer committed
2
# Copyright (C) 2013-2017  Leo Singer
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
"""
Functions that support the command line interface.
"""

Leo Pound Singer's avatar
Leo Pound Singer committed
22
from __future__ import print_function
23
import argparse
Leo Pound Singer's avatar
Leo Pound Singer committed
24 25
from distutils.dir_util import mkpath
from distutils.errors import DistutilsFileError
26
import glob
27
import inspect
28
import itertools
29
import logging
30 31
import os
import sys
32
import tempfile
33
import matplotlib
34
from matplotlib import cm
35
from ..plot import cmap
36
from ..util import sqlite
37

38

39 40 41
# Set no-op Matplotlib backend to defer importing anything that requires a GUI
# until we have determined that it is necessary based on the command line
# arguments.
42 43 44 45 46
if 'matplotlib.pyplot' in sys.modules:
    from matplotlib import pyplot as plt
    plt.switch_backend('Template')
else:
    matplotlib.use('Template', warn=False, force=True)
47

48

49 50 51 52 53 54
# FIXME: Remove this after all Matplotlib monkeypatches are obsolete.
import matplotlib
import distutils.version
mpl_version = distutils.version.LooseVersion(matplotlib.__version__)


55 56 57 58 59
def get_version():
    from .. import InferenceVCSInfo as vcs_info
    return vcs_info.name + ' ' + vcs_info.version


60
class GlobAction(argparse._StoreAction):
61 62 63
    """Generate a list of filenames from a list of filenames and globs."""

    def __call__(self, parser, namespace, values, *args, **kwargs):
64
        values = list(
65 66 67 68
            itertools.chain.from_iterable(glob.iglob(s) for s in values))
        if values:
            super(GlobAction, self).__call__(
                parser, namespace, values, *args, **kwargs)
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
        nvalues = getattr(namespace, self.dest)
        nvalues = 0 if nvalues is None else len(nvalues)
        if self.nargs == argparse.OPTIONAL:
            if nvalues > 1:
                msg = 'expected at most one file'
            else:
                msg = None
        elif self.nargs == argparse.ONE_OR_MORE:
            if nvalues < 1:
                msg = 'expected at least one file'
            else:
                msg = None
        elif self.nargs == argparse.ZERO_OR_MORE:
            msg = None
        elif int(self.nargs) != nvalues:
            msg = 'expected exactly %s file' % self.nargs
            if self.nargs != 1:
                msg += 's'
        else:
            msg = None
        if msg is not None:
            msg += ', but found '
            msg += '{} file'.format(nvalues)
            if nvalues != 1:
                msg += 's'
            raise argparse.ArgumentError(self, msg)
95 96


97 98
waveform_parser = argparse.ArgumentParser(add_help=False)
group = waveform_parser.add_argument_group(
99
    'waveform options', 'Options that affect template waveform generation')
100 101
# FIXME: The O1 uberbank high-mass template, SEOBNRv2_ROM_DoubleSpin, does
# not support frequencies less than 30 Hz.
Leo Pound Singer's avatar
Leo Pound Singer committed
102 103
group.add_argument(
    '--f-low', type=float, metavar='Hz', default=30,
104
    help='Low frequency cutoff [default: %(default)s]')
Leo Pound Singer's avatar
Leo Pound Singer committed
105 106
group.add_argument(
    '--f-high-truncate', type=float, default=0.95,
107 108
    help='Truncate waveform at this fraction of the maximum frequency of the '
    'PSD [default: %(default)s]')
Leo Pound Singer's avatar
Leo Pound Singer committed
109 110
group.add_argument(
    '--waveform', default='o2-uberbank',
111
    help='Template waveform approximant (e.g., TaylorF2threePointFivePN) '
112
    '[default: O2 uberbank mass-dependent waveform]')
113 114 115 116 117
del group


prior_parser = argparse.ArgumentParser(add_help=False)
group = prior_parser.add_argument_group(
118
    'prior options', 'Options that affect the BAYESTAR likelihood')
Leo Pound Singer's avatar
Leo Pound Singer committed
119 120
group.add_argument(
    '--min-distance', type=float, metavar='Mpc',
121 122
    help='Minimum distance of prior in megaparsecs '
    '[default: infer from effective distance]')
Leo Pound Singer's avatar
Leo Pound Singer committed
123 124
group.add_argument(
    '--max-distance', type=float, metavar='Mpc',
125 126
    help='Maximum distance of prior in megaparsecs '
    '[default: infer from effective distance]')
Leo Pound Singer's avatar
Leo Pound Singer committed
127 128 129
group.add_argument(
    '--prior-distance-power', type=int, metavar='-1|2', default=2,
    help='Distance prior '
130
    '[-1 for uniform in log, 2 for uniform in volume, default: %(default)s]')
Leo Pound Singer's avatar
Leo Pound Singer committed
131 132 133 134
group.add_argument(
    '--cosmology', default=False, action='store_true',
    help='Use cosmological comoving volume prior [default: %(default)s]')
group.add_argument(
135 136 137
    '--disable-snr-series', dest='enable_snr_series', action='store_false',
    help='Disable input of SNR time series (WARNING: UNREVIEWED!) '
    '[default: enabled]')
138 139 140
del group


141 142 143
skymap_parser = argparse.ArgumentParser(add_help=False)
group = skymap_parser.add_argument_group(
    'sky map output options', 'Options that affect sky map output')
Leo Pound Singer's avatar
Leo Pound Singer committed
144 145
group.add_argument(
    '--nside', '-n', type=int, default=-1,
146
    help='HEALPix resolution [default: auto]')
Leo Pound Singer's avatar
Leo Pound Singer committed
147 148
group.add_argument(
    '--chain-dump', default=False, action='store_true',
149 150 151 152
    help='For MCMC methods, dump the sample chain to disk [default: no]')
del group


153
class MatplotlibFigureType(argparse.FileType):
Leo Pound Singer's avatar
Leo Pound Singer committed
154

155 156 157 158 159 160 161 162 163 164
    def __init__(self):
        super(MatplotlibFigureType, self).__init__('wb')

    @staticmethod
    def __show():
        from matplotlib import pyplot as plt
        return plt.show()

    def __save(self):
        from matplotlib import pyplot as plt
165 166 167 168 169 170 171 172 173 174
        _, ext = os.path.splitext(self.string)
        ext = ext.lower()
        program, _ = os.path.splitext(os.path.basename(sys.argv[0]))
        cmdline = ' '.join([program] + sys.argv[1:])
        metadata = {'Title': cmdline}
        if ext == '.png':
            metadata['Software'] = get_version()
        elif ext in {'.pdf', '.ps', '.eps'}:
            metadata['Creator'] = get_version()
        return plt.savefig(self.string, metadata=metadata)
175 176

    def __call__(self, string):
177
        from matplotlib import pyplot as plt
178
        if string == '-':
179
            plt.switch_backend(matplotlib.rcParamsOrig['backend'])
180 181 182 183
            return self.__show
        else:
            with super(MatplotlibFigureType, self).__call__(string):
                pass
184
            plt.switch_backend('agg')
185 186 187
            self.string = string
            return self.__save

Leo Pound Singer's avatar
Leo Pound Singer committed
188

189
class HelpChoicesAction(argparse.Action):
Leo Pound Singer's avatar
Leo Pound Singer committed
190

191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
    def __init__(self,
                 option_strings,
                 choices=(),
                 dest=argparse.SUPPRESS,
                 default=argparse.SUPPRESS):
        name = option_strings[0].replace('--help-', '')
        super(HelpChoicesAction, self).__init__(
            option_strings=option_strings,
            dest=dest,
            default=default,
            nargs=0,
            help='show support values for --' + name + ' and exit')
        self._name = name
        self._choices = choices

    def __call__(self, parser, namespace, values, option_string=None):
        print('Supported values for --' + self._name + ':')
        for choice in self._choices:
            print(choice)
        parser.exit()

Leo Pound Singer's avatar
Leo Pound Singer committed
212

213 214 215 216 217 218 219 220 221
def type_with_sideeffect(type):
    def decorator(sideeffect):
        def func(value):
            ret = type(value)
            sideeffect(ret)
            return ret
        return func
    return decorator

Leo Pound Singer's avatar
Leo Pound Singer committed
222

223 224 225 226 227
@type_with_sideeffect(str)
def colormap(value):
    from matplotlib import rcParams
    rcParams['image.cmap'] = value

Leo Pound Singer's avatar
Leo Pound Singer committed
228

229
@type_with_sideeffect(float)
Leo Pound Singer's avatar
Leo Pound Singer committed
230
def figwidth(value):
231
    from matplotlib import rcParams
232
    rcParams['figure.figsize'][0] = float(value)
233

Leo Pound Singer's avatar
Leo Pound Singer committed
234

235 236 237
@type_with_sideeffect(float)
def figheight(value):
    from matplotlib import rcParams
238
    rcParams['figure.figsize'][1] = float(value)
239

Leo Pound Singer's avatar
Leo Pound Singer committed
240

241 242 243
@type_with_sideeffect(int)
def dpi(value):
    from matplotlib import rcParams
244
    rcParams['figure.dpi'] = rcParams['savefig.dpi'] = float(value)
245

Leo Pound Singer's avatar
Leo Pound Singer committed
246

247 248 249 250 251
@type_with_sideeffect(int)
def transparent(value):
    from matplotlib import rcParams
    rcParams['savefig.transparent'] = bool(value)

Leo Pound Singer's avatar
Leo Pound Singer committed
252

253 254 255 256 257 258 259 260 261
figure_parser = argparse.ArgumentParser(add_help=False)
colormap_choices = sorted(cm.cmap_d.keys())
group = figure_parser.add_argument_group(
    'figure options', 'Options that affect figure output format')
group.add_argument(
    '-o', '--output', metavar='FILE.{pdf,png}',
    default='-', type=MatplotlibFigureType(),
    help='name of output file [default: plot to screen]')
group.add_argument(
262 263 264 265 266
    '--colormap', default='cylon', choices=colormap_choices,
    type=colormap, metavar='CMAP',
    help='name of matplotlib colormap [default: %(default)s]')
group.add_argument(
    '--help-colormap', action=HelpChoicesAction, choices=colormap_choices)
267
group.add_argument(
Leo Pound Singer's avatar
Leo Pound Singer committed
268
    '--figure-width', metavar='INCHES', type=figwidth, default='8',
269 270
    help='width of figure in inches [default: %(default)s]')
group.add_argument(
271
    '--figure-height', metavar='INCHES', type=figheight, default='6',
272 273 274 275
    help='height of figure in inches [default: %(default)s]')
group.add_argument(
    '--dpi', metavar='PIXELS', type=dpi, default=300,
    help='resolution of figure in dots per inch [default: %(default)s]')
276 277 278 279 280 281
# FIXME: the savefig.transparent rcparam was added in Matplotlib 1.4,
# but we have to support Matplotlib 1.2 for Scientific Linux 7.
if mpl_version >= '1.4':
    group.add_argument(
        '--transparent', const='1', default='0', nargs='?', type=transparent,
        help='Save image with transparent background [default: false]')
282 283 284 285
del colormap_choices
del group


286 287 288
# Defer loading SWIG bindings until version string is needed.
class VersionAction(argparse._VersionAction):
    def __call__(self, parser, namespace, values, option_string=None):
289
        self.version = get_version()
290 291 292 293
        super(VersionAction, self).__call__(
            parser, namespace, values, option_string)


294 295 296 297 298 299 300 301 302 303 304 305 306 307
@type_with_sideeffect(str)
def loglevel_type(value):
    try:
        value = int(value)
    except ValueError:
        value = value.upper()
    logging.basicConfig(level=value)


class LogLevelAction(argparse._StoreAction):

    def __init__(
            self, option_strings, dest, nargs=None, const=None, default=None,
            type=None, choices=None, required=False, help=None, metavar=None):
Leo Pound Singer's avatar
Leo Pound Singer committed
308 309 310 311 312 313 314 315 316 317
        # FIXME: this broke because of internal changes in the Python standard
        # library logging module between Python 2.7 and 3.6. We should not rely
        # on these undocumented module variables in the first place.
        try:
            logging._levelNames
        except AttributeError:
            metavar = '|'.join(logging._levelToName.values())
        else:
            metavar = '|'.join(
                _ for _ in logging._levelNames.keys() if isinstance(_, str))
318 319 320 321 322 323 324
        type = loglevel_type
        super(LogLevelAction, self).__init__(
            option_strings, dest, nargs=nargs, const=const, default=default,
            type=type, choices=choices, required=required, help=help,
            metavar=metavar)


325
class ArgumentParser(argparse.ArgumentParser):
Leo Pound Singer's avatar
Leo Pound Singer committed
326 327 328 329 330 331 332 333 334 335 336 337 338 339
    """
    An ArgumentParser subclass with some sensible defaults.

    - Any ``.py`` suffix is stripped from the program name, because the
      program is probably being invoked from the stub shell script.

    - The description is taken from the docstring of the file in which the
      ArgumentParser is created.

    - If the description is taken from the docstring, then whitespace in
      the description is preserved.

    - A ``--version`` option is added that prints the version of LALInference.
    """
340 341 342 343 344 345
    def __init__(self,
                 prog=None,
                 usage=None,
                 description=None,
                 epilog=None,
                 parents=[],
346
                 formatter_class=None,
347 348 349 350 351 352 353 354 355 356
                 prefix_chars='-',
                 fromfile_prefix_chars=None,
                 argument_default=None,
                 conflict_handler='error',
                 add_help=True):
        if prog is None:
            prog = os.path.basename(sys.argv[0]).replace('.py', '')
        if description is None:
            parent_frame = inspect.currentframe().f_back
            description = parent_frame.f_locals.get('__doc__', None)
357 358 359
            if formatter_class is None:
                formatter_class = argparse.RawDescriptionHelpFormatter
        if formatter_class is None:
Leo Pound Singer's avatar
Leo Pound Singer committed
360
            formatter_class = argparse.HelpFormatter
361 362 363 364 365 366 367 368 369 370 371
        super(ArgumentParser, self).__init__(
                 prog=prog,
                 usage=usage,
                 description=description,
                 epilog=epilog,
                 parents=parents,
                 formatter_class=argparse.RawDescriptionHelpFormatter,
                 prefix_chars=prefix_chars,
                 fromfile_prefix_chars=fromfile_prefix_chars,
                 argument_default=argument_default,
                 conflict_handler=conflict_handler,
372
                 add_help=add_help)
373
        self.register('action', 'glob', GlobAction)
374 375 376 377
        self.register('action', 'loglevel', LogLevelAction)
        self.register('action', 'version', VersionAction)
        self.add_argument('--version', action='version')
        self.add_argument('-l', '--loglevel', action='loglevel', default='INFO')
378 379


Leo Pound Singer's avatar
Leo Pound Singer committed
380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
class DirType(object):
    """Factory for directory arguments."""

    def __init__(self, create=False):
        self._create = create

    def __call__(self, string):
        if self._create:
            try:
                mkpath(string)
            except DistutilsFileError as e:
                raise argparse.ArgumentTypeError(e.message)
        else:
            try:
                os.listdir(string)
            except OSError as e:
                raise argparse.ArgumentTypeError(e)
        return string


400 401 402
class SQLiteType(argparse.FileType):
    """Open an SQLite database, or fail if it does not exist.
    FIXME: use SQLite URI when we drop support for Python < 3.4.
403 404 405 406 407
    See: https://docs.python.org/3.4/whatsnew/3.4.html#sqlite3

    Here is an example of trying to open a file that does not exist for
    reading (mode='r'). It should raise an exception:

408
    >>> import tempfile
409 410
    >>> filetype = SQLiteType('r')
    >>> filename = tempfile.mktemp()
411
    >>> # Note, simply check or a FileNotFound error in Python 3.
412 413 414 415
    >>> filetype(filename)
    Traceback (most recent call last):
      ...
    argparse.ArgumentTypeError: ...
416 417

    If the file already exists, then it's fine:
418
    >>> import sqlite3
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
    >>> filetype = SQLiteType('r')
    >>> with tempfile.NamedTemporaryFile() as f:
    ...     with sqlite3.connect(f.name) as db:
    ...         _ = db.execute('create table foo (bar char)')
    ...     filetype(f.name)
    <sqlite3.Connection object at ...>

    Here is an example of opening a file for writing (mode='w'), which should
    overwrite the file if it exists. Even if the file was not an SQLite
    database beforehand, this should work:

    >>> filetype = SQLiteType('w')
    >>> with tempfile.NamedTemporaryFile(mode='w') as f:
    ...     print('This is definitely not an SQLite file.', file=f)
    ...     f.flush()
    ...     with filetype(f.name) as db:
    ...         db.execute('create table foo (bar char)')
    <sqlite3.Cursor object at ...>

    Here is an example of opening a file for appending (mode='a'), which should
    NOT overwrite the file if it exists. If the file was not an SQLite database
    beforehand, this should raise an exception.

442
    >>> import pytest
443 444 445 446 447
    >>> filetype = SQLiteType('a')
    >>> with tempfile.NamedTemporaryFile(mode='w') as f:
    ...     print('This is definitely not an SQLite file.', file=f)
    ...     f.flush()
    ...     with filetype(f.name) as db:
448 449 450 451
    ...         db.execute('create table foo (bar char)')
    Traceback (most recent call last):
      ...
    sqlite3.DatabaseError: ...
452 453 454 455 456 457 458 459 460 461 462 463

    And if the database did exist beforehand, then opening for appending
    (mode='a') should not clobber existing tables.

    >>> filetype = SQLiteType('a')
    >>> with tempfile.NamedTemporaryFile() as f:
    ...     with sqlite3.connect(f.name) as db:
    ...         _ = db.execute('create table foo (bar char)')
    ...     with filetype(f.name) as db:
    ...         db.execute('select count(*) from foo').fetchone()
    (0,)
    """
464

465
    def __init__(self, mode):
466 467
        if mode not in 'arw':
            raise ValueError('Unknown file mode: {}'.format(mode))
468
        self.mode = mode
469 470

    def __call__(self, string):
471
        try:
472 473 474
            return sqlite.open(string, self.mode)
        except OSError as e:
            raise argparse.ArgumentTypeError(e)
475 476


477 478 479 480 481 482 483 484 485 486 487
def _sanitize_arg_value_for_xmldoc(value):
    if hasattr(value, 'read'):
        return value.name
    elif isinstance(value, tuple):
        return tuple(_sanitize_arg_value_for_xmldoc(v) for v in value)
    elif isinstance(value, list):
        return [_sanitize_arg_value_for_xmldoc(v) for v in value]
    else:
        return value


488 489
def register_to_xmldoc(xmldoc, parser, opts, **kwargs):
    from glue.ligolw.utils import process
490
    params = {key: _sanitize_arg_value_for_xmldoc(value)
Leo Pound Singer's avatar
Leo Pound Singer committed
491
              for key, value in opts.__dict__.items()}
492
    return process.register_to_xmldoc(xmldoc, parser.prog, params, **kwargs)
Leo Pound Singer's avatar
Leo Pound Singer committed
493 494 495 496 497


start_msg = '\
Waiting for input on stdin. Type control-D followed by a newline to terminate.'
stop_msg = 'Reached end of file. Exiting.'
498 499


Leo Pound Singer's avatar
Leo Pound Singer committed
500
def iterlines(file, start_message=start_msg, stop_message=stop_msg):
501 502
    """Iterate over non-emtpy lines in a file."""
    is_tty = os.isatty(file.fileno())
503 504 505

    if is_tty:
        print(start_message, file=sys.stderr)
506 507 508 509 510 511

    while True:
        # Read a line.
        line = file.readline()

        if not line:
512 513
            # If we reached EOF, then exit.
            break
514 515 516 517 518 519 520

        # Strip off the trailing newline and any whitespace.
        line = line.strip()

        # Emit the line if it is not empty.
        if line:
            yield line
521 522 523

    if is_tty:
        print(stop_message, file=sys.stderr)
524 525 526 527


from lalinference.bayestar.deprecation import warn
warn('ligo.skymap.tool')