dbtables.py 35.8 KB
Newer Older
Kipp Cannon's avatar
Kipp Cannon committed
1
# Copyright (C) 2007-2018  Kipp Cannon
2 3 4
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
duncan's avatar
duncan committed
5
# Free Software Foundation; either version 3 of the License, or (at your
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#


"""
This module provides an implementation of the Table element that uses a
database engine for storage.  On top of that it then re-implements a number
of the tables from the lsctables module to provide versions of their
methods that work against the SQL database.
"""


Kipp Cannon's avatar
Kipp Cannon committed
35
import itertools
36
import operator
37
import os
38
import re
39
import shutil
40
import signal
41
import sys
42
import tempfile
43
import threading
44
from xml.sax.xmlreader import AttributesImpl
Kipp Cannon's avatar
Kipp Cannon committed
45
import warnings
46

kipp's avatar
kipp committed
47

48
from . import __author__, __date__, __version__
49 50 51 52
from . import ligolw
from . import table
from . import lsctables
from . import types as ligolwtypes
53
from . import utils as ligolw_utils
54 55 56 57 58 59 60 61 62 63 64


#
# =============================================================================
#
#                                  Connection
#
# =============================================================================
#


Kipp Cannon's avatar
Kipp Cannon committed
65 66 67 68 69 70 71 72 73
def connection_db_type(connection):
	"""
	A totally broken attempt to determine what type of database a
	connection object is attached to.  Don't use this.

	The input is a DB API 2.0 compliant connection object, the return
	value is one of the strings "sqlite3" or "mysql".  Raises TypeError
	when the database type cannot be determined.
	"""
74 75
	if "sqlite" in repr(connection):
		return "sqlite"
Kipp Cannon's avatar
Kipp Cannon committed
76 77
	if "mysql" in repr(connection):
		return "mysql"
78
	raise TypeError(connection)
Kipp Cannon's avatar
Kipp Cannon committed
79 80


Kipp Cannon's avatar
Kipp Cannon committed
81
#
82
# work with database file in scratch space
Kipp Cannon's avatar
Kipp Cannon committed
83 84 85
#


86
class workingcopy(object):
87
	"""
88 89 90 91 92 93 94 95
	Manage a working copy of an sqlite database file.  This is used
	when a large enough number of manipulations are being performed on
	a database file that the total network I/O would be higher than
	that of copying the entire file to a local disk, doing the
	manipulations locally, then copying the file back.  It is also
	useful in unburdening a file server when large numbers of read-only
	operations are being performed on the same file by many different
	machines.
96
	"""
97

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
	def __init__(self, filename, tmp_path = None, replace_file = False, discard = False, verbose = False):
		"""
		filename:  the name of the sqlite database file.

		tmp_path:  the directory to use for the working copy.  If
		None (the default), the system's default location for
		temporary files is used.  If set to the special value
		"_CONDOR_SCRATCH_DIR" then the value of the environment
		variable of that name will be used (to use a directory
		literally named _CONDOR_SCRATCH_DIR set tmp_path to
		"./_CONDOR_SCRATCH_DIR").

		replace_file:  if True, filename is truncated in place
		before manipulation;  if False (the default), the file is
		not modified before use.  This is used when the original
		file is being over-written with the working copy, and it is
		necessary to ensure that a malfunction or crash (which
		might prevent the working copy from over writing the
		original) does not leave behind the unmodified original,
		which could subsequently be mistaken for valid output.

		discard:  if True the working copy is simply deleted
		instead of being copied back to the original location;  if
		False (the default) the working copy overwrites the
		original.  This is used to improve read-only operations,
		when it is not necessary to pay the I/O cost of moving an
		unmodified file a second time.  The .discard attribute can
		be set at any time while the context manager is in use,
		before the .__exit__() method is invoked.

		verbose:  print messages to stderr.

		NOTES:

		- When replace_file mode is enabled, any failures that
		  prevent the original file from being trucated are
		  ignored.  The inability to truncate the file is
		  considered non-fatal.

		- If the operation to copy the file to the working path
		  fails then a working copy is not used, the original file
		  is used in place.  If the failure that prevents copying
		  the file to the working path is potentially transient,
		  for example "permission denied" or "no space on device",
		  the code sleeps for a brief period of time and then tries
		  again.  Only after the potentially transient failure
		  persists for several attempts is the working copy
		  abandoned and the original copy used instead.

		- When the working copy is moved back to the original
		  location, if a file with the same name but ending in
		  -journal is present in the working directory then it is
		  deleted.

		- The name of the working copy can be obtained by
		  converting the workingcopy object to a string.
		"""
		self.filename = filename
		self.tmp_path = tmp_path if tmp_path != "_CONDOR_SCRATCH_DIR" else os.getenv("_CONDOR_SCRATCH_DIR")
		self.replace_file = replace_file
		self.discard = discard
		self.verbose = verbose


	@staticmethod
163
	def truncate(filename, verbose = False):
164 165 166 167
		"""
		Truncate a file to 0 size, ignoring all errors.  This is
		used internally to implement the "replace_file" feature.
		"""
168
		if verbose:
169
			sys.stderr.write("'%s' exists, truncating ... " % filename)
