From 505fa324a3ad849dce52a6abe8f4030c766fc994 Mon Sep 17 00:00:00 2001
From: Patrick Godwin <patrick.godwin@ligo.org>
Date: Wed, 15 May 2019 07:14:22 -0700
Subject: [PATCH] gstlal_inspiral_dag: factored out dag layers into finer
 pieces

---
 gstlal-inspiral/bin/gstlal_inspiral_dag | 212 ++++++++++++++----------
 1 file changed, 127 insertions(+), 85 deletions(-)

diff --git a/gstlal-inspiral/bin/gstlal_inspiral_dag b/gstlal-inspiral/bin/gstlal_inspiral_dag
index 81287a5603..88a8267d4b 100755
--- a/gstlal-inspiral/bin/gstlal_inspiral_dag
+++ b/gstlal-inspiral/bin/gstlal_inspiral_dag
@@ -410,9 +410,9 @@ def cache_to_db(cache, jobs):
 
 def get_rank_file(instruments, boundary_seg, n, basename, job=None):
 	if job:
-		return dagparts.T050017_filename(instruments, '_'.join(['%04d'%n, basename]), boundary_seg, '.xml.gz', path = job.output_path)
+		return dagparts.T050017_filename(instruments, '_'.join([n, basename]), boundary_seg, '.xml.gz', path = job.output_path)
 	else:
-		return dagparts.T050017_filename(instruments, '_'.join(['%04d'%n, basename]), boundary_seg, '.cache')
+		return dagparts.T050017_filename(instruments, '_'.join([n, basename]), boundary_seg, '.cache')
 
 def set_up_jobs(options):
 	jobs = {}
@@ -987,11 +987,7 @@ def merge_cluster_layer(dag, jobs, parent_nodes, db, db_cache, sqlfile, input_fi
 		input_files = {"": db}
 	)
 
-def rank_and_merge_layer(dag, jobs, svd_nodes, lloid_output, lloid_diststats, options, boundary_seg, instrument_set, model_node, model_file):
-	likelihood_nodes = {}
-	rankpdf_nodes = []
-	rankpdf_zerolag_nodes = []
-	outnodes = {}
+def marginalize_layer(dag, jobs, svd_nodes, lloid_output, lloid_diststats, options, boundary_seg, instrument_set, model_node, model_file): 
 	instruments = "".join(sorted(instrument_set))
 	margnodes = {}
 
@@ -999,10 +995,11 @@ def rank_and_merge_layer(dag, jobs, svd_nodes, lloid_output, lloid_diststats, op
 	one_ifo_svd_nodes = svd_nodes.values()[0]
 	# Here n counts the bins
 	# first non-injections, which will get skipped if this is an injections-only run
-	for n, (outputs, diststats) in enumerate((lloid_output[None][key], lloid_diststats[key]) for key in sorted(lloid_output[None].keys())):
+	bgbin_indices = sorted(lloid_output[None].keys())
+	for n, (outputs, diststats, bgbin_index) in enumerate((lloid_output[None][key], lloid_diststats[key], key) for key in bgbin_indices):
 		inputs = [o[0] for o in outputs]
 		parents = flatten([o[1] for o in outputs])
-		rankfile = functools.partial(get_rank_file, instruments, boundary_seg, n)
+		rankfile = functools.partial(get_rank_file, instruments, boundary_seg, '%04d'%n)
 
 		# FIXME we keep this here in case we someday want to have a
 		# mass bin dependent prior, but it really doesn't matter for
@@ -1024,44 +1021,69 @@ def rank_and_merge_layer(dag, jobs, svd_nodes, lloid_output, lloid_diststats, op
 			input_cache_file_name = rankfile('MARG_DIST_STATS')
 		)
 
+		margnodes[bgbin_index] = diststats_per_bin_node
+
+	return margnodes
+
+def calc_rank_pdf_layer(dag, jobs, marg_nodes, options, boundary_seg, instrument_set):
+	rankpdf_nodes = []
+	rankpdf_zerolag_nodes = []
+	instruments = "".join(sorted(instrument_set))
+
+	# Here n counts the bins
+	for n, bgbin_index in enumerate(sorted(marg_nodes.keys())):
+		rankfile = functools.partial(get_rank_file, instruments, boundary_seg, '%04d'%n)
+
 		calcranknode = dagparts.DAGNode(jobs['calcRankPDFs'], dag,
-			parent_nodes = [diststats_per_bin_node],
+			parent_nodes = [marg_nodes[bgbin_index]],
 			opts = {"ranking-stat-samples":options.ranking_stat_samples},
-			input_files = {"":diststats_per_bin_node.output_files["output"]},
+			input_files = {"": marg_nodes[bgbin_index].output_files["output"]},
 			output_files = {"output": rankfile('CALC_RANK_PDFS', job=jobs['calcRankPDFs'])},
 		)
 
 		calcrankzerolagnode = dagparts.DAGNode(jobs['calcRankPDFsWithZerolag'], dag,
-			parent_nodes = [diststats_per_bin_node],
-			opts = {"add-zerolag-to-background":"","ranking-stat-samples":options.ranking_stat_samples},
-			input_files = {"":diststats_per_bin_node.output_files["output"]},
+			parent_nodes = [marg_nodes[bgbin_index]],
+			opts = {"add-zerolag-to-background": "", "ranking-stat-samples": options.ranking_stat_samples},
+			input_files = {"": marg_nodes[bgbin_index].output_files["output"]},
 			output_files = {"output": rankfile('CALC_RANK_PDFS_WZL', job=jobs['calcRankPDFsWithZerolag'])},
 		)
 
-		margnodes['%04d' %(n,)] = diststats_per_bin_node
 		rankpdf_nodes.append(calcranknode)
 		rankpdf_zerolag_nodes.append(calcrankzerolagnode)
 
+	return rankpdf_nodes, rankpdf_zerolag_nodes
+
+def likelihood_layer(dag, jobs, marg_nodes, lloid_output, lloid_diststats, options, boundary_seg, instrument_set):
+	likelihood_nodes = {}
+	instruments = "".join(sorted(instrument_set))
+	chunk_size = 16
+
+	bgbin_indices = sorted(lloid_output[None].keys())
+	for n, (outputs, diststats, bgbin_index) in enumerate((lloid_output[None][key], lloid_diststats[key], key) for key in bgbin_indices):
+		rankfile = functools.partial(get_rank_file, instruments, boundary_seg, '%04d'%n)
+		inputs = [o[0] for o in outputs]
+
 		# Break up the likelihood jobs into chunks to process fewer files, e.g, 16
 		likelihood_nodes.setdefault(None,[]).append(
 			[dagparts.DAGNode(jobs['calcLikelihood'], dag,
-				parent_nodes = [diststats_per_bin_node],
-				opts = {"tmp-space":dagparts.condor_scratch_space()},
-				input_files = {"likelihood-url":diststats_per_bin_node.output_files["output"]},
-				input_cache_files = {"input-cache":chunked_inputs}
-				) for chunked_inputs in chunks(inputs, 16)]
+				parent_nodes = [marg_nodes[bgbin_index]],
+				opts = {"tmp-space": dagparts.condor_scratch_space()},
+				input_files = {"likelihood-url": marg_nodes[bgbin_index].output_files["output"]},
+				input_cache_files = {"input-cache": chunked_inputs}
+				) for chunked_inputs in chunks(inputs, chunk_size)]
 			)
 
 	# then injections
 	for inj in options.injections:
 		lloid_nodes = lloid_output[sim_tag_from_inj_file(inj)]
-		for outputs, diststats, bgbin_index in ((lloid_nodes[key], lloid_diststats[key], key) for key in sorted(lloid_nodes.keys())):
+		bgbin_indices = sorted(lloid_nodes.keys())
+		for n, (outputs, diststats, bgbin_index) in enumerate((lloid_nodes[key], lloid_diststats[key], key) for key in bgbin_indices):
 			if outputs is not None:
 				inputs = [o[0] for o in outputs]
 				parents = flatten([o[1] for o in outputs])
-				if margnodes:
-					parents.append(margnodes[bgbin_index])
-					likelihood_url = margnodes[bgbin_index].output_files["output"]
+				if marg_nodes[bgbin_index]:
+					parents.append(marg_nodes[bgbin_index])
+					likelihood_url = marg_nodes[bgbin_index].output_files["output"]
 				else:
 					likelihood_url = diststats[0]
 
@@ -1069,50 +1091,13 @@ def rank_and_merge_layer(dag, jobs, svd_nodes, lloid_output, lloid_diststats, op
 				likelihood_nodes.setdefault(sim_tag_from_inj_file(inj),[]).append(
 					[dagparts.DAGNode(jobs['calcLikelihoodInj'], dag,
 						parent_nodes = parents,
-						opts = {"tmp-space":dagparts.condor_scratch_space()},
-						input_files = {"likelihood-url":likelihood_url},
-						input_cache_files = {"input-cache":chunked_inputs}
-						) for chunked_inputs in chunks(inputs, 16)]
+						opts = {"tmp-space": dagparts.condor_scratch_space()},
+						input_files = {"likelihood-url": likelihood_url},
+						input_cache_files = {"input-cache": chunked_inputs}
+						) for chunked_inputs in chunks(inputs, chunk_size)]
 					)
 
-
-	# after assigning the likelihoods cluster and merge by sub bank and whether or not it was an injection run
-	files_to_group = 40
-	for subbank, (inj, nodes) in enumerate(likelihood_nodes.items()):
-		if inj is None:
-			outnode_key = None
-			sql_file = options.cluster_sql_file
-		else:
-			outnode_key = sim_tag_from_inj_file(inj)
-			sql_file = options.injection_sql_file
-
-		# Flatten the nodes for this sub bank
-		nodes = flatten(nodes)
-		merge_nodes = []
-		# Flatten the input/output files from calc_likelihood
-		inputs = flatten([node.input_files["input-cache"] for node in nodes])
-
-		# files_to_group at a time irrespective of the sub bank they came from so the jobs take a bit longer to run
-		for input_files in chunks(inputs, files_to_group):
-			merge_nodes.append(dagparts.DAGNode(jobs['lalappsRunSqlite'], dag, parent_nodes = nodes,
-				opts = {"sql-file": sql_file, "tmp-space": dagparts.condor_scratch_space()},
-				input_files = {"": input_files}
-				)
-			)
-			if options.copy_raw_results:
-				merge_nodes[-1].set_pre_script("store_raw.sh")
-				merge_nodes[-1].add_pre_script_arg(" ".join(input_files))
-
-		# Merging all the dbs from the same sub bank
-		for subbank, inputs in enumerate([node.input_files["input-cache"] for node in nodes]):
-			db = inputs_to_db(jobs, inputs)
-			sqlitenode = merge_cluster_layer(dag, jobs, merge_nodes, db, inputs, sql_file)
-			outnodes.setdefault(outnode_key, []).append(sqlitenode)
-
-	# make sure outnodes has a None key, even if its value is an empty list
-	outnodes.setdefault(None, [])
-
-	return rankpdf_nodes, rankpdf_zerolag_nodes, outnodes
+	return likelihood_nodes
 
 def merge_in_bin_layer(dag, jobs, options):
 	rankpdf_nodes = sorted([CacheEntry(line).path for line in open(options.rank_pdf_cache)], key = lambda s: int(os.path.basename(s).split('-')[1].split('_')[0]))
@@ -1169,7 +1154,45 @@ def merge_in_bin_layer(dag, jobs, options):
 
 	return rankpdf_nodes, rankpdf_zerolag_nodes, outnodes
 
