From 86440ca57c710fdeba36e508353f07ced705dbb2 Mon Sep 17 00:00:00 2001
From: Surabhi Sachdev <surabhi.sachdev@ligo.org>
Date: Fri, 9 Oct 2015 19:11:00 -0500
Subject: [PATCH] gstlal_bank_splitter: Fix missing templates

add padding at the edge of chi bins to
take care of the clipping
---
 gstlal-inspiral/bin/gstlal_bank_splitter | 111 +++++++++++------------
 1 file changed, 52 insertions(+), 59 deletions(-)

diff --git a/gstlal-inspiral/bin/gstlal_bank_splitter b/gstlal-inspiral/bin/gstlal_bank_splitter
index 42f5076396..14678cd15a 100755
--- a/gstlal-inspiral/bin/gstlal_bank_splitter
+++ b/gstlal-inspiral/bin/gstlal_bank_splitter
@@ -55,10 +55,8 @@ from gstlal import chirptime
 #	+ `--n` [count] (int): Set the number of templates per output file (required).  It will be rounded to make all sub banks approximately the same size.
 #	+ `--overlap` [count] (int): Overlap the templates in each file by this amount, must be even.
 #	+ `--sort-by` [column]: Select the template sort order column (required).
-#	+ `--add-f-final`: Select whether to add f_final to the bank.
 #	+ `--max-f-final` [max final freq] (float): Max f_final to populate table with; if f_final > max, use max.
 #	+ `--instrument` [ifo]: Override the instrument, required
-#	+ `--bank-program` [name]: Select name of the program used to generate the template bank (default: tmpltbank).
 #	+ `--verbose`: Be verbose.
 #	+ `--approximant` [string]: Must specify an approximant
 #	+ `--f-low` [frequency] (floate): Lower frequency cutoff
@@ -103,7 +101,6 @@ def group_templates(templates, n, overlap = 0):
 			if end >= len(templates):
 				break
 
-
 def parse_command_line():
 	parser = OptionParser()
 	parser.add_option("--output-path", metavar = "path", default = ".", help = "Set the path to the directory where output files will be written.  Default is \".\".")
@@ -111,17 +108,15 @@ def parse_command_line():
 	parser.add_option("--n", metavar = "count", type = "int", help = "Set the number of templates per output file (required). It will be rounded to make all sub banks approximately the same size.")
 	parser.add_option("--overlap", default = 0, metavar = "count", type = "int", help = "overlap the templates in each file by this amount, must be even")
 	parser.add_option("--sort-by", metavar = "column", default="mchirp", help = "Select the template sort column, default mchirp")
-	parser.add_option("--add-f-final", action = "store_true", help = "Select whether to add f_final to the bank.")
 	parser.add_option("--max-f-final", metavar = "float", type="float", help = "Max f_final to populate table with; if f_final over mx, use max.")
 	parser.add_option("--instrument", metavar = "ifo", type="string", help = "override the instrument, required")
-	parser.add_option("--bank-program", metavar = "name", default = "tmpltbank", type="string", help = "Select name of the program used to generate the template bank (default: tmpltbank).")
 	parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
 	parser.add_option("--approximant", type = "string", help = "Must specify an approximant")
-	parser.add_option("--f-low", type = "float", metavar = "frequency", help = "Lower frequency cutoff")
+	parser.add_option("--f-low", type = "float", metavar = "frequency", help = "Lower frequency cutoff. Required")
 	parser.add_option("--group-by-chi", type = "int", metavar = "N", default = 1, help = "group templates into N groups of chi - helps with SVD. Default 1")
 	options, filenames = parser.parse_args()
 
-	required_options = ("n", "instrument", "sort_by", "output_cache", "approximant")
+	required_options = ("n", "instrument", "sort_by", "output_cache", "approximant", "f_low")
 	missing_options = [option for option in required_options if getattr(options, option) is None]
 	if missing_options:
 		raise ValueError, "missing required option(s) %s" % ", ".join("--%s" % option.replace("_", "-") for option in missing_options)
@@ -129,82 +124,80 @@ def parse_command_line():
 	if options.overlap % 2:
 		raise ValueError("overlap must be even")
 
-	return options, filenames
+	if len(filenames) !=1:
+		raise ValueError("Must give exactly one file name")
+
+	return options, filenames[0]
 
-options, filenames = parse_command_line()
+options, filename = parse_command_line()
 output_cache_file = open(options.output_cache, "w")
 bank_count = 0
 
 outputrows = []
 
-for filename in filenames:
-	xmldoc = ligolw_utils.load_filename(filename, verbose = options.verbose, contenthandler = LIGOLWContentHandler)
-	sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)
-
-	if options.add_f_final:
-		if options.f_low is None:
-			flow, = ligolw_process.get_process_params(xmldoc, options.bank_program, "--flow") + ligolw_process.get_process_params(xmldoc, options.bank_program, "--low-frequency-cutoff") + ligolw_process.get_process_params(xmldoc, options.bank_program, "--f-low")
-		else:
-			flow = options.f_low
-		for row in sngl_inspiral_table:
-			# Find the total spin magnitudes
-			spin1, spin2 = (row.spin1x**2 + row.spin1y**2 + row.spin1z**2)**.5, (row.spin2x**2 + row.spin2y**2 + row.spin2z**2)**.5
-			# Chirptime uses SI
-			m1_SI, m2_SI = MSUN_SI * row.mass1, MSUN_SI * row.mass2
+xmldoc = ligolw_utils.load_filename(filename, verbose = options.verbose, contenthandler = LIGOLWContentHandler)
+sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)
 
-			if options.approximant in templates.gstlal_IMR_approximants:
-				# make sure to go a factor of 2 above the ringdown frequency for safety
-				row.f_final = 2 * chirptime.ringf(m1_SI + m2_SI, chirptime.overestimate_j_from_chi(max(spin1, spin2)))
-			else:
-				# otherwise choose a suitable high frequency
-				# NOTE not SI
-				row.f_final = spawaveform.ffinal(row.mass1, row.mass2, 'bkl_isco')
-
-			# Override the high frequency with the max if appropriate
-			if options.max_f_final and (row.f_final > options.max_f_final):
-				row.f_final = options.max_f_final
-
-			# Record the conservative template duration
-			row.template_duration = chirptime.imr_time(flow, m1_SI, m2_SI, spin1, spin2, f_max = row.f_final)
+for row in sngl_inspiral_table:
+	# Find the total spin magnitudes
+	spin1, spin2 = (row.spin1x**2 + row.spin1y**2 + row.spin1z**2)**.5, (row.spin2x**2 + row.spin2y**2 + row.spin2z**2)**.5
+	# Chirptime uses SI
+	m1_SI, m2_SI = MSUN_SI * row.mass1, MSUN_SI * row.mass2
 
+	if options.approximant in templates.gstlal_IMR_approximants:
+		# make sure to go a factor of 2 above the ringdown frequency for safety
+		row.f_final = 2 * chirptime.ringf(m1_SI + m2_SI, chirptime.overestimate_j_from_chi(max(spin1, spin2)))
+	else:
+		# otherwise choose a suitable high frequency
+		# NOTE not SI
+		row.f_final = spawaveform.ffinal(row.mass1, row.mass2, 'bkl_isco')
 
-	for row in sngl_inspiral_table:
-		row.ifo = options.instrument
+	# Override the high frequency with the max if appropriate
+	if options.max_f_final and (row.f_final > options.max_f_final):
+		row.f_final = options.max_f_final
 
-	# just to make sure it is set
-	for row in sngl_inspiral_table:
-		row.mtotal = row.mass1 + row.mass2
-	
-	# Bin by Chi, has no effect if option is not specified, i.e. there is only one bin.
-	chidict = {}
-	templates_by_chi = [tmp[1] for tmp in sorted([(spawaveform.computechi(row.mass1, row.mass2, row.spin1z, row.spin2z), row) for row in sngl_inspiral_table])]
-	for i, rows in enumerate(group_templates(templates_by_chi, len(templates_by_chi) / options.group_by_chi, overlap = 0)):
-		chidict[i] = rows
+	# Record the conservative template duration
+	row.template_duration = chirptime.imr_time(options.f_low, m1_SI, m2_SI, spin1, spin2, f_max = row.f_final)
 
-	for chi in chidict:
-		chirows = chidict[chi]
+for row in sngl_inspiral_table:
+	row.ifo = options.instrument
 
-		# store the process params
-		process = ligolw_process.register_to_xmldoc(xmldoc, program = "gstlal_bank_splitter", paramdict = options.__dict__, comment = "Split bank into smaller banks for SVD")
+# just to make sure it is set
+for row in sngl_inspiral_table:
+	row.mtotal = row.mass1 + row.mass2
 
-		def sort_func(row, column = options.sort_by):
-			return getattr(row, column)
+# Bin by Chi
+sngl_inspiral_table.sort(key = lambda row: spawaveform.computechi(row.mass1, row.mass2, row.spin1z, row.spin2z))
+for chirows in group_templates(sngl_inspiral_table, len(sngl_inspiral_table) / options.group_by_chi, overlap = 0):
 
-		chirows.sort(key=sort_func)
+	# store the process params
+	process = ligolw_process.register_to_xmldoc(xmldoc, program = "gstlal_bank_splitter", paramdict = options.__dict__, comment = "Split bank into smaller banks for SVD")
 
+	def sort_func(row, column = options.sort_by):
+		return getattr(row, column)
 
-		for rows in group_templates(chirows, options.n, options.overlap):
-			assert len(rows) >= options.n/2, "There are too few templates in this chi interval.  Requested %d: have %d" % (options.n, len(rows))
-			outputrows.append((rows[0], rows))
+	chirows.sort(key=sort_func)
 
+	for numrow, rows in enumerate(group_templates(chirows, options.n, options.overlap)):
+		assert len(rows) >= options.n/2, "There are too few templates in this chi interval.  Requested %d: have %d" % (options.n, len(rows))
+		# Pad the first group with an extra overlap / 2 templates
+		if numrow == 0:
+			rows = rows[:options.overlap/2] + rows
+		outputrows.append((rows[0], rows))
+	# Pad the last group with an extra overlap / 2 templates
+	outputrows[-1] = (rows[0], rows + rows[-options.overlap/2:])
 
-# One last sort now that the templates have been grouped
+# A sort of the groups of templates so that the sub banks are ordered.
 def sort_func((row, rows), column = options.sort_by):
 	return getattr(row, column)
 
 outputrows.sort(key=sort_func)
 
 for bank_count, (row, rows) in enumerate(outputrows):
+	if bank_count == 0:
+		rows = rows[options.overlap/2:]
+	if bank_count == len(outputrows) - 1:
+		rows = rows[:-options.overlap/2]
 	sngl_inspiral_table[:] = rows
 	output = inspiral_pipe.T050017_filename(options.instrument, "GSTLAL_SPLIT_BANK_%04d" % bank_count, 0, 0, ".xml.gz", path = options.output_path)
 	output_cache_file.write("%s\n" % gluelal.CacheEntry.from_T050017("file://localhost%s" % os.path.abspath(output)))
-- 
GitLab