Gitlab will migrate to a new storage backend starting 0300 UTC on 2020-04-04. We do not anticipate a maintenance window for this migration. Performance may be impacted over the weekend. Thanks for your patience.

Commit c414cad5 authored by Leo Pound Singer's avatar Leo Pound Singer

Refactor sqlite opener shims to eliminate matplotlib import

Fixes #27.
parent fd68d3e8
......@@ -25,6 +25,7 @@ AC_CONFIG_FILES([ \
python/lalinference/popprior/Makefile \
python/lalinference/rapid_pe/Makefile \
python/lalinference/tiger/Makefile \
python/lalinference/util/Makefile \
src/LALInferenceVCSInfo.c \
src/LALInferenceVCSInfo.h \
src/Makefile \
......
......@@ -72,6 +72,7 @@ import sqlite3
import numpy as np
from lalinference.io import fits
from lalinference.bayestar.postprocess import find_injection_moc
from lalinference.util import sqlite
def startup(dbfilename, opts_contour, opts_modes, opts_area):
......@@ -108,7 +109,7 @@ def process(fitsfilename):
if row is None:
raise ValueError(
"No database record found for event '{0}' in '{1}'".format(
coinc_event_id, command.sqlite_get_filename(db)))
coinc_event_id, sqlite.get_filename(db)))
simulation_id, true_ra, true_dec, true_dist, far, snr = row
searched_area, searched_prob, offset, searched_modes, contour_areas, \
area_probs, contour_modes, searched_prob_dist, contour_dists, \
......@@ -153,7 +154,7 @@ if __name__ == '__main__':
from multiprocessing import Pool
map = Pool(
opts.jobs, startup,
(command.sqlite_get_filename(db), contours, modes, areas)
(sqlite.get_filename(db), contours, modes, areas)
).imap
colnames = (
......
......@@ -12,6 +12,7 @@ SUBDIRS = \
popprior \
rapid_pe \
tiger \
util \
$(END_OF_LIST)
vcs_info_sources = git_version.py
......
......@@ -31,12 +31,12 @@ import itertools
import logging
import os
import shutil
import sqlite3
import sys
import tempfile
import matplotlib
from matplotlib import cm
from ..plot import cmap
from ..util import sqlite
# Set no-op Matplotlib backend to defer importing anything that requires a GUI
......@@ -410,25 +410,6 @@ class DirType(object):
return string
def sqlite_open_a(string):
return sqlite3.connect(string)
def sqlite_open_r(string):
if (sys.version_info.major, sys.version_info.minor) >= (3, 4):
return sqlite3.connect('file:{}?mode=ro'.format(string), uri=True)
else: # FIXME: remove this code path when we drop Python < 3.4
fd = os.open(string, os.O_RDONLY)
try:
return sqlite3.connect('/dev/fd/{}'.format(fd))
finally:
os.close(fd)
def sqlite_open_w(string):
with open(string, 'wb') as f:
pass
return sqlite3.connect(string)
class SQLiteType(argparse.FileType):
"""Open an SQLite database, or fail if it does not exist.
FIXME: use SQLite URI when we drop support for Python < 3.4.
......@@ -446,6 +427,7 @@ class SQLiteType(argparse.FileType):
argparse.ArgumentTypeError: ...
If the file already exists, then it's fine:
>>> import sqlite3
>>> filetype = SQLiteType('r')
>>> with tempfile.NamedTemporaryFile() as f:
... with sqlite3.connect(f.name) as db:
......@@ -498,26 +480,10 @@ class SQLiteType(argparse.FileType):
self.mode = mode
def __call__(self, string):
if string in {'-', '/dev/stdin', '/dev/stdout'}:
raise argparse.ArgumentTypeError(
'Cannot open stdin/stdout as an SQLite database')
openers = {'a': sqlite_open_a, 'r': sqlite_open_r, 'w': sqlite_open_w}
opener = openers[self.mode]
try:
return opener(string)
except (OSError, sqlite3.Error) as e:
raise argparse.ArgumentTypeError(
'Failed to open database {}: {}'.format(string, e))
def sqlite_get_filename(connection):
"""Get the name of the file associated with an SQLite connection"""
result = connection.execute('pragma database_list').fetchall()
try:
(_, _, filename), = result
except ValueError:
raise RuntimeError('Expected exactly one attached database')
return filename
return sqlite.open(string, self.mode)
except OSError as e:
raise argparse.ArgumentTypeError(e)
def rename(src, dst):
......
......@@ -23,7 +23,7 @@ import sys
from glue.ligolw import dbtables
from ...bayestar.command import sqlite_open_r, sqlite_get_filename
from ...util import sqlite
from .ligolw import LigoLWEventSource
__all__ = ('SQLiteEventSource',)
......@@ -34,14 +34,14 @@ class SQLiteEventSource(LigoLWEventSource):
def __init__(self, f, *args, **kwargs):
if isinstance(f, sqlite3.Connection):
db = f
filename = sqlite_get_filename(f)
filename = sqlite.get_filename(f)
else:
if hasattr(f, 'read'):
filename = f.name
f.close()
else:
filename = f
db = sqlite_open_r(filename)
db = sqlite.open(filename, 'r')
super(SQLiteEventSource, self).__init__(dbtables.get_xml(db),
*args, **kwargs)
self._fallbackpath = os.path.dirname(filename) if filename else None
......
BUILT_SOURCES =
MOSTLYCLEANFILES =
EXTRA_DIST =
include $(top_srcdir)/gnuscripts/lalsuite_python.am
if HAVE_PYTHON
pymoduledir = $(pkgpythondir)/util
pymodule_PYTHON = \
__init__.py \
sqlite.py \
$(END_OF_LIST)
endif
#
# Copyright (C) 2018 Leo Singer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
"""Tools for reading and writing SQLite databases"""
import os
import sqlite3
import sys
_open = open
def _open_a(string):
return sqlite3.connect(string)
def _open_r(string):
if (sys.version_info.major, sys.version_info.minor) >= (3, 4):
return sqlite3.connect('file:{}?mode=ro'.format(string), uri=True)
else: # FIXME: remove this code path when we drop Python < 3.4
fd = os.open(string, os.O_RDONLY)
try:
return sqlite3.connect('/dev/fd/{}'.format(fd))
finally:
os.close(fd)
def _open_w(string):
with _open(string, 'wb') as f:
pass
return sqlite3.connect(string)
_openers = {'a': _open_a, 'r': _open_r, 'w': _open_w}
def open(string, mode):
if string in {'-', '/dev/stdin', '/dev/stdout'}:
raise ValueError('Cannot open stdin/stdout as an SQLite database')
try:
opener = _openers[mode]
except KeyError:
raise ValueError('Invalid mode "{}". Must be one of "{}".'.format(
mode, ''.join(_openers.keys())))
try:
return opener(string)
except (OSError, sqlite3.Error) as e:
raise OSError('Failed to open database {}: {}'.format(string, e))
def get_filename(connection):
"""Get the name of the file associated with an SQLite connection"""
result = connection.execute('pragma database_list').fetchall()
try:
(_, _, filename), = result
except ValueError:
raise RuntimeError('Expected exactly one attached database')
return filename
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment