command.py 13.4 KB
Newer Older
1
#
Leo Pound Singer's avatar
Leo Pound Singer committed
2
# Copyright (C) 2013-2016  Leo Singer
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
18
from __future__ import print_function
19 20 21 22 23 24
"""
Functions that support the command line interface.
"""
__author__ = "Leo Singer <leo.singer@ligo.org>"


25
import argparse
26
import contextlib
Leo Pound Singer's avatar
Leo Pound Singer committed
27 28
from distutils.dir_util import mkpath
from distutils.errors import DistutilsFileError
29
import errno
30
import glob
31
import inspect
32
import itertools
33
import os
34
from select import select
35
import shutil
36
import stat
37
import sys
38
import tempfile
39
import matplotlib
40
from matplotlib import cm
41
from ..plot import cmap
42

43

44 45 46
# Set no-op Matplotlib backend to defer importing anything that requires a GUI
# until we have determined that it is necessary based on the command line
# arguments.
47 48 49 50 51
if 'matplotlib.pyplot' in sys.modules:
    from matplotlib import pyplot as plt
    plt.switch_backend('Template')
else:
    matplotlib.use('Template', warn=False, force=True)
52

53

54 55 56 57 58 59 60 61 62 63
@contextlib.contextmanager
def TemporaryDirectory(suffix='', prefix='tmp', dir=None, delete=True):
    try:
        dir = tempfile.mkdtemp(suffix=suffix, prefix=prefix, dir=dir)
        yield dir
    finally:
        if delete:
            shutil.rmtree(dir)


64 65 66
def chainglob(patterns):
    """Generate a list of all files matching a list of globs."""
    return itertools.chain.from_iterable(glob.iglob(s) for s in patterns)
67 68


69 70
waveform_parser = argparse.ArgumentParser(add_help=False)
group = waveform_parser.add_argument_group(
71
    'waveform options', 'Options that affect template waveform generation')
72 73 74
# FIXME: The O1 uberbank high-mass template, SEOBNRv2_ROM_DoubleSpin, does
# not support frequencies less than 30 Hz.
group.add_argument('--f-low', type=float, metavar='Hz', default=30,
75
    help='Low frequency cutoff [default: %(default)s]')
76 77 78
group.add_argument('--f-high-truncate', type=float, default=0.95,
    help='Truncate waveform at this fraction of the maximum frequency of the '
    'PSD [default: %(default)s]')
79
group.add_argument('--waveform', default='o2-uberbank',
80
    help='Template waveform approximant (e.g., TaylorF2threePointFivePN) '
81
    '[default: O2 uberbank mass-dependent waveform]')
82 83 84 85 86
del group


prior_parser = argparse.ArgumentParser(add_help=False)
group = prior_parser.add_argument_group(
87 88 89 90
    'prior options', 'Options that affect the BAYESTAR likelihood')
group.add_argument('--phase-convention', default='antifindchirp',
    choices=('findchirp', 'antifindchirp'),
    help='Phase convention [default: %(default)s]')
91 92 93 94 95 96 97 98 99
group.add_argument('--min-distance', type=float, metavar='Mpc',
    help='Minimum distance of prior in megaparsecs '
    '[default: infer from effective distance]')
group.add_argument('--max-distance', type=float, metavar='Mpc',
    help='Maximum distance of prior in megaparsecs '
    '[default: infer from effective distance]')
group.add_argument('--prior-distance-power', type=int, metavar='-1|2',
    default=2, help='Distance prior '
    '[-1 for uniform in log, 2 for uniform in volume, default: %(default)s]')
100 101
group.add_argument('--enable-snr-series', default=False, action='store_true',
    help='Enable input of SNR time series (WARNING: UNREVIEWED!) [default: no]')
102 103 104
del group


105 106 107 108 109 110 111 112 113 114
skymap_parser = argparse.ArgumentParser(add_help=False)
group = skymap_parser.add_argument_group(
    'sky map output options', 'Options that affect sky map output')
group.add_argument('--nside', '-n', type=int, default=-1,
    help='HEALPix resolution [default: auto]')
group.add_argument('--chain-dump', default=False, action='store_true',
    help='For MCMC methods, dump the sample chain to disk [default: no]')
del group


115 116 117 118 119 120 121 122 123 124 125 126 127 128
class MatplotlibFigureType(argparse.FileType):
    def __init__(self):
        super(MatplotlibFigureType, self).__init__('wb')

    @staticmethod
    def __show():
        from matplotlib import pyplot as plt
        return plt.show()

    def __save(self):
        from matplotlib import pyplot as plt
        return plt.savefig(self.string)

    def __call__(self, string):