-def finalize_run_layer(dag, jobs, innodes, ligolw_add_nodes, options, instruments):
+def sql_cluster_and_merge_layer(dag, jobs, likelihood_nodes, ligolw_add_nodes, options, instruments):
+	innodes = {}
+
+	# after assigning the likelihoods cluster and merge by sub bank and whether or not it was an injection run
+	files_to_group = 40
+	for subbank, (inj, nodes) in enumerate(likelihood_nodes.items()):
+		if inj is None:
+			innode_key = None
+			sql_file = options.cluster_sql_file
+		else:
+			innode_key = sim_tag_from_inj_file(inj)
+			sql_file = options.injection_sql_file
+
+		# Flatten the nodes for this sub bank
+		nodes = flatten(nodes)
+		merge_nodes = []
+		# Flatten the input/output files from calc_likelihood
+		inputs = flatten([node.input_files["input-cache"] for node in nodes])
+
+		# files_to_group at a time irrespective of the sub bank they came from so the jobs take a bit longer to run
+		for input_files in chunks(inputs, files_to_group):
+			merge_nodes.append(dagparts.DAGNode(jobs['lalappsRunSqlite'], dag, parent_nodes = nodes,
+				opts = {"sql-file": sql_file, "tmp-space": dagparts.condor_scratch_space()},
+				input_files = {"": input_files}
+				)
+			)
+			if options.copy_raw_results:
+				merge_nodes[-1].set_pre_script("store_raw.sh")
+				merge_nodes[-1].add_pre_script_arg(" ".join(input_files))
+
+		# Merging all the dbs from the same sub bank
+		for subbank, inputs in enumerate([node.input_files["input-cache"] for node in nodes]):
+			db = inputs_to_db(jobs, inputs)
+			sqlitenode = merge_cluster_layer(dag, jobs, merge_nodes, db, inputs, sql_file)
+			innodes.setdefault(innode_key, []).append(sqlitenode)
+
+	# make sure outnodes has a None key, even if its value is an empty list
+	innodes.setdefault(None, [])
+
 	num_chunks = 50
 
 	if options.vetoes is None:
@@ -1307,18 +1330,16 @@ def finalize_run_layer(dag, jobs, innodes, ligolw_add_nodes, options, instrument
 
 	return injdbs, noninjdb, outnodes, dbs_to_delete
 
-def compute_fap_layer(dag, jobs, rankpdf_nodes, rankpdf_zerolag_nodes, injdbs, noninjdb, final_sqlite_nodes):
-	"""compute FAPs and FARs
-	"""
+def final_marginalize_layer(dag, jobs, rankpdf_nodes, rankpdf_zerolag_nodes):
 	ranknodes = [rankpdf_nodes, rankpdf_zerolag_nodes]
 	margjobs = [jobs['marginalize'], jobs['marginalizeWithZerolag']]
 	margfiles = [options.marginalized_likelihood_file, options.marginalized_likelihood_file]
 	filesuffixs = ['', '_with_zerolag']
 
 	margnum = 16
-	outnode = None
 	all_margcache = []
 	all_margnodes = []
+	final_margnodes = []
 	for nodes, job, margfile, filesuffix in zip(ranknodes, margjobs, margfiles, filesuffixs):
 		try:
 			margin = [node.output_files["output"] for node in nodes]
@@ -1346,18 +1367,29 @@ def compute_fap_layer(dag, jobs, rankpdf_nodes, rankpdf_zerolag_nodes, injdbs, n
 		all_margcache.append(margcache)
 		all_margnodes.append(margnodes)
 
-	for nodes, job, margnodes, margcache, margfile, filesuffix in zip(ranknodes, margjobs, all_margnodes, all_margcache, margfiles, filesuffixs):
+	if not options.marginalized_likelihood_file: ### not an injection-only run
+		for nodes, job, margnodes, margcache, margfile, filesuffix in zip(ranknodes, margjobs, all_margnodes, all_margcache, margfiles, filesuffixs):
+			final_margnodes.append(dagparts.DAGNode(job, dag, parent_nodes = margnodes,
+				opts = {"marginalize": "ranking-stat-pdf"},
+				output_files = {"output": "marginalized_likelihood%s.xml.gz"%filesuffix},
+				input_cache_files = {"likelihood-cache": margcache},
+				input_cache_file_name = "marginalized_likelihood%s.cache"%filesuffix
+			))
+
+	return final_margnodes, flatten(all_margcache)
+
+def compute_far_layer(dag, jobs, margnodes, injdbs, noninjdb, final_sqlite_nodes):
+	"""compute FAPs and FARs
+	"""
+	margfiles = [options.marginalized_likelihood_file, options.marginalized_likelihood_file]
+	filesuffixs = ['', '_with_zerolag']
+
+	for margnode, margfile, filesuffix in zip(margnodes, margfiles, filesuffixs):
 		if options.marginalized_likelihood_file: ### injection-only run
 			parents = final_sqlite_nodes
 			marginalized_likelihood_file = margfile
 
 		else:
-			margnode = dagparts.DAGNode(job, dag, parent_nodes = margnodes,
-				opts = {"marginalize": "ranking-stat-pdf"},
-				output_files = {"output": "marginalized_likelihood%s.xml.gz"%filesuffix},
-				input_cache_files = {"likelihood-cache": margcache},
-				input_cache_file_name = "marginalized_likelihood%s.cache"%filesuffix
-			)
 			parents = [margnode] + final_sqlite_nodes
 			marginalized_likelihood_file = margnode.output_files["output"]
 
@@ -1373,7 +1405,7 @@ def compute_fap_layer(dag, jobs, rankpdf_nodes, rankpdf_zerolag_nodes, injdbs, n
 		if 'zerolag' not in filesuffix:
 			outnode = farnode
 
-	return outnode, flatten(all_margcache)
+	return outnode
 
 def horizon_dist_layer(dag, jobs, psd_nodes, options, boundary_seg, output_dir):
 	"""calculate horizon distance
@@ -1492,18 +1524,28 @@ if __name__ == '__main__':
 		# Inspiral jobs by segment
 		inspiral_nodes, lloid_output, lloid_diststats = inspiral_layer(dag, jobs, svd_nodes, segsdict, options, channel_dict, template_mchirp_dict)
 
-		# Setup likelihood jobs, clustering and/or merging
-		rankpdf_nodes, rankpdf_zerolag_nodes, outnodes = rank_and_merge_layer(dag, jobs, svd_nodes, lloid_output, lloid_diststats, options, boundary_seg, instrument_set, model_node, model_file)
+		# marginalize jobs
+		marg_nodes = marginalize_layer(dag, jobs, svd_nodes, lloid_output, lloid_diststats, options, boundary_seg, instrument_set, model_node, model_file)
+
+		# calc rank PDF jobs
+		rankpdf_nodes, rankpdf_zerolag_nodes = calc_rank_pdf_layer(dag, jobs, marg_nodes, options, boundary_seg, instrument_set)
 
 	else:
 		# Merge lloid files into 1 file per bin if not already 1 file per bin
-		rankpdf_nodes, rankpdf_zerolag_nodes, outnodes = merge_in_bin_layer(dag, jobs, options)
+		rankpdf_nodes, rankpdf_zerolag_nodes, likelihood_nodes = merge_in_bin_layer(dag, jobs, options)
+
+	# final marginalization step
+	final_marg_nodes, margfiles_to_delete = final_marginalize_layer(dag, jobs, rankpdf_nodes, rankpdf_zerolag_nodes)
+
+	# likelihood jobs
+	likelihood_nodes = likelihood_layer(dag, jobs, marg_nodes, lloid_output, lloid_diststats, options, boundary_seg, instrument_set)
 
+	# Setup clustering and/or merging
 	# after all of the likelihood ranking and preclustering is finished put everything into single databases based on the injection file (or lack thereof)
-	injdbs, noninjdb, final_sqlite_nodes, dbs_to_delete = finalize_run_layer(dag, jobs, outnodes, ligolw_add_nodes, options, instruments)
+	injdbs, noninjdb, final_sqlite_nodes, dbs_to_delete = sql_cluster_and_merge_layer(dag, jobs, likelihood_nodes, ligolw_add_nodes, options, instruments)
 
-	# Compute FAP
-	farnode, margfiles_to_delete = compute_fap_layer(dag, jobs, rankpdf_nodes, rankpdf_zerolag_nodes, injdbs, noninjdb, final_sqlite_nodes)
+	# Compute FAR
+	farnode = compute_far_layer(dag, jobs, final_marg_nodes, injdbs, noninjdb, final_sqlite_nodes)
 
 	# make summary plots
 	plotnodes = summary_plot_layer(dag, jobs, farnode, options)
-- 
GitLab