rest.py 102 KB
Newer Older
1 2
# -*- coding: utf-8 -*-
# Copyright (C) Brian Moe, Branson Stephens (2015)
3
#
4
# This file is part of gracedb
5
#
6 7 8 9
# gracedb is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
10
#
11 12 13 14 15 16 17
# It is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with gracedb.  If not, see <http://www.gnu.org/licenses/>.
18
from base64 import b64encode
19 20 21
from cryptography import x509
from cryptography.hazmat.backends import default_backend
import datetime
22
import json
23
import mimetypes
24
import os
25
import six
26 27 28 29 30
from six.moves import map, http_client
from six.moves.urllib.parse import urlparse, urlencode
import socket
import ssl
import sys
31

32
from .exceptions import HTTPError
Tanner Prestegard's avatar
Tanner Prestegard committed
33
from .version import __version__
Tanner Prestegard's avatar
Tanner Prestegard committed
34
from .utils import event_or_superevent, safe_netrc
35

36
DEFAULT_SERVICE_URL = "https://gracedb.ligo.org/api/"
37

38
# --------------------------------------------------------------------
39 40 41 42
# This monkey patch forces TLSv1 if the python version is 2.6.6.
# It was introduced because clients connection from CIT *occasionally*
# try to use SSLv3.  See:
# http://stackoverflow.com/questions/18669457/python-httplib-ssl23-get-server-hellounknown-protocol
43 44
# --------------------------------------------------------------------
if sys.version_info <= (2, 6, 6):
45
    wrap_socket_orig = ssl.wrap_socket
46

47 48 49 50 51 52 53 54 55 56 57
    def wrap_socket_patched(sock, keyfile=None, certfile=None,
                            server_side=False, cert_reqs=ssl.CERT_NONE,
                            ssl_version=ssl.PROTOCOL_TLSv1, ca_certs=None,
                            do_handshake_on_connect=True,
                            suppress_ragged_eofs=True):
        return wrap_socket_orig(sock, keyfile, certfile, server_side,
                                cert_reqs, ssl_version, ca_certs,
                                do_handshake_on_connect,
                                suppress_ragged_eofs)
    ssl.wrap_socket = wrap_socket_patched

58
# ----------------------------------------------------------------
59
# HTTP/S Proxy classes
60 61 62
# Taken from: http://code.activestate.com/recipes/456195/


63 64 65
class ProxyHTTPConnection(http_client.HTTPConnection):

    _ports = {'http': 80, 'https': 443}
66 67

    def request(self, method, url, body=None, headers={}):
68 69
        # request is called before connect, so can interpret url and get
        # real host/port to be used to make CONNECT request to proxy
70 71 72 73
        o = urlparse(url)
        proto = o.scheme
        port = o.port
        host = o.hostname
74
        if proto is None:
75
            raise ValueError("unknown URL type: %s" % url)
76 77 78 79
        if port is None:
            try:
                port = self._ports[proto]
            except KeyError:
80
                raise ValueError("unknown protocol for: %s" % url)
81 82
        self._real_host = host
        self._real_port = port
83
        http_client.HTTPConnection.request(self, method, url, body, headers)
84 85

    def connect(self):
86 87 88 89 90 91 92
        http_client.HTTPConnection.connect(self)
        # send proxy CONNECT request
        self.send("CONNECT {0}:{1} HTTP/1.0\r\n\r\n".format(
                  self._real_host, self._real_port))
        # expect a HTTP/1.0 200 Connection established
        response = self.response_class(self.sock, strict=self.strict,
                                       method=self._method)
93
        (version, code, message) = response._read_status()
94
        # probably here we can handle auth requests...
95
        if code != 200:
96
            # proxy returned and error, abort connection, and raise exception
97
            self.close()
98 99 100
            raise socket.error("Proxy connection failed: {0} {1}".format(
                               code, message.strip()))
        # eat up header block from proxy....
101
        while True:
102
            # should not use directly fp probably
103
            line = response.fp.readline()
104 105
            if line == '\r\n':
                break
106 107


108
class ProxyHTTPSConnection(ProxyHTTPConnection):
109 110
    default_port = 443

111 112
    def __init__(self, host, port=None, key_file=None, cert_file=None,
                 strict=None, context=None):
113 114 115
        ProxyHTTPConnection.__init__(self, host, port)
        self.key_file = key_file
        self.cert_file = cert_file
116
        self.context = context
117 118 119

    def connect(self):
        ProxyHTTPConnection.connect(self)
120
        # make the sock ssl-aware
121
        if sys.version_info < (2, 6, 6):
122
            ssl = socket.ssl(self.sock, self.key_file, self.cert_file)
123
            self.sock = http_client.FakeSocket(self.sock, ssl)
124 125
        else:
            self.sock = self.context.wrap_socket(self.sock)
126

127

