From 214489cd4041e071f97373a1322a30b75969c024 Mon Sep 17 00:00:00 2001
From: ChiWai Chan <chiwai.chan@ligo.org>
Date: Wed, 16 Sep 2020 05:31:28 -0700
Subject: [PATCH] svd_bank.py: moved the clipping processes to the Bank class
 and modified gstlal_svd_bank accordingly.

---
 gstlal-inspiral/bin/gstlal_svd_bank | 10 ++---
 gstlal-inspiral/python/svd_bank.py  | 58 +++++++++++++++--------------
 2 files changed, 35 insertions(+), 33 deletions(-)

diff --git a/gstlal-inspiral/bin/gstlal_svd_bank b/gstlal-inspiral/bin/gstlal_svd_bank
index db1585a134..bea8bf7bd9 100755
--- a/gstlal-inspiral/bin/gstlal_svd_bank
+++ b/gstlal-inspiral/bin/gstlal_svd_bank
@@ -148,19 +148,19 @@ svd_bank.write_bank(
 		options.ortho_gate_fap,
 		inspiral_lr.LnLRDensity.snr_min,
 		options.svd_tolerance,
+		clipleft,
+		clipright,
 		padding = options.padding,
 		identity_transform = options.identity_transform,
 		verbose = options.verbose,
 		autocorrelation_length = options.autocorrelation_length,
 		samples_min = options.samples_min,
 		samples_max_256 = options.samples_max_256,
-		samples_max_64 = options.samples_max_64, 
+		samples_max_64 = options.samples_max_64,
 		samples_max = options.samples_max,
 		bank_id = bank_id,
 		contenthandler = svd_bank.DefaultContentHandler,
 		sample_rate = options.sample_rate
-	) for (template_bank, bank_id) in zip(options.template_bank, options.bank_id)],
-	psd,
-	options.clipleft,
-	options.clipright
+	) for (template_bank, bank_id, clipleft, clipright) in zip(options.template_bank, options.bank_id, options.clipleft, options.clipright)],
+	psd
 )
diff --git a/gstlal-inspiral/python/svd_bank.py b/gstlal-inspiral/python/svd_bank.py
index b0060c947f..4cfb7df545 100644
--- a/gstlal-inspiral/python/svd_bank.py
+++ b/gstlal-inspiral/python/svd_bank.py
@@ -174,7 +174,7 @@ class BankFragment(object):
 
 
 class Bank(object):
-	def __init__(self, bank_xmldoc, psd, time_slices, gate_fap, snr_threshold, tolerance, flow = 40.0, autocorrelation_length = None, logname = None, identity_transform = False, verbose = False, bank_id = None, fhigh = None):
+	def __init__(self, bank_xmldoc, psd, time_slices, gate_fap, snr_threshold, tolerance, clipleft = None, clipright = None, flow = 40.0, autocorrelation_length = None, logname = None, identity_transform = False, verbose = False, bank_id = None, fhigh = None):
 		# FIXME: remove template_bank_filename when no longer needed
 		# by trigger generator element
 		self.template_bank_filename = None
@@ -231,6 +231,27 @@ class Bank(object):
 		if verbose:
 			print("sum-of-squares threshold for false-alarm probability of %.16g:  %.16g" % (gate_fap, self.gate_threshold), file=sys.stderr)
 
+		# Sanity checks before cliping
+		clipright = len(self.sngl_inspiral_table) - clipright if clipright is not None else None
+		doubled_clipright = clipright * 2 if clipright is not None else None
+		doubled_clipleft = clipleft * 2 if clipleft is not None else None
+
+		# Apply clipping options
+		new_sngl_table = self.sngl_inspiral_table.copy()
+		for row in self.sngl_inspiral_table[clipleft:clipright]:
+			# FIXME need a proper id column
+			row.Gamma1 = int(self.bank_id.split("_")[0])
+			new_sngl_table.append(row)
+		self.sngl_inspiral_table = new_sngl_table
+		self.autocorrelation_bank = self.autocorrelation_bank[clipleft:clipright,:]
+		self.autocorrelation_mask = self.autocorrelation_mask[clipleft:clipright,:]
+		self.sigmasq = self.sigmasq[clipleft:clipright]
+		self.bank_correlation_matrix = self.bank_correlation_matrix[clipleft:clipright,clipleft:clipright]
+		for i, frag in enumerate(self.bank_fragments):
+			if frag.mix_matrix is not None:
+				frag.mix_matrix = frag.mix_matrix[:,doubled_clipleft:doubled_clipright]
+			frag.chifacs = frag.chifacs[doubled_clipleft:doubled_clipright]
+
 	def get_rates(self):
 		return set(bank_fragment.rate for bank_fragment in self.bank_fragments)
 
@@ -241,7 +262,7 @@ class Bank(object):
 
 
 
-def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_tolerance, padding = 1.5, identity_transform = False, verbose = False, autocorrelation_length = 201, samples_min = 1024, samples_max_256 = 1024, samples_max_64 = 2048, samples_max = 4096, bank_id = None, contenthandler = None, sample_rate = None, instrument_override = None):
+def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_tolerance, clipleft = None, clipright = None, padding = 1.5, identity_transform = False, verbose = False, autocorrelation_length = 201, samples_min = 1024, samples_max_256 = 1024, samples_max_64 = 2048, samples_max = 4096, bank_id = None, contenthandler = None, sample_rate = None, instrument_override = None):
 	"""!
 	Return an instance of a Bank class.
 
@@ -251,6 +272,8 @@ def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_
 	@param ortho_gate_fap The FAP threshold for the sum of squares threshold, see http://arxiv.org/abs/1101.0584
 	@param snr_threshold The SNR threshold for the search
 	@param svd_tolerance The target SNR loss of the SVD, see http://arxiv.org/abs/1005.0012
+	@param clipleft The number of N poorly reconstructed templates from the left edge of each sub-bank to be removed
+	@param cliptright The number of N poorly reconstructed templates from the right edge of each sub-bank to be removed
 	@param padding The padding from Nyquist for any template time slice, e.g., if a time slice has a Nyquist of 256 Hz and the padding is set to 2, only allow the template frequency to extend to 128 Hz.
 	@param identity_transform Don't do the SVD, just do time slices and keep the raw waveforms
 	@param verbose Be verbose
@@ -300,6 +323,8 @@ def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_
 		gate_fap = ortho_gate_fap,
 		snr_threshold = snr_threshold,
 		tolerance = svd_tolerance,
+		clipleft = clipleft,
+		clipright = clipright,
 		flow = flow,
 		autocorrelation_length = autocorrelation_length,	# samples
 		identity_transform = identity_transform,
@@ -314,31 +339,19 @@ def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_
 	return bank
 
 
-def write_bank(filename, banks, psd_input, cliplefts = None, cliprights = None, verbose = False):
+def write_bank(filename, banks, psd_input, verbose = False):
 	"""Write SVD banks to a LIGO_LW xml file."""
 
 	# Create new document
 	xmldoc = ligolw.Document()
 	lw = xmldoc.appendChild(ligolw.LIGO_LW())
 
-	for bank, clipleft, clipright in zip(banks, cliplefts, cliprights):
+	for bank in banks:
 		# set up root for this sub bank
 		root = lw.appendChild(ligolw.LIGO_LW(Attributes({u"Name": u"gstlal_svd_bank_Bank"})))
 
-		# FIXME FIXME FIXME move this clipping stuff to the Bank class
-		# set the right clipping index
-		clipright = len(bank.sngl_inspiral_table) - clipright
-
-		# Apply clipping option to sngl inspiral table
-		# put the bank table into the output document
-		new_sngl_table = bank.sngl_inspiral_table.copy()
-		for row in bank.sngl_inspiral_table[clipleft:clipright]:
-			# FIXME need a proper id column
-			row.Gamma1 = int(bank.bank_id.split("_")[0])
-			new_sngl_table.append(row)
-
 		# put the possibly clipped table into the file
-		root.appendChild(new_sngl_table)
+		root.appendChild(bank.sngl_inspiral_table)
 
 		# Add root-level scalar params
 		root.appendChild(ligolw_param.Param.from_pyvalue('filter_length', bank.filter_length))
@@ -353,12 +366,6 @@ def write_bank(filename, banks, psd_input, cliplefts = None, cliprights = None,
 		root.appendChild(ligolw_param.Param.from_pyvalue('sample_rate_max', int(bank.sample_rate_max)))
 		root.appendChild(ligolw_param.Param.from_pyvalue('gstlal_fir_whiten', os.environ['GSTLAL_FIR_WHITEN']))
 
-		# apply clipping to autocorrelations and sigmasq
-		bank.autocorrelation_bank = bank.autocorrelation_bank[clipleft:clipright,:]
-		bank.autocorrelation_mask = bank.autocorrelation_mask[clipleft:clipright,:]
-		bank.sigmasq = bank.sigmasq[clipleft:clipright]
-		bank.bank_correlation_matrix = bank.bank_correlation_matrix[clipleft:clipright,clipleft:clipright]
-
 		# Add root-level arrays
 		# FIXME:  ligolw format now supports complex-valued data
 		root.appendChild(ligolw_array.Array.build('autocorrelation_bank_real', bank.autocorrelation_bank.real))
@@ -373,11 +380,6 @@ def write_bank(filename, banks, psd_input, cliplefts = None, cliprights = None,
 			# Start new bank fragment container
 			el = root.appendChild(ligolw.LIGO_LW())
 
-			# Apply clipping option
-			if frag.mix_matrix is not None:
-				frag.mix_matrix = frag.mix_matrix[:,clipleft*2:clipright*2]
-			frag.chifacs = frag.chifacs[clipleft*2:clipright*2]
-
 			# Add scalar params
 			el.appendChild(ligolw_param.Param.from_pyvalue('rate', int(frag.rate)))
 			el.appendChild(ligolw_param.Param.from_pyvalue('start', frag.start))
-- 
GitLab