170 171
		try:
			fd = os.open(filename, os.O_WRONLY | os.O_TRUNC)
Kipp Cannon's avatar
Kipp Cannon committed
172
		except Exception as e:
173
			if verbose:
174
				sys.stderr.write("cannot truncate '%s': %s\n" % (filename, str(e)))
175
			return
176 177
		os.close(fd)
		if verbose:
178
			sys.stderr.write("done.\n")
179

180 181 182 183 184 185 186 187 188 189 190

	@staticmethod
	def cpy(srcname, dstname, attempts = 5, verbose = False):
		"""
		Copy a file to a destination preserving permission if
		possible.  If the operation fails for a non-fatal reason
		then several attempts are made with a pause between each.
		The return value is dstname if the operation was successful
		or srcname if a non-fatal failure caused the operation to
		terminate.  Fatal failures raise an exeption.
		"""
191
		if verbose:
192
			sys.stderr.write("copying '%s' to '%s' ... " % (srcname, dstname))
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
		for i in itertools.count(1):
			try:
				shutil.copy2(srcname, dstname)
				# if we get here it worked
				break
			except IOError as e:
				# anything other than out-of-space is a
				# real error
				import errno
				import time
				if e.errno not in (errno.EPERM, errno.ENOSPC):
					raise
				if verbose:
					sys.stderr.write("warning: attempt %d: %s: \r" % (i, errno.errorcode[e.errno]))
				# if we've run out of attempts, fall back
				# to the original file
				if i > 4:
					if verbose:
						sys.stderr.write("working with original file '%s'\n" % srcname)
					return srcname
				# otherwise sleep and try again
				if verbose:
					sys.stderr.write("sleeping and trying again ...\n")
				time.sleep(10)
Kipp Cannon's avatar
Kipp Cannon committed
217
		if verbose:
218
			sys.stderr.write("done.\n")
Kipp Cannon's avatar
Kipp Cannon committed
219 220 221 222 223 224 225
		try:
			# try to preserve permission bits.  according to
			# the documentation, copy() and copy2() are
			# supposed preserve them but don't.  maybe they
			# don't preserve them if the destination file
			# already exists?
			shutil.copystat(srcname, dstname)
Kipp Cannon's avatar
Kipp Cannon committed
226
		except Exception as e:
Kipp Cannon's avatar
Kipp Cannon committed
227
			if verbose:
228 229
				sys.stderr.write("warning: ignoring failure to copy permission bits from '%s' to '%s': %s\n" % (srcname, dstname, str(e)))
		return dstname
230

231

232 233
	def __enter__(self):
		database_exists = os.access(self.filename, os.F_OK)
234

235 236 237 238 239
		if self.tmp_path is not None:
			# create the remporary file and retain a reference
			# to prevent its removal.  for suffix, can't use
			# splitext() because it only keeps the last bit,
			# e.g. won't give ".xml.gz" but just ".gz"
240

241 242 243 244
			self.temporary_file = tempfile.NamedTemporaryFile(suffix = ".".join(os.path.split(self.filename)[-1].split(".")[1:]), dir = self.tmp_path)
			self.target = self.temporary_file.name
			if self.verbose:
				sys.stderr.write("using '%s' as workspace\n" % self.target)
245

246 247 248 249
			# mkstemp() ignores umask, creates all files accessible
			# only by owner;  we should respect umask.  note that
			# os.umask() sets it, too, so we have to set it back after
			# we know what it is
250

251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
			umsk = os.umask(0o777)
			os.umask(umsk)
			os.chmod(self.target, 0o666 & ~umsk)

			if database_exists:
				# if the file is being replaced then
				# truncate the database so that if this job
				# fails the user won't think the database
				# file is valid, otherwise copy the
				# existing database to the work space for
				# modification
				if self.replace_file:
					self.truncate(self.filename, verbose = self.verbose)
				elif self.cpy(self.filename, self.target, verbose = self.verbose) == self.filename:
					# non-fatal errors have caused us
					# to fall-back to the file in its
					# original location
					self.target = self.filename
					del self.temporary_file
		else:
			self.target = self.filename
			if database_exists and self.replace_file:
				self.truncate(self.target, verbose = self.verbose)

		return self


	def __str__(self):
		return self.target


	def __exit__(self, exc_type, exc_val, exc_tb):
		"""
		Restore the working copy to its original location if the
		two are different.

		During the move operation, this function traps the signals
		used by Condor to evict jobs.  This reduces the risk of
		corrupting a document by the job terminating part-way
		through the restoration of the file to its original
		location.  When the move operation is concluded, the
		original signal handlers are restored and if any signals
		were trapped they are resent to the current process in
		order.  Typically this will result in the signal handlers
		installed by the install_signal_trap() function being
		invoked, meaning any other scratch files that might be in
		use get deleted and the current process is terminated.
		"""
		# when removed, must also delete a -journal partner, ignore
		# all errors
301

302
		try:
303
			orig_unlink("%s-journal" % self)
304 305
		except:
			pass
Kipp Cannon's avatar
Kipp Cannon committed
306

307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
		# restore the file to its original location

		if self.target != self.filename:
			with ligolw_utils.SignalsTrap():
				if not self.discard:
					# move back to original location

					if self.verbose:
						sys.stderr.write("moving '%s' to '%s' ... " % (self.target, self.filename))
					shutil.move(self.target, self.filename)
					if self.verbose:
						sys.stderr.write("done.\n")

					# next we will trigger the
					# temporary file removal.  because
					# we've just deleted that file,
					# this will produce an annoying but
					# harmless message about an ignored
					# OSError.  so silence the warning
					# we create a dummy file for the
					# TemporaryFile to delete.  ignore
					# any errors that occur when trying
					# to make the dummy file.  FIXME:
					# this is stupid, find a better way
					# to shut TemporaryFile up
332

333 334 335 336
					try:
						open(self.target, "w").close()
					except:
						pass
337

338 339 340
				# remove reference to
				# tempfile.TemporaryFile object.  this
				# triggers the removal of the file.
341

342 343 344 345 346 347 348 349 350 351 352 353
				del self.temporary_file

		# if an exception terminated the code block, re-raise the
		# exception

		return False


	def set_temp_store_directory(self, connection, verbose = False):
		"""
		Sets the temp_store_directory parameter in sqlite.
		"""
354
		if verbose:
355 356 357 358
			sys.stderr.write("setting the temp_store_directory to %s ... " % self.tmp_path)
		cursor = connection.cursor()
		cursor.execute("PRAGMA temp_store_directory = '%s'" % self.tmp_path)
		cursor.close()
Kipp Cannon's avatar
Kipp Cannon committed
359
		if verbose:
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
			sys.stderr.write("done\n")


#
# backwards compatibility for old code.  FIXME:  delete in next release
#


def get_connection_filename(*args, **kwargs):
	return workingcopy(*args, **kwargs).__enter__()


def put_connection_filename(ignored, target, verbose = False):
	target.verbose = verbose
	target.__exit__(None, None, None)

def discard_connection_filename(ignored, target, verbose = False):
	target.discard = True
	target.verbose = verbose
	target.__exit__(None, None, None)

def set_temp_store_directory(connection, temp_store_directory, verbose = False):
	if temp_store_directory == "_CONDOR_SCRATCH_DIR":
		temp_store_directory = os.getenv("_CONDOR_SCRATCH_DIR")
	if verbose:
		sys.stderr.write("setting the temp_store_directory to %s ... " % temp_store_directory)
	cursor = connection.cursor()
	cursor.execute("PRAGMA temp_store_directory = '%s'" % temp_store_directory)
	cursor.close()
	if verbose:
		sys.stderr.write("done\n")
391

392

393 394 395 396 397 398 399 400 401
#
# =============================================================================
#
#                                  ID Mapping
#
# =============================================================================
#


402
def idmap_create(connection):
403
	"""
404 405 406 407 408 409
	Create the _idmap_ table.  This table has columns "table_name",
	"old", and "new" mapping old IDs to new IDs for each table.  The
	(table_name, old) column pair is a primary key (is indexed and must
	contain unique entries).  The table is created as a temporary
	table, so it will be automatically dropped when the database
	connection is closed.
kipp's avatar
kipp committed
410 411 412

	This function is for internal use, it forms part of the code used
	to re-map row IDs when merging multiple documents.
413
	"""
414
	connection.cursor().execute("CREATE TEMPORARY TABLE _idmap_ (table_name TEXT NOT NULL, old INTEGER NOT NULL, new INTEGER NOT NULL, PRIMARY KEY (table_name, old))")
415 416


417
def idmap_reset(connection):
418
	"""
419 420
	Erase the contents of the _idmap_ table, but leave the table in
	place.
kipp's avatar
kipp committed
421 422 423

	This function is for internal use, it forms part of the code used
	to re-map row IDs when merging multiple documents.
424
	"""
425
	connection.cursor().execute("DELETE FROM _idmap_")
426 427


Kipp Cannon's avatar
Kipp Cannon committed
428 429 430 431 432 433 434 435 436 437 438 439
def idmap_sync(connection):
	"""
	Iterate over the tables in the database, ensure that there exists a
	custom DBTable class for each, and synchronize that table's ID
	generator to the ID values in the database.
	"""
	xmldoc = get_xml(connection)
	for tbl in xmldoc.getElementsByTagName(DBTable.tagName):
		tbl.sync_next_id()
	xmldoc.unlink()


440
def idmap_get_new(cursor, table_name, old, tbl):
441 442 443
	"""
	From the old ID string, obtain a replacement ID string by either
	grabbing it from the _idmap_ table if one has already been assigned
444 445 446
	to the old ID, or by using the current value of the Table
	instance's next_id class attribute.  In the latter case, the new ID
	is recorded in the _idmap_ table, and the class attribute
kipp's avatar
kipp committed
447 448 449 450
	incremented by 1.

	This function is for internal use, it forms part of the code used
	to re-map row IDs when merging multiple documents.
451
	"""
452
	cursor.execute("SELECT new FROM _idmap_ WHERE table_name == ? AND old == ?", (table_name, old))
Kipp Cannon's avatar
Kipp Cannon committed
453
	new = cursor.fetchone()
454
	if new is not None:
kipp's avatar
kipp committed
455
		# a new ID has already been created for this old ID
Kipp Cannon's avatar
Kipp Cannon committed
456
		return new[0]
kipp's avatar
kipp committed
457 458
	# this ID was not found in _idmap_ table, assign a new ID and
	# record it
Kipp Cannon's avatar
Kipp Cannon committed
459
	new = tbl.get_next_id()
460
	cursor.execute("INSERT INTO _idmap_ VALUES (?, ?, ?)", (table_name, old, new))
461 462 463 464 465 466 467 468 469 470 471 472
	return new


#
# =============================================================================
#
#                             Database Information
#
# =============================================================================
#


473 474 475 476
#
# SQL parsing
#

kipp's avatar
kipp committed
477

478
_sql_create_table_pattern = re.compile(r"CREATE\s+TABLE\s+(?P<name>\w+)\s*\((?P<coldefs>.*)\)", re.IGNORECASE)
479 480 481 482 483 484 485 486
_sql_coldef_pattern = re.compile(r"\s*(?P<name>\w+)\s+(?P<type>\w+)[^,]*")


#
# Database info extraction utils
#


487
def get_table_names(connection):
488 489 490
	"""
	Return a list of the table names in the database.
	"""
Kipp Cannon's avatar
Kipp Cannon committed
491 492 493
	cursor = connection.cursor()
	cursor.execute("SELECT name FROM sqlite_master WHERE type == 'table'")
	return [name for (name,) in cursor]
494 495


496
def get_column_info(connection, table_name):
497
	"""
498
	Return an in order list of (name, type) tuples describing the
kipp's avatar
kipp committed
499
	columns in the given table.
500
	"""
Kipp Cannon's avatar
Kipp Cannon committed
501 502 503
	cursor = connection.cursor()
	cursor.execute("SELECT sql FROM sqlite_master WHERE type == 'table' AND name == ?", (table_name,))
	statement, = cursor.fetchone()
504
	coldefs = re.match(_sql_create_table_pattern, statement).groupdict()["coldefs"]
505
	return [(coldef.groupdict()["name"], coldef.groupdict()["type"]) for coldef in re.finditer(_sql_coldef_pattern, coldefs) if coldef.groupdict()["name"].upper() not in ("PRIMARY", "UNIQUE", "CHECK")]
506 507


508
def get_xml(connection, table_names = None):
509 510
	"""
	Construct an XML document tree wrapping around the contents of the
511
	database.  On success the return value is a ligolw.LIGO_LW element
512 513 514
	containing the tables as children.  Arguments are a connection to
	to a database, and an optional list of table names to dump.  If
	table_names is not provided the set is obtained from get_table_names()
515 516
	"""
	ligo_lw = ligolw.LIGO_LW()
517 518 519

	if table_names is None:
		table_names = get_table_names(connection)
Kipp Cannon's avatar
Kipp Cannon committed
520

521
	for table_name in table_names:
522 523 524 525
		# build the table document tree.  copied from
		# lsctables.New()
		try:
			cls = TableByName[table_name]
526
		except KeyError:
527
			cls = DBTable
Kipp Cannon's avatar
Kipp Cannon committed
528
		table_elem = cls(AttributesImpl({u"Name": u"%s:table" % table_name}), connection = connection)
529 530 531 532
		destrip = {}
		if table_elem.validcolumns is not None:
			for name in table_elem.validcolumns:
				destrip[table.Column.ColumnName(name)] = name
533
		for column_name, column_type in get_column_info(connection, table_elem.Name):
534
			if table_elem.validcolumns is not None:
535 536 537 538
				try:
					column_name = destrip[column_name]
				except KeyError:
					raise ValueError("invalid column")
Kipp Cannon's avatar
Kipp Cannon committed
539
				# use the pre-defined column type
540
				column_type = table_elem.validcolumns[column_name]
541 542
			else:
				# guess the column type
kipp's avatar
kipp committed
543
				column_type = ligolwtypes.FromSQLiteType[column_type]
544
			table_elem.appendChild(table.Column(AttributesImpl({u"Name": column_name, u"Type": column_type})))
545
		table_elem._end_of_columns()
546
		table_elem.appendChild(table.TableStream(AttributesImpl({u"Name": u"%s:table" % table_name, u"Delimiter": table.TableStream.Delimiter.default, u"Type": table.TableStream.Type.default})))
547 548 549 550 551 552 553 554 555 556 557 558 559
		ligo_lw.appendChild(table_elem)
	return ligo_lw


#
# =============================================================================
#
#                            DBTable Element Class
#
# =============================================================================
#


560 561 562 563 564 565 566 567
# FIXME:  is this needed?
class DBTableStream(table.TableStream):
	def endElement(self):
		super(DBTableStream, self).endElement()
		if hasattr(self.parentNode, "connection"):
			self.parentNode.connection.commit()


568 569
class DBTable(table.Table):
	"""
Kipp Cannon's avatar
Kipp Cannon committed
570 571 572
	A version of the Table class using an SQL database for storage.
	Many of the features of the Table class are not available here, but
	instead the user can use SQL to query the table's contents.
573 574 575 576 577

	The constraints attribute can be set to a text string that will be
	added to the table's CREATE statement where constraints go, for
	example you might wish to set this to "PRIMARY KEY (event_id)" for
	a table with an event_id column.
578 579 580 581 582

	Note:  because the table is stored in an SQL database, the use of
	this class imposes the restriction that table names be unique
	within a document.

583 584 585 586 587 588 589
	Also note that at the present time there is really only proper
	support for the pre-defined tables in the lsctables module.  It is
	possible to load unrecognized tables into a database from LIGO
	Light Weight XML files, but without developer intervention there is
	no way to indicate the constraints that should be imposed on the
	columns, for example which columns should be used as primary keys
	and so on.  This can result in poor query performance.  It is also
590
	possible to extract a database' contents to a LIGO Light Weight XML
591
	file even when the database contains unrecognized tables, but
592 593 594
	without developer intervention the column types will be guessed
	using a generic mapping of SQL types to LIGO Light Weight types.

Kipp Cannon's avatar
Kipp Cannon committed
595 596 597
	Each instance of this class must be connected to a database.  The
	(Python DBAPI 2.0 compatible) connection object is passed to the
	class via the connection parameter at instance creation time.
598 599 600

	Example:

601 602
	>>> import sqlite3
	>>> connection = sqlite3.connection()
603
	>>> tbl = dbtables.DBTable(AttributesImpl({u"Name": u"process:table"}), connection = connection)
Kipp Cannon's avatar
Kipp Cannon committed
604

605 606 607 608 609
	A custom content handler must be created in order to pass the
	connection keyword argument to the DBTable class when instances are
	created, since the default content handler does not do this.  See
	the use_in() function defined in this module for information on how
	to create such a content handler
Kipp Cannon's avatar
Kipp Cannon committed
610

611 612 613 614 615 616 617
	If a custom ligo.lw.Table subclass is defined in ligo.lw.lsctables
	whose name matches the name of the DBTable being constructed, the
	lsctables class is added to the list of parent classes.  This
	allows the lsctables class' methods to be used with the DBTable
	instances but not all of the methods will necessarily work with the
	database-backed version of the class.  Your mileage may vary.

618
	"""
Kipp Cannon's avatar
Kipp Cannon committed
619
	def __new__(cls, *args, **kwargs):
kipp's avatar
kipp committed
620
		# does this class already have table-specific metadata?
621 622 623
		if not hasattr(cls, "tableName"):
			# no, try to retrieve it from lsctables
			attrs, = args
624
			name = table.Table.TableName(attrs[u"Name"])
625 626
			if name in lsctables.TableByName:
				# found metadata in lsctables, construct
Kipp Cannon's avatar
Kipp Cannon committed
627 628 629 630 631 632
				# custom subclass.  the class from
				# lsctables is added as a parent class to
				# allow methods from that class to be used
				# with this class, however there is no
				# guarantee that all parent class methods
				# will be appropriate for use with the
633
				# DB-backend object.
634
				lsccls = lsctables.TableByName[name]
Kipp Cannon's avatar
Kipp Cannon committed
635
				class CustomDBTable(cls, lsccls):
636 637 638 639 640 641 642
					tableName = lsccls.tableName
					validcolumns = lsccls.validcolumns
					loadcolumns = lsccls.loadcolumns
					constraints = lsccls.constraints
					next_id = lsccls.next_id
					RowType = lsccls.RowType
					how_to_index = lsccls.how_to_index
643 644 645 646 647 648 649

				# save for re-use (required for ID
				# remapping across multiple documents in
				# ligolw_sqlite)
				TableByName[name] = CustomDBTable

				# replace input argument with new class
650
				cls = CustomDBTable
651 652
		return table.Table.__new__(cls, *args)

Kipp Cannon's avatar
Kipp Cannon committed
653 654
	def __init__(self, *args, **kwargs):
		# chain to parent class
655
		table.Table.__init__(self, *args)
656

657 658
		# retrieve connection object from kwargs
		self.connection = kwargs.pop("connection")
659 660

		# pre-allocate a cursor for internal queries
661 662
		self.cursor = self.connection.cursor()

663 664
	def copy(self, *args, **kwargs):
		"""
665 666
		This method is not implemented.  See ligo.lw.table.Table
		for more information.
667
		"""
668
		raise NotImplementedError
669

670 671
	def _end_of_columns(self):
		table.Table._end_of_columns(self)
kipp's avatar
kipp committed
672 673
		# dbcolumnnames and types have the "not loaded" columns
		# removed
674 675 676 677 678 679 680 681
		if self.loadcolumns is not None:
			self.dbcolumnnames = [name for name in self.columnnames if name in self.loadcolumns]
			self.dbcolumntypes = [name for i, name in enumerate(self.columntypes) if self.columnnames[i] in self.loadcolumns]
		else:
			self.dbcolumnnames = self.columnnames
			self.dbcolumntypes = self.columntypes

		# create the table
Kipp Cannon's avatar
Kipp Cannon committed
682
		ToSQLType = {
683
			"sqlite": ligolwtypes.ToSQLiteType,
Kipp Cannon's avatar
Kipp Cannon committed
684 685
			"mysql": ligolwtypes.ToMySQLType
		}[connection_db_type(self.connection)]
686
		try:
687
			statement = "CREATE TABLE IF NOT EXISTS " + self.Name + " (" + ", ".join(map(lambda n, t: "%s %s" % (n, ToSQLType[t]), self.dbcolumnnames, self.dbcolumntypes))
Kipp Cannon's avatar
Kipp Cannon committed
688
		except KeyError as e:
689
			raise ValueError("column type '%s' not supported" % str(e))
690 691 692 693 694
		if self.constraints is not None:
			statement += ", " + self.constraints
		statement += ")"
		self.cursor.execute(statement)

695 696
		# row ID where remapping is to start
		self.remap_first_rowid = None
697 698

		# construct the SQL to be used to insert new rows
Kipp Cannon's avatar
Kipp Cannon committed
699
		params = {
700
			"sqlite": ",".join("?" * len(self.dbcolumnnames)),
Kipp Cannon's avatar
Kipp Cannon committed
701 702
			"mysql": ",".join(["%s"] * len(self.dbcolumnnames))
		}[connection_db_type(self.connection)]
703
		self.append_statement = "INSERT INTO %s (%s) VALUES (%s)" % (self.Name, ",".join(self.dbcolumnnames), params)
704
		self.append_attrgetter = operator.attrgetter(*self.dbcolumnnames)
705

706 707
	def sync_next_id(self):
		if self.next_id is not None:
708
			maxid = self.cursor.execute("SELECT MAX(%s) FROM %s" % (self.next_id.column_name, self.Name)).fetchone()[0]
709 710 711 712 713 714 715
			if maxid is not None:
				# type conversion not needed for
				# .set_next_id(), but needed so we can do
				# arithmetic on the thing
				maxid = type(self.next_id)(maxid) + 1
				if maxid > self.next_id:
					self.set_next_id(maxid)
716
		return self.next_id
717 718

	def maxrowid(self):
719
		self.cursor.execute("SELECT MAX(ROWID) FROM %s" % self.Name)
Kipp Cannon's avatar
Kipp Cannon committed
720
		return self.cursor.fetchone()[0]
721 722

	def __len__(self):
723
		self.cursor.execute("SELECT COUNT(*) FROM %s" % self.Name)
Kipp Cannon's avatar
Kipp Cannon committed
724
		return self.cursor.fetchone()[0]
725 726

	def __iter__(self):
Kipp Cannon's avatar
Kipp Cannon committed
727
		cursor = self.connection.cursor()
728 729 730 731 732 733 734
		cursor.execute("SELECT * FROM %s ORDER BY rowid ASC" % self.Name)
		for values in cursor:
			yield self.row_from_cols(values)

	def __reversed__(self):
		cursor = self.connection.cursor()
		cursor.execute("SELECT * FROM %s ORDER BY rowid DESC" % self.Name)
Kipp Cannon's avatar
Kipp Cannon committed
735
		for values in cursor:
Kipp Cannon's avatar
Kipp Cannon committed
736
			yield self.row_from_cols(values)
737

738 739 740 741 742
	# FIXME:  is adding this a good idea?
	#def __delslice__(self, i, j):
	#	# sqlite numbers rows starting from 1:  [0:10] becomes
	#	# "rowid between 1 and 10" which means 1 <= rowid <= 10,
	#	# which is the intended range
743
	#	self.cursor.execute("DELETE FROM %s WHERE ROWID BETWEEN %d AND %d" % (self.Name, i + 1, j))
744

745
	def _append(self, row):
Kipp Cannon's avatar
Kipp Cannon committed
746 747 748 749
		"""
		Standard .append() method.  This method is for intended for
		internal use only.
		"""
750
		self.cursor.execute(self.append_statement, self.append_attrgetter(row))
751 752 753

	def _remapping_append(self, row):
		"""
Kipp Cannon's avatar
Kipp Cannon committed
754 755 756 757 758 759
		Replacement for the standard .append() method.  This
		version performs on the fly row ID reassignment, and so
		also performs the function of the updateKeyMapping()
		method.  SQLite does not permit the PRIMARY KEY of a row to
		be modified, so it needs to be done prior to insertion.
		This method is intended for internal use only.
760
		"""
761
		if self.next_id is not None:
762
			# assign (and record) a new ID before inserting the
kipp's avatar
kipp committed
763
			# row to avoid collisions with existing rows
764
			setattr(row, self.next_id.column_name, idmap_get_new(self.cursor, self.Name, getattr(row, self.next_id.column_name), self))
Kipp Cannon's avatar
Kipp Cannon committed
765
		self._append(row)
766 767 768
		if self.remap_first_rowid is None:
			self.remap_first_rowid = self.maxrowid()
			assert self.remap_first_rowid is not None
769 770 771

	append = _append

Kipp Cannon's avatar
Kipp Cannon committed
772
	def row_from_cols(self, values):
773
		"""
kipp's avatar
kipp committed
774 775
		Given an iterable of values in the order of columns in the
		database, construct and return a row object.  This is a
776 777 778 779
		convenience function for turning the results of database
		queries into Python objects.
		"""
		row = self.RowType()
Kipp Cannon's avatar
Kipp Cannon committed
780
		for c, v in zip(self.dbcolumnnames, values):
781 782
			setattr(row, c, v)
		return row
Kipp Cannon's avatar
Kipp Cannon committed
783 784
	# backwards compatibility
	_row_from_cols = row_from_cols
785 786 787

	def unlink(self):
		table.Table.unlink(self)
kipp's avatar
kipp committed
788
		self.connection = None
789 790 791 792 793 794 795 796
		self.cursor = None

	def applyKeyMapping(self):
		"""
		Used as the second half of the key reassignment algorithm.
		Loops over each row in the table, replacing references to
		old row keys with the new values from the _idmap_ table.
		"""
797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818
		if self.remap_first_rowid is None:
			# no rows have been added since we processed this
			# table last
			return
		assignments = []
		for colname in self.dbcolumnnames:
			column = self.getColumnByName(colname)
			try:
				table_name = column.table_name
			except ValueError:
				# if we get here the column's name does not
				# have a table name component, so by
				# convention it cannot contain IDs pointing
				# to other tables
				continue
			# make sure it's not our own ID column (by
			# convention this should not be possible, but it
			# doesn't hurt to check)
			if self.next_id is not None and colname == self.next_id.column_name:
				continue
			assignments.append("%s = (SELECT new FROM _idmap_ WHERE _idmap_.table_name == \"%s\" AND _idmap_.old == %s)" % (colname, table_name, colname))
		assignments = ", ".join(assignments)
819 820 821 822 823 824 825 826
		if assignments:
			# SQLite documentation says ROWID is monotonically
			# increasing starting at 1 for the first row unless
			# it ever wraps around, then it is randomly
			# assigned.  ROWID is a 64 bit integer, so the only
			# way it will wrap is if somebody sets it to a very
			# high number manually.  This library does not do
			# that, so I don't bother checking.
827 828
			self.cursor.execute("UPDATE %s SET %s WHERE ROWID >= %d" % (self.Name, assignments, self.remap_first_rowid))
		self.remap_first_rowid = None
829 830 831 832 833 834 835 836 837 838 839


#
# =============================================================================
#
#                                  LSC Tables
#
# =============================================================================
#


840 841 842 843 844 845 846 847 848
class CoincMapTable(DBTable):
	tableName = lsctables.CoincMapTable.tableName
	validcolumns = lsctables.CoincMapTable.validcolumns
	constraints = lsctables.CoincMapTable.constraints
	next_id = lsctables.CoincMapTable.next_id
	RowType = lsctables.CoincMapTable.RowType
	how_to_index = lsctables.CoincMapTable.how_to_index

	def applyKeyMapping(self):
849 850 851
		if self.remap_first_rowid is not None:
			self.cursor.execute("UPDATE coinc_event_map SET event_id = (SELECT new FROM _idmap_ WHERE _idmap_.table_name == coinc_event_map.table_name AND old == event_id), coinc_event_id = (SELECT new FROM _idmap_ WHERE _idmap_.table_name == 'coinc_event' AND old == coinc_event_id) WHERE ROWID >= ?", (self.remap_first_rowid,))
			self.remap_first_rowid = None
852 853


854 855 856 857
class TimeSlideTable(DBTable):
	tableName = lsctables.TimeSlideTable.tableName
	validcolumns = lsctables.TimeSlideTable.validcolumns
	constraints = lsctables.TimeSlideTable.constraints
858
	next_id = lsctables.TimeSlideTable.next_id
859
	RowType = lsctables.TimeSlideTable.RowType
kipp's avatar
kipp committed
860
	how_to_index = lsctables.TimeSlideTable.how_to_index
861

862 863
	def as_dict(self):
		"""
Kipp Cannon's avatar
Kipp Cannon committed
864
		Return a dictionary mapping time slide IDs to offset
865 866
		dictionaries.
		"""
Kipp Cannon's avatar
Kipp Cannon committed
867 868 869 870 871
		# import is done here to reduce risk of a cyclic
		# dependency.  at the time of writing there is not one, but
		# we can help prevent it in the future by putting this
		# here.
		from lalburst import offsetvector
Kipp Cannon's avatar
Kipp Cannon committed
872
		return dict((time_slide_id, offsetvector.offsetvector((instrument, offset) for time_slide_id, instrument, offset in values)) for time_slide_id, values in itertools.groupby(self.cursor.execute("SELECT time_slide_id, instrument, offset FROM time_slide ORDER BY time_slide_id"), lambda time_slide_id_instrument_offset: time_slide_id_instrument_offset[0]))
873

874
	def get_time_slide_id(self, offsetdict, create_new = None, superset_ok = False, nonunique_ok = False):
875
		"""
876
		Return the time_slide_id corresponding to the offset vector
877
		described by offsetdict, a dictionary of instrument/offset
878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909
		pairs.

		If the optional create_new argument is None (the default),
		then the table must contain a matching offset vector.  The
		return value is the ID of that vector.  If the table does
		not contain a matching offset vector then KeyError is
		raised.

		If the optional create_new argument is set to a Process
		object (or any other object with a process_id attribute),
		then if the table does not contain a matching offset vector
		a new one will be added to the table and marked as having
		been created by the given process.  The return value is the
		ID of the (possibly newly created) matching offset vector.

		If the optional superset_ok argument is False (the default)
		then an offset vector in the table is considered to "match"
		the requested offset vector only if they contain the exact
		same set of instruments.  If the superset_ok argument is
		True, then an offset vector in the table is considered to
		match the requested offset vector as long as it provides
		the same offsets for the same instruments as the requested
		vector, even if it provides offsets for other instruments
		as well.

		More than one offset vector in the table might match the
		requested vector.  If the optional nonunique_ok argument is
		False (the default), then KeyError will be raised if more
		than one offset vector in the table is found to match the
		requested vector.  If the optional nonunique_ok is True
		then the return value is the ID of one of the matching
		offset vectors selected at random.
910
		"""
911 912
		# look for matching offset vectors
		if superset_ok:
913
			ids = [time_slide_id for time_slide_id, slide in self.as_dict().items() if offsetdict == dict((instrument, offset) for instrument, offset in slide.items() if instrument in offsetdict)]
914
		else:
915
			ids = [time_slide_id for time_slide_id, slide in self.as_dict().items() if offsetdict == slide]
916 917 918 919 920 921
		if len(ids) > 1:
			# found more than one
			if nonunique_ok:
				# and that's OK
				return ids[0]
			# and that's not OK
922
			raise KeyError(offsetdict)
923 924 925 926
		if len(ids) == 1:
			# found one
			return ids[0]
		# offset vector not found in table
927
		if create_new is None:
928
			# and that's not OK
929
			raise KeyError(offsetdict)
930
		# that's OK, create new vector
931
		time_slide_id = self.get_next_id()
932
		for instrument, offset in offsetdict.items():
933 934
			row = self.RowType()
			row.process_id = create_new.process_id
935
			row.time_slide_id = time_slide_id
936 937 938 939 940
			row.instrument = instrument
			row.offset = offset
			self.append(row)

		# return new ID
941
		return time_slide_id
942

943 944 945 946 947 948 949 950 951 952

#
# =============================================================================
#
#                                Table Metadata
#
# =============================================================================
#


953
def build_indexes(connection, verbose = False):
954 955
	"""
	Using the how_to_index annotations in the table class definitions,
956
	construct a set of indexes for the database at the given
957 958
	connection.
	"""
959
	cursor = connection.cursor()
960
	for table_name in get_table_names(connection):
961 962
		# FIXME:  figure out how to do this extensibly
		if table_name in TableByName:
963
			how_to_index = TableByName[table_name].how_to_index
964 965 966
		elif table_name in lsctables.TableByName:
			how_to_index = lsctables.TableByName[table_name].how_to_index
		else:
967
			continue
968
		if how_to_index is not None:
969
			if verbose:
970
				sys.stderr.write("indexing %s table ...\n" % table_name)
971
			for index_name, cols in how_to_index.items():
972
				cursor.execute("CREATE INDEX IF NOT EXISTS %s ON %s (%s)" % (index_name, table_name, ",".join(cols)))
Kipp Cannon's avatar
Kipp Cannon committed
973
	connection.commit()
974 975 976 977 978 979 980 981 982 983 984


#
# =============================================================================
#
#                                Table Metadata
#
# =============================================================================
#


985 986 987 988 989 990
#
# Table name ---> table type mapping.
#


TableByName = {
991
	CoincMapTable.tableName: CoincMapTable,
992
	TimeSlideTable.tableName: TimeSlideTable
993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005
}


#
# =============================================================================
#
#                               Content Handler
#
# =============================================================================
#


#
1006
# Override portions of a ligolw.LIGOLWContentHandler class
1007 1008 1009
#


1010 1011 1012
def use_in(ContentHandler):
	"""
	Modify ContentHandler, a sub-class of
1013
	ligo.lw.ligolw.LIGOLWContentHandler, to cause it to use the DBTable
1014 1015 1016 1017 1018 1019
	class defined in this module when parsing XML documents.  Instances
	of the class must provide a connection attribute.  When a document
	is parsed, the value of this attribute will be passed to the
	DBTable class' .__init__() method as each table object is created,
	and thus sets the database connection for all table objects in the
	document.
1020 1021 1022

	Example:

1023
	>>> import sqlite3
1024
	>>> from ligo.lw import ligolw
Kipp Cannon's avatar
Kipp Cannon committed
1025
	>>> class MyContentHandler(ligolw.LIGOLWContentHandler):
1026 1027 1028
	...	def __init__(self, *args):
	...		super(MyContentHandler, self).__init__(*args)
	...		self.connection = sqlite3.connection()
1029
	...
1030
	>>> use_in(MyContentHandler)
1031 1032 1033

	Multiple database files can be in use at once by creating a content
	handler class for each one.
1034
	"""
1035
	ContentHandler = lsctables.use_in(ContentHandler)
kipp's avatar
kipp committed
1036

1037 1038 1039 1040 1041 1042
	def startStream(self, parent, attrs, __orig_startStream = ContentHandler.startStream):
		if parent.tagName == ligolw.Table.tagName:
			parent._end_of_columns()
			return DBTableStream(attrs).config(parent)
		return __orig_startStream(self, parent, attrs)

1043
	def startTable(self, parent, attrs):
1044
		name = table.Table.TableName(attrs[u"Name"])
1045
		if name in TableByName:
1046 1047
			return TableByName[name](attrs, connection = self.connection)
		return DBTable(attrs, connection = self.connection)
kipp's avatar
kipp committed
1048

1049
	ContentHandler.startStream = startStream
1050
	ContentHandler.startTable = startTable
1051

1052
	return ContentHandler