128 129
# ----------------------------------------------------------------
# Generic GSI REST
130
class GsiRest(object):
131
    def __init__(self, url=DEFAULT_SERVICE_URL, proxy_host=None,
132 133 134
                 proxy_port=3128, cred=None, username=None, password=None,
                 force_noauth=False, fail_if_noauth=False,
                 reload_certificate=False, reload_buffer=300):
135
        """
136 137 138 139 140 141 142 143 144 145 146 147 148 149
        url (:obj:`str`, optional): URL of server API
        proxy_host (:obj:`str`, optional): proxy host
        proxy_port (:obj:`str`, optional): proxy port
        cred (:obj:`tuple` or :obj:`str, optional): a tuple or list of
            (``/path/to/cert/file``, ``/path/to/key/file) or a single path to
            a combined proxy file (if using an X.509 certificate for
            authentication)
        username (:obj:`str`, optional): username for basic auth
        password (:obj:`str`, optional): password for basic auth
        force_noauth (:obj:`bool`, optional): set to True if you want to skip
            credential lookup and use this client as an unauthenticated user
        fail_if_noauth (:obj:`bool`, optional): set to True if you want the
            constructor to fail if no authentication credentials are provided
            or found
150 151 152 153 154 155 156 157 158 159
        reload_certificate (:obj:`bool`, optional): if ``True``, your
            certificate will be checked before each request whether it is
            within ``reload_buffer`` seconds of expiration, and if so, it will
            be reloaded. Useful for processes which may live longer than the
            certificate lifetime and have an automated method for certificate
            renewal. The path to the new/renewed certificate **must** be the
            same as for the old certificate.
        reload_buffer (:obj:`int`, optional): buffer (in seconds) for reloading
            a certificate in advance of its expiration. Only used if
            ``reload_certificate`` is ``True``.
160 161 162

        Authentication details:
        You can:
163
            1. Provide a path to an X.509 certificate and key or a single
164 165 166 167
               combined proxy file
            2. Provide a username and password
        Or:
            The code will look for a certificate in a default location
168
                (``/tmp/x509up_u%d``, where ``%d`` is your user ID)
169
            The code will look for a username and password for the specified
170
                server in ``$HOME/.netrc``
171 172
        """
        # Process service URL
173 174
        o = urlparse(url)
        host = o.hostname
175 176 177 178 179 180 181 182 183
        port = o.port or 443

        # Store some of this information
        self._server_host = host
        self._server_port = port
        self._proxy_host = proxy_host
        self._proxy_port = proxy_port
        self._reload_certificate = reload_certificate
        self._reload_buffer = reload_buffer
184

185 186 187 188
        # Store information about credentials and authentication type
        self.credentials = {}
        self.auth_type = None

189 190
        # Fail if conflicting arguments: (fail if no auth, but force no auth)
        if fail_if_noauth and force_noauth:
191 192 193 194
            err_msg = ('You have provided conflicting parameters to the '
                       'client constructor: fail_if_noauth=True and '
                       'force_noauth=True.')
            raise ValueError(err_msg)
195

196 197 198
        # Try to get user-provided credentials, if we aren't forcing
        # no authentication
        if not force_noauth:
199 200
            credentials_provided = self._process_credentials(
                cred, username, password)
201 202 203 204 205 206 207 208 209 210 211 212 213 214

        # If the user didn't provide credentials in the constructor,
        # we try to look up the credentials
        if not force_noauth and not credentials_provided:
            # Look for X509 certificate and key
            cred = self._find_x509_credentials()
            if cred:
                self.credentials['cert_file'], self.credentials['key_file'] = \
                    cred
                self.auth_type = 'x509'
            else:
                # Look for basic auth credentials in .netrc file
                try:
                    basic_auth_tuple = safe_netrc().authenticators(host)
215
                except IOError:
216 217 218 219 220 221 222 223 224
                    # IOError = no .netrc file found, pass
                    pass
                else:
                    # If credentials were found for host, set them up!
                    if basic_auth_tuple is not None:
                        self.credentials['username'] = basic_auth_tuple[0]
                        self.credentials['password'] = basic_auth_tuple[2]
                        self.auth_type = 'basic'

225 226
        if (fail_if_noauth and not self.credentials):
            raise RuntimeError('No authentication credentials found.')
227 228 229 230 231 232 233 234 235 236 237

        # If we are using basic auth, construct auth header
        if (self.auth_type == 'basic'):
            user_and_pass = b64encode('{username}:{password}'.format(
                username=self.credentials['username'],
                password=self.credentials['password']).encode()) \
                .decode('ascii')
            self.authn_header = {
                'Authorization': 'Basic {0}'.format(user_and_pass),
            }

238 239 240 241 242
        # If we are using X.509 auth, load the certificate with the
        # cryptography.x509 module
        if (self.auth_type == 'x509'):
            self._load_certificate()

243 244 245 246 247 248 249
        # Construct version header
        self.version_header = {'User-Agent': 'gracedb-client/{version}'.format(
            version=__version__)}

        # Set up SSL context and connector
        self.set_up_connector(host, port, proxy_host, proxy_port)

250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
    def _load_certificate(self):
        if not self.auth_type == 'x509':
            raise RuntimeError("Can't load certificate for "
                               "non-X.509 authentication.")

        # Open cert file and load it as bytes
        with open(self.credentials['cert_file'], 'rb') as cf:
            cert_data = cf.read()

        # Certificates should be PEM, but just in case, we'll try
        # DER if loading a PEM certificate fails
        try:
            self.certificate = x509.load_pem_x509_certificate(
                cert_data, default_backend()
            )
265
        except ValueError:
266 267 268 269
            try:
                self.certificate = x509.load_der_x509_certificate(
                    cert_data, default_backend()
                )
270
            except ValueError:
271 272
                raise RuntimeError('Error importing certificate')

273 274 275
    def _check_certificate_expiration(self, reload_buffer=None):
        if reload_buffer is None:
            reload_buffer = self._reload_buffer
276 277 278 279 280 281 282 283
        if (self.auth_type != 'x509'):
            raise RuntimeError("Can't check certificate expiration for "
                               "non-X.509 authentication.")
        if not hasattr(self, 'certificate'):
            self._load_certificate()

        # Compare certificate expiration to current time (UTC)
        # (certs use UTC until 2050, see https://tools.ietf.org/html/rfc5280)
284 285 286
        time_to_expire = \
            (self.certificate.not_valid_after - datetime.datetime.utcnow())
        expired = \
287
            time_to_expire <= datetime.timedelta(seconds=reload_buffer)
288 289
        return expired

290
    def set_up_connector(self, host, port, proxy_host, proxy_port):
291 292 293
        # Versions of Python earlier than 2.7.9 don't use SSL Context
        # objects for this purpose, and do not do any server cert verification.
        ssl_context = None
294
        if sys.version_info >= (2, 6, 6):
295 296 297
            # Use the new method with SSL Context
            # Prepare SSL context
            ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
298 299 300
            if (self.auth_type == 'x509'):
                try:
                    ssl_context.load_cert_chain(self.credentials['cert_file'],
301 302
                                                self.credentials['key_file'])
                except ssl.SSLError:
303
                    msg = ("\nERROR: Unable to load cert/key pair.\n\nPlease "
304 305
                           "run ligo-proxy-init or grid-proxy-init again or "
                           "make sure your robot certificate is readable.\n\n")
306
                    self.output_and_die(msg)
307 308 309 310 311
            # Load and verify certificates
            ssl_context.verify_mode = ssl.CERT_REQUIRED
            ssl_context.check_hostname = True
            # Find the various CA cert bundles stored on the system
            ssl_context.load_default_certs()
312 313

            if proxy_host:
314 315
                self.connector = lambda: ProxyHTTPSConnection(
                    proxy_host, proxy_port, context=ssl_context)
316
            else:
317 318
                self.connector = lambda: http_client.HTTPSConnection(
                    host, port, context=ssl_context)
319
        else:
320 321
            # Using an older version of python. We'll pass in the cert and
            # key files.
322
            creds = self.credentials if self.auth_type == 'x509' else {}
323
            if proxy_host:
324 325
                self.connector = lambda: ProxyHTTPSConnection(
                    proxy_host, proxy_port, **creds)
326
            else:
327 328
                self.connector = lambda: http_client.HTTPSConnection(
                    host, port, **creds)
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348

    def _process_credentials(self, cred, username, password):
        """Process credentials provided in the constructor"""
        credentials_provided = False
        if cred:
            if isinstance(cred, (list, tuple)):
                self.credentials['cert_file'], self.credentials['key_file'] = \
                    cred
            else:
                self.credentials['cert_file'] = cred
                self.credentials['key_file'] = cred
            credentials_provided = True
            self.auth_type = 'x509'
        elif username and password:
            self.credentials['username'] = username
            self.credentials['password'] = password
            credentials_provided = True
            self.auth_type = 'basic'
        elif (username is None) ^ (password is None):
            raise RuntimeError('Must provide both username AND password for '
349
                               'basic auth.')
350 351 352 353 354 355 356 357 358 359 360 361 362 363

        return credentials_provided

    def _find_x509_credentials(self):
        """
        Tries to find a user's X509 certificate and key.  Checks environment
        variables first, then expected location for default proxy.
        """
        proxyFile = os.environ.get('X509_USER_PROXY')
        certFile = os.environ.get('X509_USER_CERT')
        keyFile = os.environ.get('X509_USER_KEY')

        if certFile and keyFile:
            return certFile, keyFile
364

