Skip to content
Snippets Groups Projects
Commit 86440ca5 authored by Surabhi Sachdev's avatar Surabhi Sachdev
Browse files

gstlal_bank_splitter: Fix missing templates

add padding at the edge of chi bins to
take care of the clipping
parent 9de1996e
No related branches found
No related tags found
No related merge requests found
......@@ -55,10 +55,8 @@ from gstlal import chirptime
# + `--n` [count] (int): Set the number of templates per output file (required). It will be rounded to make all sub banks approximately the same size.
# + `--overlap` [count] (int): Overlap the templates in each file by this amount, must be even.
# + `--sort-by` [column]: Select the template sort order column (required).
# + `--add-f-final`: Select whether to add f_final to the bank.
# + `--max-f-final` [max final freq] (float): Max f_final to populate table with; if f_final > max, use max.
# + `--instrument` [ifo]: Override the instrument, required
# + `--bank-program` [name]: Select name of the program used to generate the template bank (default: tmpltbank).
# + `--verbose`: Be verbose.
# + `--approximant` [string]: Must specify an approximant
# + `--f-low` [frequency] (floate): Lower frequency cutoff
......@@ -103,7 +101,6 @@ def group_templates(templates, n, overlap = 0):
if end >= len(templates):
break
def parse_command_line():
parser = OptionParser()
parser.add_option("--output-path", metavar = "path", default = ".", help = "Set the path to the directory where output files will be written. Default is \".\".")
......@@ -111,17 +108,15 @@ def parse_command_line():
parser.add_option("--n", metavar = "count", type = "int", help = "Set the number of templates per output file (required). It will be rounded to make all sub banks approximately the same size.")
parser.add_option("--overlap", default = 0, metavar = "count", type = "int", help = "overlap the templates in each file by this amount, must be even")
parser.add_option("--sort-by", metavar = "column", default="mchirp", help = "Select the template sort column, default mchirp")
parser.add_option("--add-f-final", action = "store_true", help = "Select whether to add f_final to the bank.")
parser.add_option("--max-f-final", metavar = "float", type="float", help = "Max f_final to populate table with; if f_final over mx, use max.")
parser.add_option("--instrument", metavar = "ifo", type="string", help = "override the instrument, required")
parser.add_option("--bank-program", metavar = "name", default = "tmpltbank", type="string", help = "Select name of the program used to generate the template bank (default: tmpltbank).")
parser.add_option("-v", "--verbose", action = "store_true", help = "Be verbose.")
parser.add_option("--approximant", type = "string", help = "Must specify an approximant")
parser.add_option("--f-low", type = "float", metavar = "frequency", help = "Lower frequency cutoff")
parser.add_option("--f-low", type = "float", metavar = "frequency", help = "Lower frequency cutoff. Required")
parser.add_option("--group-by-chi", type = "int", metavar = "N", default = 1, help = "group templates into N groups of chi - helps with SVD. Default 1")
options, filenames = parser.parse_args()
required_options = ("n", "instrument", "sort_by", "output_cache", "approximant")
required_options = ("n", "instrument", "sort_by", "output_cache", "approximant", "f_low")
missing_options = [option for option in required_options if getattr(options, option) is None]
if missing_options:
raise ValueError, "missing required option(s) %s" % ", ".join("--%s" % option.replace("_", "-") for option in missing_options)
......@@ -129,82 +124,80 @@ def parse_command_line():
if options.overlap % 2:
raise ValueError("overlap must be even")
return options, filenames
if len(filenames) !=1:
raise ValueError("Must give exactly one file name")
return options, filenames[0]
options, filenames = parse_command_line()
options, filename = parse_command_line()
output_cache_file = open(options.output_cache, "w")
bank_count = 0
outputrows = []
for filename in filenames:
xmldoc = ligolw_utils.load_filename(filename, verbose = options.verbose, contenthandler = LIGOLWContentHandler)
sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)
if options.add_f_final:
if options.f_low is None:
flow, = ligolw_process.get_process_params(xmldoc, options.bank_program, "--flow") + ligolw_process.get_process_params(xmldoc, options.bank_program, "--low-frequency-cutoff") + ligolw_process.get_process_params(xmldoc, options.bank_program, "--f-low")
else:
flow = options.f_low
for row in sngl_inspiral_table:
# Find the total spin magnitudes
spin1, spin2 = (row.spin1x**2 + row.spin1y**2 + row.spin1z**2)**.5, (row.spin2x**2 + row.spin2y**2 + row.spin2z**2)**.5
# Chirptime uses SI
m1_SI, m2_SI = MSUN_SI * row.mass1, MSUN_SI * row.mass2
xmldoc = ligolw_utils.load_filename(filename, verbose = options.verbose, contenthandler = LIGOLWContentHandler)
sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)
if options.approximant in templates.gstlal_IMR_approximants:
# make sure to go a factor of 2 above the ringdown frequency for safety
row.f_final = 2 * chirptime.ringf(m1_SI + m2_SI, chirptime.overestimate_j_from_chi(max(spin1, spin2)))
else:
# otherwise choose a suitable high frequency
# NOTE not SI
row.f_final = spawaveform.ffinal(row.mass1, row.mass2, 'bkl_isco')
# Override the high frequency with the max if appropriate
if options.max_f_final and (row.f_final > options.max_f_final):
row.f_final = options.max_f_final
# Record the conservative template duration
row.template_duration = chirptime.imr_time(flow, m1_SI, m2_SI, spin1, spin2, f_max = row.f_final)
for row in sngl_inspiral_table:
# Find the total spin magnitudes
spin1, spin2 = (row.spin1x**2 + row.spin1y**2 + row.spin1z**2)**.5, (row.spin2x**2 + row.spin2y**2 + row.spin2z**2)**.5
# Chirptime uses SI
m1_SI, m2_SI = MSUN_SI * row.mass1, MSUN_SI * row.mass2
if options.approximant in templates.gstlal_IMR_approximants:
# make sure to go a factor of 2 above the ringdown frequency for safety
row.f_final = 2 * chirptime.ringf(m1_SI + m2_SI, chirptime.overestimate_j_from_chi(max(spin1, spin2)))
else:
# otherwise choose a suitable high frequency
# NOTE not SI
row.f_final = spawaveform.ffinal(row.mass1, row.mass2, 'bkl_isco')
for row in sngl_inspiral_table:
row.ifo = options.instrument
# Override the high frequency with the max if appropriate
if options.max_f_final and (row.f_final > options.max_f_final):
row.f_final = options.max_f_final
# just to make sure it is set
for row in sngl_inspiral_table:
row.mtotal = row.mass1 + row.mass2
# Bin by Chi, has no effect if option is not specified, i.e. there is only one bin.
chidict = {}
templates_by_chi = [tmp[1] for tmp in sorted([(spawaveform.computechi(row.mass1, row.mass2, row.spin1z, row.spin2z), row) for row in sngl_inspiral_table])]
for i, rows in enumerate(group_templates(templates_by_chi, len(templates_by_chi) / options.group_by_chi, overlap = 0)):
chidict[i] = rows
# Record the conservative template duration
row.template_duration = chirptime.imr_time(options.f_low, m1_SI, m2_SI, spin1, spin2, f_max = row.f_final)
for chi in chidict:
chirows = chidict[chi]
for row in sngl_inspiral_table:
row.ifo = options.instrument
# store the process params
process = ligolw_process.register_to_xmldoc(xmldoc, program = "gstlal_bank_splitter", paramdict = options.__dict__, comment = "Split bank into smaller banks for SVD")
# just to make sure it is set
for row in sngl_inspiral_table:
row.mtotal = row.mass1 + row.mass2
def sort_func(row, column = options.sort_by):
return getattr(row, column)
# Bin by Chi
sngl_inspiral_table.sort(key = lambda row: spawaveform.computechi(row.mass1, row.mass2, row.spin1z, row.spin2z))
for chirows in group_templates(sngl_inspiral_table, len(sngl_inspiral_table) / options.group_by_chi, overlap = 0):
chirows.sort(key=sort_func)
# store the process params
process = ligolw_process.register_to_xmldoc(xmldoc, program = "gstlal_bank_splitter", paramdict = options.__dict__, comment = "Split bank into smaller banks for SVD")
def sort_func(row, column = options.sort_by):
return getattr(row, column)
for rows in group_templates(chirows, options.n, options.overlap):
assert len(rows) >= options.n/2, "There are too few templates in this chi interval. Requested %d: have %d" % (options.n, len(rows))
outputrows.append((rows[0], rows))
chirows.sort(key=sort_func)
for numrow, rows in enumerate(group_templates(chirows, options.n, options.overlap)):
assert len(rows) >= options.n/2, "There are too few templates in this chi interval. Requested %d: have %d" % (options.n, len(rows))
# Pad the first group with an extra overlap / 2 templates
if numrow == 0:
rows = rows[:options.overlap/2] + rows
outputrows.append((rows[0], rows))
# Pad the last group with an extra overlap / 2 templates
outputrows[-1] = (rows[0], rows + rows[-options.overlap/2:])
# One last sort now that the templates have been grouped
# A sort of the groups of templates so that the sub banks are ordered.
def sort_func((row, rows), column = options.sort_by):
return getattr(row, column)
outputrows.sort(key=sort_func)
for bank_count, (row, rows) in enumerate(outputrows):
if bank_count == 0:
rows = rows[options.overlap/2:]
if bank_count == len(outputrows) - 1:
rows = rows[:-options.overlap/2]
sngl_inspiral_table[:] = rows
output = inspiral_pipe.T050017_filename(options.instrument, "GSTLAL_SPLIT_BANK_%04d" % bank_count, 0, 0, ".xml.gz", path = options.output_path)
output_cache_file.write("%s\n" % gluelal.CacheEntry.from_T050017("file://localhost%s" % os.path.abspath(output)))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment