diff --git a/python/gstlal_svd_bank.py b/python/gstlal_svd_bank.py index 0d8384c0363a46934f215b27495444f9fd7f40a3..5d6f84024465cd9b5c7d6278a6ff656a05b484b0 100755 --- a/python/gstlal_svd_bank.py +++ b/python/gstlal_svd_bank.py @@ -29,15 +29,13 @@ __all__ = ("Bank", "build_bank", "read_bank", "write_bank") import numpy import sys -import cPickle try: all except NameError: # Python < 2.5 compatibility from glue.iterutils import all -from glue.ligolw import lsctables -from glue.ligolw import utils +from glue.ligolw import ligolw, lsctables, array, param, utils, types from glue.ligolw.utils import process as ligolw_process @@ -198,17 +196,94 @@ def build_bank(template_bank_filename, psd, flow, ortho_gate_fap, snr_threshold, def write_bank(filename, bank): - f = open(filename, "wb") - try: - cPickle.dump(bank, f, -1) - finally: - f.close() + """Write an SVD bank to a LIGO_LW xml file.""" + + # Create new document + xmldoc = ligolw.Document() + root = ligolw.LIGO_LW() + + # Add root-level scalar params + root.appendChild(param.new_param('filter_length', types.FromPyType[float], bank.filter_length)) + root.appendChild(param.new_param('gate_threshold', types.FromPyType[float], bank.gate_threshold)) + root.appendChild(param.new_param('logname', types.FromPyType[str], bank.logname)) + root.appendChild(param.new_param('snr_threshold', types.FromPyType[float], bank.snr_threshold)) + root.appendChild(param.new_param('template_bank_filename', types.FromPyType[str], bank.template_bank_filename)) + + # Add root-level arrays + root.appendChild(array.from_array('autocorrelation_bank_real', bank.autocorrelation_bank.real)) + root.appendChild(array.from_array('autocorrelation_bank_imag', bank.autocorrelation_bank.imag)) + root.appendChild(array.from_array('sigmasq', numpy.array(bank.sigmasq))) + + # Write bank fragments + for i, frag in enumerate(bank.bank_fragments): + # Start new container + el = ligolw.LIGO_LW() + + # Add scalar params + el.appendChild(param.new_param('start', types.FromPyType[float], frag.start)) + el.appendChild(param.new_param('end', types.FromPyType[float], frag.end)) + el.appendChild(param.new_param('rate', types.FromPyType[int], frag.rate)) + + # Add arrays + el.appendChild(array.from_array('chifacs', frag.chifacs)) + el.appendChild(array.from_array('mix_matrix', frag.mix_matrix)) + el.appendChild(array.from_array('orthogonal_template_bank', frag.orthogonal_template_bank)) + el.appendChild(array.from_array('singular_values', frag.singular_values)) + el.appendChild(array.from_array('sum_of_squares_weights', frag.sum_of_squares_weights)) + + # Add bank fragment container to root container + root.appendChild(el) + + # Add root container to document + xmldoc.appendChild(root) + + # Write to file + utils.write_filename(xmldoc, filename, gz=filename.endswith('.gz')) def read_bank(filename): - f = open(filename, "rb") - try: - bank = cPickle.load(f) - finally: - f.close() + """Read an SVD bank from a LIGO_LW xml file.""" + + # Load document + xmldoc = utils.load_filename(filename, gz=filename.endswith('.gz')) + root = xmldoc.childNodes[0] + + # Create new SVD bank object + bank = Bank.__new__(Bank) + + # Read root-level scalar parameters + bank.filter_length = param.get_pyvalue(root, 'filter_length') + bank.gate_threshold = param.get_pyvalue(root, 'gate_threshold') + bank.logname = param.get_pyvalue(root, 'logname') + bank.snr_threshold = param.get_pyvalue(root, 'snr_threshold') + bank.template_bank_filename = param.get_pyvalue(root, 'template_bank_filename') + + # Read root-level arrays + autocorrelation_bank_real = array.get_array(root, 'autocorrelation_bank_real').array + autocorrelation_bank_imag = array.get_array(root, 'autocorrelation_bank_imag').array + bank.autocorrelation_bank = autocorrelation_bank_real + (0+1j) * autocorrelation_bank_imag + bank.sigmasq = array.get_array(root, 'sigmasq').array + + bank_fragments = [] + + # Read bank fragments + for el in (node for node in root.childNodes if node.tagName == 'LIGO_LW'): + frag = BankFragment.__new__(BankFragment) + + # Read scalar params + frag.start = param.get_pyvalue(el, 'start') + frag.end = param.get_pyvalue(el, 'end') + frag.rate = param.get_pyvalue(el, 'rate') + + # Read arrays + frag.chifacs = array.get_array(el, 'chifacs').array + frag.mix_matrix = array.get_array(el, 'mix_matrix').array + frag.orthogonal_template_bank = array.get_array(el, 'orthogonal_template_bank').array + frag.singular_values = array.get_array(el, 'singular_values').array + frag.sum_of_squares_weights = array.get_array(el, 'sum_of_squares_weights').array + + bank_fragments.append(frag) + + bank.bank_fragments = bank_fragments + return bank