From b00bc6f5cfc9ed1beee22ee8096e188fcd92f24e Mon Sep 17 00:00:00 2001
From: Surabhi Sachdev <surabhi.sachdev@ligo.org>
Date: Thu, 29 Nov 2018 22:34:34 -0800
Subject: [PATCH] To allow for a different highest sampling rate than 2*fhigh

 - gstlal_inspiral_pipe and gstlal_svd_bank take as option
--sample-rate
 - the waveforms that are generated with a lower fhigh
are shifted forward such that all waveforms in a given
sub bank have a common amount of early warning
 - this has the effect that the actual fhigh of waveforms
is less than requested in a given sub bank, since the
waveforms are shifted according to the waveform that takes
longest to go from fhigh to ISCO

Signed-off-by: Surabhi Sachdev <surabhi.sachdev@ligo.org>
---
 gstlal-inspiral/bin/gstlal_inspiral_pipe      |  5 ++
 .../bin/gstlal_inspiral_plot_svd_bank         |  2 +-
 gstlal-inspiral/bin/gstlal_svd_bank           |  9 ++-
 gstlal-inspiral/python/cbc_template_fir.py    | 61 ++++++++++++++++---
 gstlal-inspiral/python/svd_bank.py            | 10 ++-
 gstlal-inspiral/python/templates.py           |  7 ++-
 6 files changed, 78 insertions(+), 16 deletions(-)

diff --git a/gstlal-inspiral/bin/gstlal_inspiral_pipe b/gstlal-inspiral/bin/gstlal_inspiral_pipe
index 8f98bd4a3c..9b5d4d3598 100755
--- a/gstlal-inspiral/bin/gstlal_inspiral_pipe
+++ b/gstlal-inspiral/bin/gstlal_inspiral_pipe
@@ -207,6 +207,7 @@ def svd_node_gen(svdJob, dag, parent_nodes, psd, bank_cache, options, seg, templ
 					parent_nodes = parent_nodes,
 					opts = {"svd-tolerance":options.tolerance,
 						"flow":options.flow,
+						"sample-rate":options.sample_rate,
 						"clipleft":clipleft,
 						"clipright":clipright,
 						"samples-min":options.samples_min[j],
@@ -1025,6 +1026,7 @@ def parse_command_line():
 	parser.add_option("--bank-cache", metavar = "filenames", action = "append", help = "Set the bank cache files in format H1=H1.cache,H2=H2.cache, etc.. (can be given multiple times)")
 	parser.add_option("--tolerance", metavar = "float", type = "float", default = 0.9999, help = "set the SVD tolerance, default 0.9999")
 	parser.add_option("--flow", metavar = "num", type = "float", default = 40, help = "set the low frequency cutoff, default 40 (Hz)")
+	parser.add_option("--sample-rate", metavar = "Hz", type = "int", help = "Set the sample rate.  If not set, the sample rate will be based on the template frequency.  The sample rate must be at least twice the highest frequency in the templates. If provided it must be a power of two")
 	parser.add_option("--identity-transform", action = "store_true", help = "Use identity transform, i.e. no SVD")
 	
 	# trigger generation options
@@ -1086,6 +1088,9 @@ def parse_command_line():
 	if options.num_banks:
 		options.num_banks = [int(v) for v in options.num_banks.split(",")]
 
+	if options.sample_rate is not None and (not numpy.log2(options.sample_rate) == int(numpy.log2(options.sample_rate))):
+		raise ValueError("--sample-rate must be a power of two")
+
 	if not options.samples_min and not options.svd_bank_cache:
 		options.samples_min = [1024]*len(options.bank_cache)
 
diff --git a/gstlal-inspiral/bin/gstlal_inspiral_plot_svd_bank b/gstlal-inspiral/bin/gstlal_inspiral_plot_svd_bank
index c633702448..b78feac0b6 100755
--- a/gstlal-inspiral/bin/gstlal_inspiral_plot_svd_bank
+++ b/gstlal-inspiral/bin/gstlal_inspiral_plot_svd_bank
@@ -79,7 +79,7 @@ for bank in svd_bank.read_banks(form.getlist("url")[0], svd_bank.DefaultContentH
 			maxrate = max(maxrate, frag.rate)
 			mint = min(t[0], mint)
 		pyplot.xlabel('time (s) mc:%.2f s1z:%.2f s2z:%.2f' % (mc, s1z, s2z))
-		pyplot.xlim([-.25, 0])
+		#pyplot.xlim([-.25, 0])
 
 		pyplot.subplot(212)
 		pyplot.xlabel('samples')
diff --git a/gstlal-inspiral/bin/gstlal_svd_bank b/gstlal-inspiral/bin/gstlal_svd_bank
index b654dad966..061467616e 100755
--- a/gstlal-inspiral/bin/gstlal_svd_bank
+++ b/gstlal-inspiral/bin/gstlal_svd_bank
@@ -30,6 +30,7 @@
 
 
 from optparse import OptionParser
+import numpy
 
 
 import lal.series
@@ -78,6 +79,7 @@ from gstlal.stats import inspiral_lr
 
 parser = OptionParser(description = __doc__)
 parser.add_option("--flow", metavar = "Hz", type = "float", default = 40.0, help = "Set the template low-frequency cut-off (default = 40.0).")
+parser.add_option("--sample-rate", metavar = "Hz", type = "int", help = "Set the sample rate.  If not set, the sample rate will be based on the template frequency.  The sample rate must be at least twice the highest frequency in the templates. If provided it must be a power of two")
 parser.add_option("--identity-transform", action = "store_true", default = False, help = "Do not perform an SVD; instead, use the original templates as the analyzing templates.")
 parser.add_option("--padding", metavar = "pad", type = "float", default = 1.5, help = "Fractional amount to pad time slices.")
 parser.add_option("--svd-tolerance", metavar = "match", type = "float", default = 0.9995, help = "Set the SVD reconstruction tolerance (default = 0.9995).")
@@ -113,6 +115,10 @@ if not (len(options.template_bank) == len(options.clipleft) == len(options.clipr
 if not options.autocorrelation_length % 2:
 	raise ValueError("--autocorrelation-length must be odd")
 
+if options.sample_rate is not None and (not numpy.log2(options.sample_rate) == int(numpy.log2(options.sample_rate))):
+	raise ValueError("--sample-rate must be a power of two")
+
+
 
 #
 #
@@ -145,7 +151,8 @@ svd_bank.write_bank(
 		samples_max_64 = options.samples_max_64, 
 		samples_max = options.samples_max,
 		bank_id = bank_id,
-		contenthandler = svd_bank.DefaultContentHandler
+		contenthandler = svd_bank.DefaultContentHandler,
+		sample_rate = options.sample_rate
 	) for (template_bank, bank_id) in zip(options.template_bank, options.bank_id)],
 	options.clipleft,
 	options.clipright
diff --git a/gstlal-inspiral/python/cbc_template_fir.py b/gstlal-inspiral/python/cbc_template_fir.py
index 6dafb47c6b..8526c6b34e 100644
--- a/gstlal-inspiral/python/cbc_template_fir.py
+++ b/gstlal-inspiral/python/cbc_template_fir.py
@@ -110,6 +110,7 @@ def generate_template(template_bank_row, approximant, sample_rate, duration, f_l
 	"""
 	if approximant not in templates.gstlal_approximants:
 		raise ValueError("Unsupported approximant given %s" % approximant)
+	assert f_high <= sample_rate // 2
 
 	# FIXME use hcross somday?
 	# We don't here because it is not guaranteed to be orthogonal
@@ -138,8 +139,21 @@ def generate_template(template_bank_row, approximant, sample_rate, duration, f_l
 	parameters['approximant'] = lalsim.GetApproximantFromString(str(approximant))
 
 	hplus, hcross = lalsim.SimInspiralFD(**parameters)
-	# NOTE assumes fhigh is the Nyquist frequency!!!
-	assert len(hplus.data.data) == int(round(sample_rate * duration))//2 +1
+	assert len(hplus.data.data) == int(round(f_high * duration)) +1
+	# pad the output vector if the sample rate was higher than the
+	# requested final frequency
+	if f_high < sample_rate / 2:
+		fseries = lal.CreateCOMPLEX16FrequencySeries(
+			name = hplus.name,
+			epoch = hplus.epoch,
+			f0 = hplus.f0,
+			deltaF = hplus.deltaF,
+			length = int(round(sample_rate * duration))//2 +1,
+			sampleUnits = hplus.sampleUnits
+		)
+		fseries.data.data = numpy.zeros(fseries.data.length)
+		fseries.data.data[:hplus.data.length] = hplus.data.data[:]
+		hplus = fseries
 	return hplus
 
 def condition_imr_template(approximant, data, epoch_time, sample_rate_max, max_ringtime):
@@ -162,6 +176,17 @@ def condition_imr_template(approximant, data, epoch_time, sample_rate_max, max_r
 	# done
 	return data, target_index
 
+def condition_ear_warn_template(approximant, data, epoch_time, sample_rate_max, max_shift_time):
+	assert -len(data) / sample_rate_max <= epoch_time < 0.0, "Epoch returned follows a different convention"
+	# find the index for the peak sample using the epoch returned by
+	# the waveform generator
+	epoch_index = -int(epoch_time*sample_rate_max) - 1
+	# move the early warning waveforms forward according to the waveform
+	# that spends the longest in going from fhigh to ISCO in a given
+	# split bank. This effectively ends some waveforms at f < fhigh
+	target_index = int(sample_rate_max * max_shift_time)
+	data = numpy.roll(data, target_index-epoch_index)
+	return data, target_index
 
 def compute_autocorrelation_mask( autocorrelation ):
 	'''
@@ -251,7 +276,7 @@ def condition_psd(psd, newdeltaF, minfs = (35.0, 40.0), maxfs = (1800., 2048.),
 	return psd
 
 
-def generate_templates(template_table, approximant, psd, f_low, time_slices, autocorrelation_length = None, verbose = False):
+def generate_templates(template_table, approximant, psd, f_low, time_slices, autocorrelation_length = None, fhigh = None, verbose = False):
 	"""!
 	Generate a bank of templates, which are
 	1. broken up into time slice,
@@ -374,18 +399,27 @@ def generate_templates(template_table, approximant, psd, f_low, time_slices, aut
 	# to get back the original waveform.
 	sigmasq = []
 
-	# Generate each template, downsampling as we go to save memory
-	max_ringtime = max([chirptime.ringtime(row.mass1*lal.MSUN_SI + row.mass2*lal.MSUN_SI, chirptime.overestimate_j_from_chi(max(row.spin1z, row.spin2z))) for row in template_table])
-	for i, row in enumerate(template_table):
-		if verbose:
-			print >>sys.stderr, "generating template %d/%d:  m1 = %g, m2 = %g, s1x = %g, s1y = %g, s1z = %g, s2x = %g, s2y = %g, s2z = %g" % (i + 1, len(template_table), row.mass1, row.mass2, row.spin1x, row.spin1y, row.spin1z, row.spin2x, row.spin2y, row.spin2z)
+	if approximant in templates.gstlal_IMR_approximants:
+		max_ringtime = max([chirptime.ringtime(row.mass1*lal.MSUN_SI + row.mass2*lal.MSUN_SI, chirptime.overestimate_j_from_chi(max(row.spin1z, row.spin2z))) for row in template_table])
+
+	else:
+		if sample_rate_max>2.*fhigh:
+		# Calculate the maximum time we need to shift the early warning
+		# waveforms forward by, calculated by the 3.5 approximation from
+		# fhigh to ISCO.
+			max_shift_time = max([spawaveform.chirptime(row.mass1, row.mass2, 7, fhigh, 0., spawaveform.computechi(row.mass1, row.mass2, row.spin1z, row.spin2z)) for row in template_table])
 
 		#
+		# Generate each template, downsampling as we go to save memory
 		# generate "cosine" component of frequency-domain template.
 		# waveform is generated for a canonical distance of 1 Mpc.
 		#
 
-		fseries = generate_template(row, approximant, sample_rate_max, working_duration, f_low, sample_rate_max / 2., fwdplan = fwdplan, fworkspace = fworkspace)
+	for i, row in enumerate(template_table):
+		if verbose:
+			print >>sys.stderr, "generating template %d/%d:  m1 = %g, m2 = %g, s1x = %g, s1y = %g, s1z = %g, s2x = %g, s2y = %g, s2z = %g" % (i + 1, len(template_table), row.mass1, row.mass2, row.spin1x, row.spin1y, row.spin1z, row.spin2x, row.spin2y, row.spin2z)
+
+		fseries = generate_template(row, approximant, sample_rate_max, working_duration, f_low, fhigh, fwdplan = fwdplan, fworkspace = fworkspace)
 
 		if FIR_WHITENER:
 			#
@@ -435,8 +469,15 @@ def generate_templates(template_table, approximant, psd, f_low, time_slices, aut
 			data, target_index = condition_imr_template(approximant, data, epoch_time, sample_rate_max, max_ringtime)
 			# record the new end times for the waveforms (since we performed the shifts)
 			row.end = LIGOTimeGPS(float(target_index-(len(data) - 1.))/sample_rate_max)
+
 		else:
-			data *= tukeywindow(data, samps = 32)
+			if sample_rate_max > fhigh*2.:
+				data, target_index = condition_ear_warn_template(approximant, data, epoch_time, sample_rate_max, max_shift_time)
+				data *= tukeywindow(data, samps = 32)
+				# record the new end times for the waveforms (since we performed the shifts)
+				row.end = LIGOTimeGPS(float(target_index-(len(data) - 1.))/sample_rate_max)
+			else:
+				data *= tukeywindow(data, samps = 32)
 
 		data = data[-length_max:]
 		#
diff --git a/gstlal-inspiral/python/svd_bank.py b/gstlal-inspiral/python/svd_bank.py
index 7f23041920..589de25f85 100644
--- a/gstlal-inspiral/python/svd_bank.py
+++ b/gstlal-inspiral/python/svd_bank.py
@@ -146,7 +146,7 @@ class BankFragment(object):
 
 
 class Bank(object):
-	def __init__(self, bank_xmldoc, psd, time_slices, gate_fap, snr_threshold, tolerance, flow = 40.0, autocorrelation_length = None, logname = None, identity_transform = False, verbose = False, bank_id = None):
+	def __init__(self, bank_xmldoc, psd, time_slices, gate_fap, snr_threshold, tolerance, flow = 40.0, autocorrelation_length = None, logname = None, identity_transform = False, verbose = False, bank_id = None, fhigh = None):
 		# FIXME: remove template_bank_filename when no longer needed
 		# by trigger generator element
 		self.template_bank_filename = None
@@ -165,6 +165,7 @@ class Bank(object):
 			flow,
 			time_slices,
 			autocorrelation_length = autocorrelation_length,
+			fhigh = fhigh,
 			verbose = verbose)
 
 		# Include signal inspiral table
@@ -198,7 +199,7 @@ class Bank(object):
 
 
 
-def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_tolerance, padding = 1.5, identity_transform = False, verbose = False, autocorrelation_length = 201, samples_min = 1024, samples_max_256 = 1024, samples_max_64 = 2048, samples_max = 4096, bank_id = None, contenthandler = None):
+def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_tolerance, padding = 1.5, identity_transform = False, verbose = False, autocorrelation_length = 201, samples_min = 1024, samples_max_256 = 1024, samples_max_64 = 2048, samples_max = 4096, bank_id = None, contenthandler = None, sample_rate = None):
 	"""!
 	Return an instance of a Bank class.
 
@@ -236,8 +237,10 @@ def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_
 		samples_max_256 = samples_max_256,
 		samples_max_64 = samples_max_64,
 		samples_max = samples_max,
+		sample_rate = sample_rate,
 		verbose=verbose)
 
+	fhigh=check_ffinal_and_find_max_ffinal(bank_xmldoc)
 	# Generate templates, perform SVD, get orthogonal basis
 	# and store as Bank object
 	bank = Bank(
@@ -251,7 +254,8 @@ def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_
 		autocorrelation_length = autocorrelation_length,	# samples
 		identity_transform = identity_transform,
 		verbose = verbose,
-		bank_id = bank_id
+		bank_id = bank_id,
+		fhigh = fhigh
 	)
 
 	# FIXME: remove this when no longer needed
diff --git a/gstlal-inspiral/python/templates.py b/gstlal-inspiral/python/templates.py
index 9778f22bb2..8fc3622b5d 100644
--- a/gstlal-inspiral/python/templates.py
+++ b/gstlal-inspiral/python/templates.py
@@ -244,6 +244,7 @@ def time_slices(
 	samples_max_256 = 1024,
 	samples_max_64 = 2048,
 	samples_max = 4096,
+	sample_rate = None,
 	verbose = False
 ):
 	"""
@@ -277,7 +278,11 @@ def time_slices(
 
 	# Remove too-small and too-big sample rates base on input params.
 	sample_rate_min = ceil_pow_2( 2 * padding * flow )
-	sample_rate_max = ceil_pow_2( 2 * fhigh )
+	if sample_rate is None:
+		sample_rate_max = ceil_pow_2( 2 * fhigh )
+	else:
+		# note that sample rate is ensured to be a power of 2 in gstlal_svd_bank
+		sample_rate_max = sample_rate
 	allowed_rates = [rate for rate in allowed_rates if sample_rate_min <= rate <= sample_rate_max]
 
 	#
-- 
GitLab