365 366 367 368 369 370 371 372 373
        if proxyFile:
            return proxyFile, proxyFile

        # Try default proxy
        proxyFile = os.path.join('/tmp', "x509up_u%d" % os.getuid())
        if os.path.exists(proxyFile):
            return proxyFile, proxyFile

        # Try default cert/key
374 375 376 377
        homeDir = os.environ.get('HOME', None)
        if homeDir:
            certFile = os.path.join(homeDir, '.globus', 'usercert.pem')
            keyFile = os.path.join(homeDir, '.globus', 'userkey.pem')
378

379 380
            if os.path.exists(certFile) and os.path.exists(keyFile):
                return certFile, keyFile
381

382
    def show_credentials(self, print_output=True):
383
        """Prints authentication type and credentials information."""
384 385
        output = {'auth_type': self.auth_type}
        output.update(self.credentials)
386 387 388 389 390

        if print_output:
            print(output)
        else:
            return output
391

392
    def get_user_info(self):
393
        """Get information from the server about your user account."""
394 395 396 397 398
        user_info_link = self.links.get('user-info', None)
        if user_info_link is None:
            raise RuntimeError('Server does not provide a user info endpoint')
        return self.get(user_info_link)

399 400
    def getConnection(self):
        return self.connector()
401

402 403
    # When there is a problem with the SSL connection or cert authentication,
    # either conn.request() or conn.getresponse() will throw an exception.
Tanner Prestegard's avatar
Tanner Prestegard committed
404
    # The following two wrappers are intended to catch these exceptions and
405 406 407 408 409
    # return an intelligible error message to the user.
    # A wrapper for getting the response:
    def get_response(self, conn):
        try:
            return conn.getresponse()
410
        except ssl.SSLError as e:
411 412 413

            if (self.auth_type == 'x509'):
                # Check for a valid user proxy cert.
Tanner Prestegard's avatar
Tanner Prestegard committed
414 415 416 417 418 419 420
                expired = self._check_certificate_expiration(reload_buffer=0)

                if expired:
                    msg = ("\nERROR\n\nYour certificate or proxy has "
                           "expired. Please run ligo-proxy-init or "
                           "grid-proxy-init (as appropriate) to generate "
                           "a fresh one.\n\n")
421
                else:
Tanner Prestegard's avatar
Tanner Prestegard committed
422 423 424
                    msg = ("\nERROR\n\nYour certificate appears valid, "
                           "but there was a problem establishing a secure "
                           "connection: {e}").format(e=str(e))
425
            else:
426
                msg = ("\nERROR\n\nProblem establishing secure connection: "
427
                       "{e}\n\n").format(e=str(e))
428
            self.output_and_die(msg)
429 430 431 432 433

    # A wrapper for making the request.
    def make_request(self, conn, *args, **kwargs):
        try:
            conn.request(*args, **kwargs)
434
        except ssl.SSLError as e:
435 436
            msg = "\nERROR \n\n"
            msg += "Problem establishing secure connection: %s \n\n" % str(e)
437
            self.output_and_die(msg)
Tanner Prestegard's avatar
Tanner Prestegard committed
438

439
    def make_request_and_get_response(self, conn, method, url, body=None,
440
                                      headers={}):
441 442 443 444 445

        # For X.509 based auth: if the user has specified to reload the
        # certificate (upon expiration), check the certificate to see if it
        # has expired
        if (self.auth_type == 'x509' and self._reload_certificate):
446
            cert_expired = self._check_certificate_expiration()
447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
            if cert_expired:
                self._load_certificate()
                self.set_up_connector(
                    self._server_host, self._server_port, self._proxy_host,
                    self._proxy_port
                )
                conn = self.getConnection()

        # Make request
        self.make_request(conn, method, url, body=body, headers=headers)

        # Get response
        response = self.get_response(conn)

        return response

463
    def request(self, method, url, body=None, headers=None, priming_url=None):
464 465 466 467 468 469 470 471 472 473
        # Bug in Python (versions < 2.7.1 (?))
        # http://bugs.python.org/issue11898
        # if the URL is unicode and the body of a request is binary,
        # the POST/PUT action fails because it tries to concatenate
        # the two which fails due to encoding problems.
        # Workaround is to cast all URLs to str.
        # This is probably bad in general,
        # but for our purposes, today, this will do.
        url = url and str(url)
        priming_url = priming_url and str(priming_url)
474 475
        headers = headers or {}
        conn = self.getConnection()
476

477
        # Add version string to user-agent header
478
        headers.update(self.version_header)
479

480 481 482 483 484 485
        # Add auth header for basic auth
        if (self.auth_type == 'basic'):
            headers.update(self.authn_header)

        # Set up priming URL for certain requests using X509 auth
        if (self.auth_type == 'x509' and priming_url):
486
            priming_header = {'connection': 'keep-alive'}
487
            priming_header.update(self.version_header)
488 489 490
            response = self.make_request_and_get_response(
                conn, "GET", priming_url, headers=priming_header
            )
491 492 493 494 495
            if response.status != 200:
                response = self.adjustResponse(response)
            else:
                # Throw away the response and make sure to read the body.
                response = response.read()
496

497 498 499
        response = self.make_request_and_get_response(
            conn, method, url, body=body, headers=headers
        )
500 501 502 503 504 505 506

        # Special handling of 401 unauthorized response for basic auth
        # to catch expired passwords
        if (self.auth_type == 'basic' and response.status == 401):
            try:
                msg = "\nERROR: {e}\n\n".format(json.loads(
                    response.read())['detail'])
507
            except Exception:
508 509
                msg = "\nERROR:\n\n"
            msg += ("\nERROR:\n\nPlease check the username/password in your "
510 511 512
                    ".netrc file. If your password is more than a year old, "
                    "you will need to use the web interface to generate a new "
                    "one.\n\n")
513
            self.output_and_die(msg)
514 515 516
        return self.adjustResponse(response)

    def adjustResponse(self, response):
517
        # XXX WRONG.
518
        if response.status >= 400:
519
            response_content = response.read()
520 521
            if isinstance(response_content, bytes):
                response_content = response_content.decode()
522 523 524
            if response.getheader('x-throttle-wait-seconds', None):
                try:
                    rdict = json.loads(response_content)
525 526
                    rdict['retry-after'] = response.getheader(
                        'x-throttle-wait-seconds')
527
                    response_content = json.dumps(rdict)
528
                except Exception:
Tanner Prestegard's avatar
Tanner Prestegard committed
529
                    pass
530
            raise HTTPError(response.status, response.reason, response_content)
531
        response.json = lambda: self.load_json_or_die(response)
532 533 534 535 536 537 538 539 540 541 542 543 544 545 546
        return response

    def get(self, url, headers=None):
        return self.request("GET", url, headers=headers)

    def head(self, url, headers=None):
        return self.request("HEAD", url, headers=headers)

    def delete(self, url, headers=None):
        return self.request("DELETE", url, headers=headers)

    def options(self, url, headers=None):
        return self.request("OPTIONS", url, headers=headers)

    def post(self, *args, **kwargs):
547
        return self.post_or_put_or_patch("POST", *args, **kwargs)
548 549

    def put(self, *args, **kwargs):
550
        return self.post_or_put_or_patch("PUT", *args, **kwargs)
551

552 553 554 555
    def patch(self, *args, **kwargs):
        return self.post_or_put_or_patch("PATCH", *args, **kwargs)

    def post_or_put_or_patch(self, method, url, body=None, headers=None,
556
                             files=None):
557 558 559 560
        headers = headers or {}
        if not files:
            # Simple urlencoded body
            if isinstance(body, dict):
561
                # XXX What about the headers in the params?
562
                if 'content-type' not in headers:
563 564
                    headers['content-type'] = "application/json"
                body = json.dumps(body)
565 566 567
        else:
            body = body or {}
            if isinstance(body, dict):
568
                body = list(body.items())
569
            content_type, body = encode_multipart_formdata(body, files)
Tanner Prestegard's avatar
Tanner Prestegard committed
570
            # XXX What about the headers in the params?
571 572 573
            headers = {
                'content-type': content_type,
                'content-length': str(len(body)),
574
                # 'connection': 'keep-alive',
575 576 577
            }
        return self.request(method, url, body, headers)

578 579 580 581 582 583 584
    # A utility for writing out an error message to the user and then stopping
    # execution. This seems to behave sensibly in both the interpreter and in
    # a script.
    @classmethod
    def output_and_die(cls, msg):
        raise RuntimeError(msg)

585
    # Given an HTTPResponse object, try to read its content and interpret as
586 587 588 589 590 591
    # JSON--or die trying.
    @classmethod
    def load_json_or_die(cls, response):

        # First check that the response object actually exists.
        if not response:
592
            raise ValueError("No response object")
593 594 595

        # Next, try to read the content of the response.
        response_content = response.read()
Leo Pound Singer's avatar
Leo Pound Singer committed
596
        response_content = response_content.decode('utf-8')
597 598
        if not response_content:
            response_content = '{}'
599 600 601 602 603

        # Finally, try to create a dict by decoding the response as JSON.
        rdict = None
        try:
            rdict = json.loads(response_content)
604 605 606
        except ValueError:
            msg = "ERROR: got unexpected content from the server:\n"
            msg += response_content
607
            raise ValueError(msg)
608 609 610

        return rdict

611

612 613
# -----------------------------------------------------------------
# GraceDb REST client
614
class GraceDb(GsiRest):
615 616 617 618 619 620 621 622 623
    """GraceDb REST client class.

    Provides a client object for accessing the GraceDB server API.
    Various methods are provided for retrieving information about different
    objects and uploading information.

    Lookup of user credentials is done in the following order:

    #. If provided, import X.509 credentials from the certificate-key \
624
        pair or combined proxy file provided in the ``cred`` keyword arg.
625 626
    #. If provided, use the username and password provided in the \
        keyword arguments.
627 628 629 630 631 632 633 634
    #. If the ``X509_USER_CERT`` and ``X509_USER_KEY`` environment variables \
        are set, load the corresponding certificate and key.
    #. If the ``X509_USER_PROXY`` environment variable is set, load the \
        corresponding proxy file.
    #. Look for a X.509 proxy from ligo-proxy-init in the default location \
        (``/tmp/x509up_u${UID}``).
    #. Look for a certificate and key file in ``$HOME/.globus/usercert.pem`` \
        and ``$HOME/.globus/userkey.pem``.
635 636 637 638 639 640 641
    #. Look for a username and password for the server in ``$HOME/.netrc``.
    #. Continue with no authentication credentials.

    Args:
        url (:obj:`str`, optional): URL of server API root.
        proxy_host (:obj:`str`, optional): proxy hostname.
        proxy_port (:obj:`str`, optional): proxy port.
Tanner Prestegard's avatar
Tanner Prestegard committed
642
        cred (:obj:`tuple` or :obj:`str`, optional): a tuple or list of
643 644 645 646 647 648 649 650 651 652 653 654 655 656
            (``/path/to/cert/file``, ``/path/to/key/file``) or a single path to
            a combined proxy file. Used for X.509 authentication only.
        username (:obj:`str`, optional): username for basic auth.
        password (:obj:`str`, optional): password for basic auth.
        force_noauth (:obj:`bool`, optional): set to True if you want to
            skip credential lookup and use this client without
            authenticating to the server.
        fail_if_noauth (:obj:`bool`, optional): set to ``True`` if you want the
            client constructor to fail if no authentication credentials are
            provided or found.
        api_version (:obj:`str`, optional): choose the version of the server
            API to use.  At present, there is only one version, but this
            argument is provided with the expectation that this will change
            in the future.
657 658 659 660 661 662 663 664 665 666
        reload_certificate (:obj:`bool`, optional): if True, your certificate
            will be checked before each request whether it is within
            ``reload_buffer`` seconds of expiration, and if so, it will be
            reloaded. Useful for processes which may live longer than the
            certificate lifetime and have an automated method for certificate
            renewal. The path to the new/renewed certificate **must** be the
            same as for the old certificate.
        reload_buffer (:obj:`int`, optional): buffer (in seconds) for reloading
            a certificate in advance of its expiration. Only used if
            ``reload_certificate`` is ``True``.
667 668 669 670 671 672 673 674 675 676

    Examples:
        Instantiate a client to use the production GraceDB server:

        >>> g = GraceDb()

        Use another GraceDB server:

        >>> g = GraceDb(service_url='https://gracedb-playground.ligo.org/api/')

677
        Use a certificate and key in the non-default location:
678

679
        >>> g = GraceDb(cred=('/path/to/cert/file', '/path/to/key/file'))
680
    """
681
    def __init__(self, service_url=DEFAULT_SERVICE_URL, proxy_host=None,
682 683 684
                 proxy_port=3128, cred=None, username=None, password=None,
                 force_noauth=False, fail_if_noauth=False, api_version=None,
                 reload_certificate=False, reload_buffer=300):
685
        """Create a client instance."""
686 687 688 689 690 691
        super(GraceDb, self).__init__(
            service_url, proxy_host=proxy_host, proxy_port=proxy_port,
            cred=cred, username=username, password=password,
            force_noauth=force_noauth, fail_if_noauth=fail_if_noauth,
            reload_certificate=reload_certificate, reload_buffer=reload_buffer
        )
692

693
        # Check version type
694 695 696
        if (api_version is not None and not
            isinstance(api_version, six.string_types)):
            # Raise error is not a string
697 698 699 700 701 702 703 704 705 706 707
            raise TypeError('api_version should be a string')

        # Sets default and versioned service URLs
        # (self._service_url, self._versioned_service_url)
        self._set_service_url(service_url, api_version)

        # Set version
        self._api_version = api_version

        # Set service_info to None, will be obtained from the server when
        # the user takes an action which needs this information.
708 709
        self._service_info = None

710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732
    def _set_service_url(self, service_url, api_version):
        """Sets versioned and unversioned service URLs"""
        # Make sure path ends with '/'
        if not service_url.endswith('/'):
            service_url += '/'

        # Default service url (unversioned)
        self._service_url = service_url

        # Versioned service url (if version provided)
        self._versioned_service_url = service_url
        if api_version and api_version != 'default':
            # If api_version is 'default', that is equivalent to not setting
            # the version and indicates that the user wants to use the
            # default/non-versioned API
            self._versioned_service_url += (api_version + '/')

    @property
    def service_url(self):
        # Will be removed in the future
        print("DEPRECATED: this attribute has been moved to '_service_url'")
        return self._service_url

733 734
    @property
    def service_info(self):
735
        """Gets the root API information."""
736
        if not self._service_info:
737 738 739 740 741 742 743 744 745 746 747 748 749 750
            # try-except block takes user-specified API version to use and
            # checks whether that version is available on the server
            try:
                r = self.request("GET", self._versioned_service_url)
            except HTTPError as e:
                # If we get a 404 error, that means that the versioned
                # service URL was not found. We assume that this happened
                # because the user requested an unavailable API version.
                if (e.status == 404):
                    # Get versions from unversioned API root
                    r = self.request("GET", self._service_url)
                    available_api_versions = r.json().get('API_VERSIONS', None)
                    if available_api_versions:
                        err_msg = ('Bad API version. Available versions for '
751
                                   'this server are: {0}').format(
752 753 754 755
                            available_api_versions)
                    else:
                        # Case where server doesn't have versions, for some
                        # reason.
756 757 758
                        err_msg = ('This server does not have a versioned '
                                   'API. Reinstantiate your client without a '
                                   'version.')
759 760 761 762 763 764

                    # Raise error
                    raise ValueError(err_msg)
                else:
                    # Not a 404 error, must be something else
                    raise e
765 766 767
            else:
                if r.status != 200:
                    raise HTTPError(r.status, r.reason, r.read())
768
            self._service_info = r.json()
769 770
        return self._service_info

771 772
    @property
    def api_versions(self):
773
        """List of available API versions on the server."""
774 775 776 777 778 779
        return self.service_info.get('api-versions')

    @property
    def server_version(self):
        """Get the code version being run on the GraceDB server."""
        return self.service_info.get('server-version')
780

781 782 783 784 785 786 787 788 789 790
    @property
    def links(self):
        return self.service_info.get('links')

    @property
    def templates(self):
        return self.service_info.get('templates')

    @property
    def groups(self):
791
        """List of available analysis groups on the server."""
792
        return self.service_info.get('groups')
Tanner Prestegard's avatar
Tanner Prestegard committed
793

794 795
    @property
    def pipelines(self):
796
        """List of  available analysis pipelines on the server."""
797 798 799 800
        return self.service_info.get('pipelines')

    @property
    def searches(self):
801
        """List of available search types on the server."""
802
        return self.service_info.get('searches')
803

804 805 806 807
    # Would like to call this 'labels' to keep in line with how
    # other properties are named, but it's already used for a function.
    @property
    def allowed_labels(self):
808
        """List of available labels on the server."""
809 810
        return self.service_info.get('labels')

811
    @property
812
    def em_groups(self):
813
        """List of available EM groups on the server."""
814
        return self.service_info.get('em-groups')
815 816 817 818 819 820 821 822 823 824 825 826 827

    @property
    def wavebands(self):
        return self.service_info.get('wavebands')

    @property
    def eel_statuses(self):
        return self.service_info.get('eel-statuses')

    @property
    def obs_statuses(self):
        return self.service_info.get('obs-statuses')

828 829
    @property
    def voevent_types(self):
830
        """List of available VOEvent types on the server."""
831 832
        return self.service_info.get('voevent-types')

833 834
    @property
    def superevent_categories(self):
835
        """List of available superevent categories on the server."""
836 837
        return self.service_info.get('superevent-categories')

838 839
    @property
    def instruments(self):
Tanner Prestegard's avatar
Tanner Prestegard committed
840
        """List of available instruments on the server."""
841 842 843 844
        return self.service_info.get('instruments')

    @property
    def signoff_types(self):
845
        """List of available signoff types on the server."""
846 847 848 849
        return self.service_info.get('signoff-types')

    @property
    def signoff_statuses(self):
850
        """List of available signoff statuses on the server."""
851 852
        return self.service_info.get('signoff-statuses')

853 854 855
    def request(self, method, url, body=None, headers=None, priming_url=None):
        if (method.upper() in ['POST', 'PUT'] and self.auth_type == 'x509'):
            priming_url = self._service_url
856 857
        return super(GraceDb, self).request(
            method, url, body, headers, priming_url)
858

859
    def _getCode(self, input_value, code_dict):
860 861
        """
        Check if input is valid and return coded version if it is
862
        code_dict is dict of {code: descriptive_name}
863 864 865
        """
        # Quick check for simple case where it's already coded
        if input_value in code_dict:
866
            return input_value
867 868 869 870 871

        # Loop over code_dict items, if we match either the key or
        # value (case-insensitive), return the code.
        input_lower = input_value.lower()
        for code, display in six.iteritems(code_dict):
872 873
            if (input_lower == code.lower()
                or input_lower == display.lower()):
874 875 876 877
                return code

        # Not found, return None
        return None
878

879
    # Search and filecontents are optional when creating an event.
880
    def createEvent(self, group, pipeline, filename, search=None, labels=None,
881
                    offline=False, filecontents=None, **kwargs):
882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911
        """Create a new event on the server.

        All LIGO-Virgo users can create events in the 'Test' group. Special
        permissions are required to create non-test events.

        Args:
            group (str): name of the analysis group which identified the
                candidate.
            pipeline (str): name of the analysis pipeline which identified the
                candidate.
            filename (str): path to event file to be uploaded. Use ``'-'`` to
                read from stdin.
            search (:obj:`str`, optional): type of search being run by the
                analysis pipeline.
            labels (:obj:`str` or :obj:`list[str]`, optional): label(s) to
                attach to the event upon creation. Should be a string (single
                label) or list of strings (multiple labels).
            offline (:obj:`bool`, optional): if ``True``, indicates that the
                event was found by an offline analysis.
            filecontents(:obj:`str`, optional): string to be uploaded to the
                server and saved into a file. If event data is uploaded via
                this mechanism, the ``filename`` argument is only used to
                set the name of the file once it is saved on the server.

        Returns:
            :class:`httplib.HTTPResponse`

        Raises:
            ligo.gracedb.exceptions.HTTPError: if the response has a status
                code >= 400.
912 913 914

        Example:
            >>> g = GraceDb()
915
            >>> r = g.createEvent('CBC', 'gstlal', '/path/to/something.xml',
916
            ... labels='INJ', search='LowMass')
917 918 919
            >>> r.status
            201
        """
920 921 922
        errors = []
        if group not in self.groups:
            errors += ["bad group"]
923 924 925 926
        if pipeline not in self.pipelines:
            errors += ["bad pipeline"]
        if search and search not in self.searches:
            errors += ["bad search"]
927 928 929
        # Process offline arg
        if not isinstance(offline, bool):
            errors += ["offline should be True or False"]
930 931 932
        # Process label args - convert non-empty strings to list
        # to ensure consistent processing
        if labels:
933
            if isinstance(labels, six.string_types):
934 935 936 937 938 939 940 941
                # Convert to list
                labels = [labels]
            elif isinstance(labels, list):
                pass
            else:
                # Raise exception instead of adding errors. The next for loop
                # will break (before errors exception is raised) if labels
                # is of the wrong type
942 943
                raise TypeError("labels arg is {0}, should be str or list"
                                .format(type(labels)))
944 945 946 947
            # Check labels against those in database
            for l in labels:
                if l not in self.allowed_labels:
                    raise NameError(("Label '{0}' does not exist in the "
948
                                     "database").format(l))
949 950 951 952
        if errors:
            # XXX Terrible error messages / weak exception type
            raise Exception(str(errors))
        if filecontents is None:
953 954 955 956
            if filename == '-':
                filename = 'initial.data'
                filecontents = sys.stdin.read()
            else:
957 958
                with open(filename, 'rb') as fh:
                    filecontents = fh.read()
959

960
        fields = [
961 962 963 964
            ('group', group),
            ('pipeline', pipeline),
            ('offline', offline),
        ]
965 966
        if search:
            fields.append(('search', search))
967 968 969
        if labels:
            for l in labels:
                fields.append(('labels', l))
970 971

        # Update fields with additional keyword arguments
972
        for key, value in six.iteritems(kwargs):
973 974
            fields.append((key, value))

975 976 977 978 979 980
        files = [('eventFile', filename, filecontents)]
        # Python httplib bug?  unicode link
        uri = str(self.links['events'])
        return self.post(uri, fields, files=files)

    def replaceEvent(self, graceid, filename, filecontents=None):
981
        """Replace an existing event by uploading a new event file.
982

983 984
        The event's parameters are updated from the new file. Only the user
        who originally created the event can update it.
985

986 987 988 989 990 991 992
        Args:
            graceid (str): GraceDB ID of the existing event
            filename (str): path to new event file
            filecontents(:obj:`str`, optional): string to be uploaded to the
                server and saved into a file. If event data is uploaded via
                this mechanism, the ``filename`` argument is only used to
                set the name of the file once it is saved on the server.
Tanner Prestegard's avatar
Tanner Prestegard committed
993

994 995 996 997 998 999
        Returns:
            :class:`httplib.HTTPResponse`

        Raises:
            ligo.gracedb.exceptions.HTTPError: if the response has a status
                code >= 400.
1000

1001
        Example:
1002
            >>> g = GraceDb()
Tanner Prestegard's avatar
Tanner Prestegard committed
1003
            >>> r = g.replaceEvent('T101383', '/path/to/new/something.xml')
1004
        """
1005
        if filecontents is None:
1006 1007
            # Note: not allowing filename '-' here.  We want the event datafile
            # to be versioned.