129
        from matplotlib import pyplot as plt
130
        if string == '-':
131
            plt.switch_backend(matplotlib.rcParamsOrig['backend'])
132 133 134 135
            return self.__show
        else:
            with super(MatplotlibFigureType, self).__call__(string):
                pass
136
            plt.switch_backend('agg')
137 138 139
            self.string = string
            return self.__save

140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
class HelpChoicesAction(argparse.Action):
    def __init__(self,
                 option_strings,
                 choices=(),
                 dest=argparse.SUPPRESS,
                 default=argparse.SUPPRESS):
        name = option_strings[0].replace('--help-', '')
        super(HelpChoicesAction, self).__init__(
            option_strings=option_strings,
            dest=dest,
            default=default,
            nargs=0,
            help='show support values for --' + name + ' and exit')
        self._name = name
        self._choices = choices

    def __call__(self, parser, namespace, values, option_string=None):
        print('Supported values for --' + self._name + ':')
        for choice in self._choices:
            print(choice)
        parser.exit()

162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
def type_with_sideeffect(type):
    def decorator(sideeffect):
        def func(value):
            ret = type(value)
            sideeffect(ret)
            return ret
        return func
    return decorator

@type_with_sideeffect(str)
def colormap(value):
    from matplotlib import rcParams
    rcParams['image.cmap'] = value

@type_with_sideeffect(float)
Leo Pound Singer's avatar
Leo Pound Singer committed
177
def figwidth(value):
178
    from matplotlib import rcParams
179
    rcParams['figure.figsize'][0] = float(value)
180 181 182 183

@type_with_sideeffect(float)
def figheight(value):
    from matplotlib import rcParams
184
    rcParams['figure.figsize'][1] = float(value)
185 186 187 188

@type_with_sideeffect(int)
def dpi(value):
    from matplotlib import rcParams
189
    rcParams['figure.dpi'] = rcParams['savefig.dpi'] = float(value)
190 191 192 193 194 195 196 197 198 199

figure_parser = argparse.ArgumentParser(add_help=False)
colormap_choices = sorted(cm.cmap_d.keys())
group = figure_parser.add_argument_group(
    'figure options', 'Options that affect figure output format')
group.add_argument(
    '-o', '--output', metavar='FILE.{pdf,png}',
    default='-', type=MatplotlibFigureType(),
    help='name of output file [default: plot to screen]')
group.add_argument(
200 201 202 203 204
    '--colormap', default='cylon', choices=colormap_choices,
    type=colormap, metavar='CMAP',
    help='name of matplotlib colormap [default: %(default)s]')
group.add_argument(
    '--help-colormap', action=HelpChoicesAction, choices=colormap_choices)
205
group.add_argument(
Leo Pound Singer's avatar
Leo Pound Singer committed
206
    '--figure-width', metavar='INCHES', type=figwidth, default='8',
207 208
    help='width of figure in inches [default: %(default)s]')
group.add_argument(
209
    '--figure-height', metavar='INCHES', type=figheight, default='6',
210 211 212 213 214 215 216 217
    help='height of figure in inches [default: %(default)s]')
group.add_argument(
    '--dpi', metavar='PIXELS', type=dpi, default=300,
    help='resolution of figure in dots per inch [default: %(default)s]')
del colormap_choices
del group


218 219 220
# Defer loading SWIG bindings until version string is needed.
class VersionAction(argparse._VersionAction):
    def __call__(self, parser, namespace, values, option_string=None):
221 222
        from .. import InferenceVCSInfo
        self.version = 'LALInference ' + InferenceVCSInfo.version
223 224 225 226
        super(VersionAction, self).__call__(
            parser, namespace, values, option_string)


227
class ArgumentParser(argparse.ArgumentParser):
Leo Pound Singer's avatar
Leo Pound Singer committed
228 229 230 231 232 233 234 235 236 237 238 239 240 241
    """
    An ArgumentParser subclass with some sensible defaults.

    - Any ``.py`` suffix is stripped from the program name, because the
      program is probably being invoked from the stub shell script.

    - The description is taken from the docstring of the file in which the
      ArgumentParser is created.

    - If the description is taken from the docstring, then whitespace in
      the description is preserved.

    - A ``--version`` option is added that prints the version of LALInference.
    """
242 243 244 245 246 247
    def __init__(self,
                 prog=None,
                 usage=None,
                 description=None,
                 epilog=None,
                 parents=[],
248
                 formatter_class=None,
249 250 251 252 253 254 255 256 257 258
                 prefix_chars='-',
                 fromfile_prefix_chars=None,
                 argument_default=None,
                 conflict_handler='error',
                 add_help=True):
        if prog is None:
            prog = os.path.basename(sys.argv[0]).replace('.py', '')
        if description is None:
            parent_frame = inspect.currentframe().f_back
            description = parent_frame.f_locals.get('__doc__', None)
259 260 261
            if formatter_class is None:
                formatter_class = argparse.RawDescriptionHelpFormatter
        if formatter_class is None:
Leo Pound Singer's avatar
Leo Pound Singer committed
262
            formatter_class = argparse.HelpFormatter
263 264 265 266 267 268 269 270 271 272 273
        super(ArgumentParser, self).__init__(
                 prog=prog,
                 usage=usage,
                 description=description,
                 epilog=epilog,
                 parents=parents,
                 formatter_class=argparse.RawDescriptionHelpFormatter,
                 prefix_chars=prefix_chars,
                 fromfile_prefix_chars=fromfile_prefix_chars,
                 argument_default=argument_default,
                 conflict_handler=conflict_handler,
274
                 add_help=add_help)
275
        self.add_argument('--version', action=VersionAction)
276 277


Leo Pound Singer's avatar
Leo Pound Singer committed
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
class DirType(object):
    """Factory for directory arguments."""

    def __init__(self, create=False):
        self._create = create

    def __call__(self, string):
        if self._create:
            try:
                mkpath(string)
            except DistutilsFileError as e:
                raise argparse.ArgumentTypeError(e.message)
        else:
            try:
                os.listdir(string)
            except OSError as e:
                raise argparse.ArgumentTypeError(e)
        return string


298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
class SQLiteType(argparse.FileType):
    """Open an SQLite database, or fail if it does not exist.
    FIXME: use SQLite URI when we drop support for Python < 3.4.
    See: https://docs.python.org/3.4/whatsnew/3.4.html#sqlite3"""

    def __init__(self, mode='r'):
        super(SQLiteType, self).__init__(mode + 'b')

    def __call__(self, string):
        if string == '-':
            raise argparse.ArgumentTypeError(
                'Cannot open stdin/stdout as an SQLite database')
        with super(SQLiteType, self).__call__(string):
            import sqlite3
            return sqlite3.connect(string)


def sqlite_get_filename(connection):
    """Get the name of the file associated with an SQLite connection"""
    result = connection.execute('pragma database_list').fetchall()
    try:
        (_, _, filename), = result
    except ValueError:
        raise RuntimeError('Expected exactly one attached database')
    return filename
323 324 325 326


def rename(src, dst):
    """Like os.rename(src, dst), but works across different devices because it
327
    catches and handles EXDEV ('Invalid cross-device link') errors."""
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
    try:
        os.rename(src, dst)
    except OSError as e:
        if e.errno == errno.EXDEV:
            dir, suffix = os.path.split(dst)
            tmpfid, tmpdst = tempfile.mkstemp(dir=dir, suffix=suffix)
            try:
                os.close(tmpfid)
                shutil.copy2(src, tmpdst)
                os.rename(tmpdst, dst)
            except:
                os.remove(tmpdst)
                raise
        else:
            raise
343 344 345 346 347 348 349 350


def register_to_xmldoc(xmldoc, parser, opts, **kwargs):
    from glue.ligolw.utils import process
    return process.register_to_xmldoc(
        xmldoc, parser.prog,
        {key: (value.name if hasattr(value, 'read') else value)
        for key, value in opts.__dict__.items()})
351 352


353
def iterlines(file, start_message='Waiting for input on stdin. Type control-D followed by a newline to terminate.', stop_message='Reached end of file. Exiting.'):
354 355 356 357
    """Safely iterate over non-emtpy lines in a file. Works around buffering
    issues with `for line in sys.stdin`. Also works around early closing of
    fifos (named pipes)."""
    fd = file.fileno()
358
    # Determine if the file is a FIFO (named pipe) or a TTY (terminal).
359
    is_fifo = stat.S_ISFIFO(os.fstat(fd).st_mode)
360 361 362 363
    is_tty = os.isatty(fd)

    if is_tty:
        print(start_message, file=sys.stderr)
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386

    while True:
        # Wait until some data is available for reading.
        rlist, _, _ = select([fd], [], [])
        assert len(rlist) == 1 and rlist[0] == fd

        # Read a line.
        line = file.readline()

        if not line:
            if is_fifo:
                # If we reached EOF, and this is a FIFO, then just keep reading.
                continue
            else:
                # If we reached EOF, and this is not a FIFO, then exit.
                break

        # Strip off the trailing newline and any whitespace.
        line = line.strip()

        # Emit the line if it is not empty.
        if line:
            yield line
387 388 389

    if is_tty:
        print(stop_message, file=sys.stderr)