Skip to content
Snippets Groups Projects
Commit b00bc6f5 authored by Surabhi Sachdev's avatar Surabhi Sachdev
Browse files

To allow for a different highest sampling rate than 2*fhigh


 - gstlal_inspiral_pipe and gstlal_svd_bank take as option
--sample-rate
 - the waveforms that are generated with a lower fhigh
are shifted forward such that all waveforms in a given
sub bank have a common amount of early warning
 - this has the effect that the actual fhigh of waveforms
is less than requested in a given sub bank, since the
waveforms are shifted according to the waveform that takes
longest to go from fhigh to ISCO

Signed-off-by: default avatarSurabhi Sachdev <surabhi.sachdev@ligo.org>
parent a90e4cb7
No related branches found
No related tags found
No related merge requests found
......@@ -207,6 +207,7 @@ def svd_node_gen(svdJob, dag, parent_nodes, psd, bank_cache, options, seg, templ
parent_nodes = parent_nodes,
opts = {"svd-tolerance":options.tolerance,
"flow":options.flow,
"sample-rate":options.sample_rate,
"clipleft":clipleft,
"clipright":clipright,
"samples-min":options.samples_min[j],
......@@ -1025,6 +1026,7 @@ def parse_command_line():
parser.add_option("--bank-cache", metavar = "filenames", action = "append", help = "Set the bank cache files in format H1=H1.cache,H2=H2.cache, etc.. (can be given multiple times)")
parser.add_option("--tolerance", metavar = "float", type = "float", default = 0.9999, help = "set the SVD tolerance, default 0.9999")
parser.add_option("--flow", metavar = "num", type = "float", default = 40, help = "set the low frequency cutoff, default 40 (Hz)")
parser.add_option("--sample-rate", metavar = "Hz", type = "int", help = "Set the sample rate. If not set, the sample rate will be based on the template frequency. The sample rate must be at least twice the highest frequency in the templates. If provided it must be a power of two")
parser.add_option("--identity-transform", action = "store_true", help = "Use identity transform, i.e. no SVD")
# trigger generation options
......@@ -1086,6 +1088,9 @@ def parse_command_line():
if options.num_banks:
options.num_banks = [int(v) for v in options.num_banks.split(",")]
if options.sample_rate is not None and (not numpy.log2(options.sample_rate) == int(numpy.log2(options.sample_rate))):
raise ValueError("--sample-rate must be a power of two")
if not options.samples_min and not options.svd_bank_cache:
options.samples_min = [1024]*len(options.bank_cache)
......
......@@ -79,7 +79,7 @@ for bank in svd_bank.read_banks(form.getlist("url")[0], svd_bank.DefaultContentH
maxrate = max(maxrate, frag.rate)
mint = min(t[0], mint)
pyplot.xlabel('time (s) mc:%.2f s1z:%.2f s2z:%.2f' % (mc, s1z, s2z))
pyplot.xlim([-.25, 0])
#pyplot.xlim([-.25, 0])
pyplot.subplot(212)
pyplot.xlabel('samples')
......
......@@ -30,6 +30,7 @@
from optparse import OptionParser
import numpy
import lal.series
......@@ -78,6 +79,7 @@ from gstlal.stats import inspiral_lr
parser = OptionParser(description = __doc__)
parser.add_option("--flow", metavar = "Hz", type = "float", default = 40.0, help = "Set the template low-frequency cut-off (default = 40.0).")
parser.add_option("--sample-rate", metavar = "Hz", type = "int", help = "Set the sample rate. If not set, the sample rate will be based on the template frequency. The sample rate must be at least twice the highest frequency in the templates. If provided it must be a power of two")
parser.add_option("--identity-transform", action = "store_true", default = False, help = "Do not perform an SVD; instead, use the original templates as the analyzing templates.")
parser.add_option("--padding", metavar = "pad", type = "float", default = 1.5, help = "Fractional amount to pad time slices.")
parser.add_option("--svd-tolerance", metavar = "match", type = "float", default = 0.9995, help = "Set the SVD reconstruction tolerance (default = 0.9995).")
......@@ -113,6 +115,10 @@ if not (len(options.template_bank) == len(options.clipleft) == len(options.clipr
if not options.autocorrelation_length % 2:
raise ValueError("--autocorrelation-length must be odd")
if options.sample_rate is not None and (not numpy.log2(options.sample_rate) == int(numpy.log2(options.sample_rate))):
raise ValueError("--sample-rate must be a power of two")
#
#
......@@ -145,7 +151,8 @@ svd_bank.write_bank(
samples_max_64 = options.samples_max_64,
samples_max = options.samples_max,
bank_id = bank_id,
contenthandler = svd_bank.DefaultContentHandler
contenthandler = svd_bank.DefaultContentHandler,
sample_rate = options.sample_rate
) for (template_bank, bank_id) in zip(options.template_bank, options.bank_id)],
options.clipleft,
options.clipright
......
......@@ -110,6 +110,7 @@ def generate_template(template_bank_row, approximant, sample_rate, duration, f_l
"""
if approximant not in templates.gstlal_approximants:
raise ValueError("Unsupported approximant given %s" % approximant)
assert f_high <= sample_rate // 2
# FIXME use hcross somday?
# We don't here because it is not guaranteed to be orthogonal
......@@ -138,8 +139,21 @@ def generate_template(template_bank_row, approximant, sample_rate, duration, f_l
parameters['approximant'] = lalsim.GetApproximantFromString(str(approximant))
hplus, hcross = lalsim.SimInspiralFD(**parameters)
# NOTE assumes fhigh is the Nyquist frequency!!!
assert len(hplus.data.data) == int(round(sample_rate * duration))//2 +1
assert len(hplus.data.data) == int(round(f_high * duration)) +1
# pad the output vector if the sample rate was higher than the
# requested final frequency
if f_high < sample_rate / 2:
fseries = lal.CreateCOMPLEX16FrequencySeries(
name = hplus.name,
epoch = hplus.epoch,
f0 = hplus.f0,
deltaF = hplus.deltaF,
length = int(round(sample_rate * duration))//2 +1,
sampleUnits = hplus.sampleUnits
)
fseries.data.data = numpy.zeros(fseries.data.length)
fseries.data.data[:hplus.data.length] = hplus.data.data[:]
hplus = fseries
return hplus
def condition_imr_template(approximant, data, epoch_time, sample_rate_max, max_ringtime):
......@@ -162,6 +176,17 @@ def condition_imr_template(approximant, data, epoch_time, sample_rate_max, max_r
# done
return data, target_index
def condition_ear_warn_template(approximant, data, epoch_time, sample_rate_max, max_shift_time):
assert -len(data) / sample_rate_max <= epoch_time < 0.0, "Epoch returned follows a different convention"
# find the index for the peak sample using the epoch returned by
# the waveform generator
epoch_index = -int(epoch_time*sample_rate_max) - 1
# move the early warning waveforms forward according to the waveform
# that spends the longest in going from fhigh to ISCO in a given
# split bank. This effectively ends some waveforms at f < fhigh
target_index = int(sample_rate_max * max_shift_time)
data = numpy.roll(data, target_index-epoch_index)
return data, target_index
def compute_autocorrelation_mask( autocorrelation ):
'''
......@@ -251,7 +276,7 @@ def condition_psd(psd, newdeltaF, minfs = (35.0, 40.0), maxfs = (1800., 2048.),
return psd
def generate_templates(template_table, approximant, psd, f_low, time_slices, autocorrelation_length = None, verbose = False):
def generate_templates(template_table, approximant, psd, f_low, time_slices, autocorrelation_length = None, fhigh = None, verbose = False):
"""!
Generate a bank of templates, which are
1. broken up into time slice,
......@@ -374,18 +399,27 @@ def generate_templates(template_table, approximant, psd, f_low, time_slices, aut
# to get back the original waveform.
sigmasq = []
# Generate each template, downsampling as we go to save memory
max_ringtime = max([chirptime.ringtime(row.mass1*lal.MSUN_SI + row.mass2*lal.MSUN_SI, chirptime.overestimate_j_from_chi(max(row.spin1z, row.spin2z))) for row in template_table])
for i, row in enumerate(template_table):
if verbose:
print >>sys.stderr, "generating template %d/%d: m1 = %g, m2 = %g, s1x = %g, s1y = %g, s1z = %g, s2x = %g, s2y = %g, s2z = %g" % (i + 1, len(template_table), row.mass1, row.mass2, row.spin1x, row.spin1y, row.spin1z, row.spin2x, row.spin2y, row.spin2z)
if approximant in templates.gstlal_IMR_approximants:
max_ringtime = max([chirptime.ringtime(row.mass1*lal.MSUN_SI + row.mass2*lal.MSUN_SI, chirptime.overestimate_j_from_chi(max(row.spin1z, row.spin2z))) for row in template_table])
else:
if sample_rate_max>2.*fhigh:
# Calculate the maximum time we need to shift the early warning
# waveforms forward by, calculated by the 3.5 approximation from
# fhigh to ISCO.
max_shift_time = max([spawaveform.chirptime(row.mass1, row.mass2, 7, fhigh, 0., spawaveform.computechi(row.mass1, row.mass2, row.spin1z, row.spin2z)) for row in template_table])
#
# Generate each template, downsampling as we go to save memory
# generate "cosine" component of frequency-domain template.
# waveform is generated for a canonical distance of 1 Mpc.
#
fseries = generate_template(row, approximant, sample_rate_max, working_duration, f_low, sample_rate_max / 2., fwdplan = fwdplan, fworkspace = fworkspace)
for i, row in enumerate(template_table):
if verbose:
print >>sys.stderr, "generating template %d/%d: m1 = %g, m2 = %g, s1x = %g, s1y = %g, s1z = %g, s2x = %g, s2y = %g, s2z = %g" % (i + 1, len(template_table), row.mass1, row.mass2, row.spin1x, row.spin1y, row.spin1z, row.spin2x, row.spin2y, row.spin2z)
fseries = generate_template(row, approximant, sample_rate_max, working_duration, f_low, fhigh, fwdplan = fwdplan, fworkspace = fworkspace)
if FIR_WHITENER:
#
......@@ -435,8 +469,15 @@ def generate_templates(template_table, approximant, psd, f_low, time_slices, aut
data, target_index = condition_imr_template(approximant, data, epoch_time, sample_rate_max, max_ringtime)
# record the new end times for the waveforms (since we performed the shifts)
row.end = LIGOTimeGPS(float(target_index-(len(data) - 1.))/sample_rate_max)
else:
data *= tukeywindow(data, samps = 32)
if sample_rate_max > fhigh*2.:
data, target_index = condition_ear_warn_template(approximant, data, epoch_time, sample_rate_max, max_shift_time)
data *= tukeywindow(data, samps = 32)
# record the new end times for the waveforms (since we performed the shifts)
row.end = LIGOTimeGPS(float(target_index-(len(data) - 1.))/sample_rate_max)
else:
data *= tukeywindow(data, samps = 32)
data = data[-length_max:]
#
......
......@@ -146,7 +146,7 @@ class BankFragment(object):
class Bank(object):
def __init__(self, bank_xmldoc, psd, time_slices, gate_fap, snr_threshold, tolerance, flow = 40.0, autocorrelation_length = None, logname = None, identity_transform = False, verbose = False, bank_id = None):
def __init__(self, bank_xmldoc, psd, time_slices, gate_fap, snr_threshold, tolerance, flow = 40.0, autocorrelation_length = None, logname = None, identity_transform = False, verbose = False, bank_id = None, fhigh = None):
# FIXME: remove template_bank_filename when no longer needed
# by trigger generator element
self.template_bank_filename = None
......@@ -165,6 +165,7 @@ class Bank(object):
flow,
time_slices,
autocorrelation_length = autocorrelation_length,
fhigh = fhigh,
verbose = verbose)
# Include signal inspiral table
......@@ -198,7 +199,7 @@ class Bank(object):
def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_tolerance, padding = 1.5, identity_transform = False, verbose = False, autocorrelation_length = 201, samples_min = 1024, samples_max_256 = 1024, samples_max_64 = 2048, samples_max = 4096, bank_id = None, contenthandler = None):
def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_tolerance, padding = 1.5, identity_transform = False, verbose = False, autocorrelation_length = 201, samples_min = 1024, samples_max_256 = 1024, samples_max_64 = 2048, samples_max = 4096, bank_id = None, contenthandler = None, sample_rate = None):
"""!
Return an instance of a Bank class.
......@@ -236,8 +237,10 @@ def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_
samples_max_256 = samples_max_256,
samples_max_64 = samples_max_64,
samples_max = samples_max,
sample_rate = sample_rate,
verbose=verbose)
fhigh=check_ffinal_and_find_max_ffinal(bank_xmldoc)
# Generate templates, perform SVD, get orthogonal basis
# and store as Bank object
bank = Bank(
......@@ -251,7 +254,8 @@ def build_bank(template_bank_url, psd, flow, ortho_gate_fap, snr_threshold, svd_
autocorrelation_length = autocorrelation_length, # samples
identity_transform = identity_transform,
verbose = verbose,
bank_id = bank_id
bank_id = bank_id,
fhigh = fhigh
)
# FIXME: remove this when no longer needed
......
......@@ -244,6 +244,7 @@ def time_slices(
samples_max_256 = 1024,
samples_max_64 = 2048,
samples_max = 4096,
sample_rate = None,
verbose = False
):
"""
......@@ -277,7 +278,11 @@ def time_slices(
# Remove too-small and too-big sample rates base on input params.
sample_rate_min = ceil_pow_2( 2 * padding * flow )
sample_rate_max = ceil_pow_2( 2 * fhigh )
if sample_rate is None:
sample_rate_max = ceil_pow_2( 2 * fhigh )
else:
# note that sample rate is ensured to be a power of 2 in gstlal_svd_bank
sample_rate_max = sample_rate
allowed_rates = [rate for rate in allowed_rates if sample_rate_min <= rate <= sample_rate_max]
#
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment