Skip to content
Snippets Groups Projects
Commit 366c2430 authored by Heather Fong's avatar Heather Fong Committed by Patrick Godwin
Browse files

make metric overlap calculate in split bank chunks rather than over entire bank

parent 03b1918d
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@ import sys
from gstlal import metric as metric_module
from ligo.lw import ligolw
from ligo.lw import utils as ligolw_utils
from lal.utils import CacheEntry
from ligo.lw import lsctables
import numpy
import argparse
......@@ -16,21 +17,27 @@ parser = argparse.ArgumentParser()
parser.add_argument("--psd-xml-file", help = "provide a psd xml file")
parser.add_argument("--bank-file", help = "provide the bank file for which overlaps will be calculated")
parser.add_argument("--out-h5-file", required = True, help = "provide the output hdf5 file name")
parser.add_argument("--start-row", type=int, help = "The starting row to calculate overlaps")
parser.add_argument("--num-rows", type=int, help = "The number of rows to calculate overlaps")
parser.add_argument("--approximant", default="IMRPhenomD", help = "Waveform model. Default IMRPhenomD.")
parser.add_argument("--f-low", type=float, default=15.0, help = "Lowest frequency component of template. Default 15 Hz.")
parser.add_argument("--f-high", type=float, default=4096.0, help = "Highest frequency component of template Default 4096 Hz.")
parser.add_argument("--split-bank-cache", help = "Cache file containing paths to split banks. Required.")
parser.add_argument("--split-bank-index", type=int, help = "Split bank index that will be looped over.")
parser.add_argument("--overlap-threshold", type=float, default=0.25, help = "Overlap threshold. Default 0.25.")
parser.add_argument("--number-of-templates", type=int, help = "Total number of templates in bank. Required.")
args = parser.parse_args()
g_ij = metric_module.Metric(
args.psd_xml_file,
coord_func = metric_module.x_y_z_zn_func,
duration = 1.0, # FIXME!!!!!
flow = 10,
fhigh = 1024,
approximant = "IMRPhenomD"
args.psd_xml_file,
coord_func = metric_module.x_y_z_zn_func,
duration = 1.0, # FIXME!!!!!
flow = args.f_low,
fhigh = args.f_high,
approximant = args.approximant
)
xmldoc = ligolw_utils.load_filename(args.bank_file, verbose = True, contenthandler = LIGOLWContentHandler)
split_banks = sorted([ce.path for ce in map(CacheEntry, open(args.split_bank_cache, 'r'))])
xmldoc = ligolw_utils.load_filename(split_banks[args.split_bank_index], verbose=True, contenthandler = LIGOLWContentHandler)
sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)
def id_x_y_z_zn_from_row(row):
......@@ -41,20 +48,64 @@ def id_x_y_z_zn_from_row(row):
metric_module.zn_from_m1_m2_s1_s2(row.mass1, row.mass2, row.spin1z, row.spin2z)
]
vec1s = numpy.array([id_x_y_z_zn_from_row(row) for row in sngl_inspiral_table[args.start_row:args.start_row+args.num_rows]])
vec2s = numpy.array([id_x_y_z_zn_from_row(row) for row in sngl_inspiral_table])
vec1s = numpy.array([id_x_y_z_zn_from_row(row) for row in sngl_inspiral_table])
output = []
output = numpy.zeros((len(vec1s), int(args.number_of_templates/3)))
id2 = []
for n, vec1, in enumerate(vec1s):
g, det = g_ij(vec1[1:])
def match(vec2, vec1 = vec1, g = g):
return (vec1[0], vec2[0], g_ij.pseudo_match(g, vec1[1:], vec2[1:]))
thisoutput = [row for row in map(match, vec2s) if row[2] > 0.25]
print n, len(thisoutput)
output += thisoutput
g, det = g_ij(vec1[1:])
b_idx = args.split_bank_index
fwd_flag = 1
while b_idx < len(split_banks) and b_idx >= 0:
sbank_str = split_banks[b_idx]
xmldoc2 = ligolw_utils.load_filename(sbank_str, verbose=False, contenthandler = LIGOLWContentHandler)
sngl_inspiral_table2 = lsctables.SnglInspiralTable.get_table(xmldoc2)
vec2s = numpy.array([id_x_y_z_zn_from_row(row) for row in sngl_inspiral_table2])
for t2_id in vec2s[:,0]:
if t2_id not in id2:
id2.append(t2_id)
if fwd_flag:
b_idx += 1
else:
b_idx -= 1
def match(vec2, vec1 = vec1, g = g):
return (vec1[0], vec2[0], g_ij.pseudo_match(g, vec1[1:], vec2[1:]))
thisoutput = numpy.array([[i, row[1], row[2]] for i, row in enumerate(map(match, vec2s))])
xmldoc2.unlink()
maxovrlp = max(thisoutput[:,2])
output[n, thisoutput[:,0].astype(int)] = thisoutput[:,2]
if maxovrlp >= args.overlap_threshold:
second_chance = 0
print "\t Max overlap in %s: %f" %(sbank_str.split('/')[-1],maxovrlp)
if maxovrlp < args.overlap_threshold:
if second_chance < 2:
second_chance += 1
elif second_chance >= 2:
if fwd_flag:
second_chance = 0
fwd_flag = 0
b_idx = args.split_bank_index - 1
else:
print "done"
break
if b_idx > len(split_banks) - 1:
second_chance = 0
fwd_flag = 0
b_idx = args.split_bank_index - 1
mask = (output==0).all(0)
stop = min(numpy.where(mask)[0])
output = numpy.array(output)
h5f = h5py.File(args.out_h5_file, 'w')
h5f.create_dataset('overlaps', data = output)
olapdata = h5f.create_group("%s_metric" %(args.approximant))
dset = olapdata.create_dataset("id", data = vec1s[:,0])
dset = olapdata.create_dataset("id2", data = id2)
dset = olapdata.create_dataset('overlaps', data = output[:,:stop])
h5f.close()
......@@ -13,14 +13,23 @@ class LIGOLWContentHandler(ligolw.LIGOLWContentHandler):
pass
lsctables.use_in(LIGOLWContentHandler)
def file_len(fname):
with open(fname) as f:
for i, l in enumerate(f):
pass
return i+1
parser = argparse.ArgumentParser()
parser.add_argument("--psd-xml-file", help = "provide a psd xml file")
parser.add_argument("--bank-file", help = "provide the bank file for which overlaps will be calculated")
parser.add_argument("--split-bank-cache", help = "Cache file containing paths to split banks.")
args = parser.parse_args()
xmldoc = ligolw_utils.load_filename(args.bank_file, verbose = True, contenthandler = LIGOLWContentHandler)
sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(xmldoc)
number_of_templates = len(sngl_inspiral_table)
try:
os.mkdir("logs")
except:
......@@ -30,12 +39,10 @@ dag = dagparts.DAG("metric_overlap")
overlapJob = dagparts.DAGJob("gstlal_inspiral_metric_overlap", condor_commands = {"want_graceful_removal":"True", "kill_sig":"15", "accounting_group":"ligo.prod.o3.cbc.uber.gstlaloffline"})
addJob = dagparts.DAGJob("gstlal_inspiral_add_metric_overlaps", condor_commands = {"want_graceful_removal":"True", "kill_sig":"15", "accounting_group":"ligo.prod.o3.cbc.uber.gstlaloffline"})
num = 1000
overlapnodes = []
# FIXME dont hardcode 3345408, it comes from number of tiles in TimePhaseSNR
for start in range(0, len(sngl_inspiral_table), num):
stop = start + num
overlapnodes.append(dagparts.DAGNode(overlapJob, dag, parent_nodes = [], opts = {"start-row":str(start), "num-rows":str(num)}, output_files = {"out-h5-file":"%s/metric_overlaps_%d_%d.h5" % (overlapJob.output_path, start, num)}, input_files = {"psd-xml-file": args.psd_xml_file, "bank-file": args.bank_file}))
for i in range(0, file_len(args.split_bank_cache)):
overlapnodes.append(dagparts.DAGNode(overlapJob, dag, parent_nodes = [], opts = {"split-bank-cache":args.split_bank_cache, "split-bank-index": i, "number-of-templates": number_of_templates}, output_files = {"out-h5-file":"%s/metric_overlaps_%d.h5" % (overlapJob.output_path, i)}, input_files = {"psd-xml-file": args.psd_xml_file, "bank-file": args.bank_file}))
addnode = dagparts.DAGNode(addJob, dag, parent_nodes = overlapnodes, output_files = {"out-h5-file": "overlaps.h5"}, input_files = {"": [n.output_files["out-h5-file"] for n in overlapnodes]})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment