From b2bb1088d1d75e8f517f27eaba03ce4c04ae6a15 Mon Sep 17 00:00:00 2001
From: Patrick Godwin <patrick.godwin@ligo.org>
Date: Thu, 29 Aug 2019 18:37:28 -0700
Subject: [PATCH] inspiral_pipe.py: add load_analysis_output() utility for
 reranking dag, allow more flexibility in get_bank_params() and
 marginalize_layer()

---
 gstlal-inspiral/python/inspiral_pipe.py | 56 +++++++++++++++++++++++--
 1 file changed, 52 insertions(+), 4 deletions(-)

diff --git a/gstlal-inspiral/python/inspiral_pipe.py b/gstlal-inspiral/python/inspiral_pipe.py
index bbeafafaf0..289955b353 100644
--- a/gstlal-inspiral/python/inspiral_pipe.py
+++ b/gstlal-inspiral/python/inspiral_pipe.py
@@ -601,7 +601,8 @@ def marginalize_layer(dag, jobs, svd_nodes, lloid_output, lloid_diststats, optio
 	# FIXME, the svd nodes list has to be the same as the sorted keys of
 	# lloid_output.  svd nodes should be made into a dictionary much
 	# earlier in the code to prevent a mishap
-	one_ifo_svd_nodes = dict(("%04d" % n, node) for n, node in enumerate( svd_nodes.values()[0]))
+	if svd_nodes:
+		one_ifo_svd_nodes = dict(("%04d" % n, [node]) for n, node in enumerate( svd_nodes.values()[0]))
 
 	# Here n counts the bins
 	# FIXME - this is broken for injection dags right now because of marg nodes
@@ -613,11 +614,18 @@ def marginalize_layer(dag, jobs, svd_nodes, lloid_output, lloid_diststats, optio
 		parents = dagparts.flatten([o[1] for o in outputs])
 		rankfile = functools.partial(get_rank_file, instruments, boundary_seg, bin_key)
 
+		if svd_nodes:
+			parent_nodes = [one_ifo_svd_nodes[bin_key]] + model_node
+			svd_file = one_ifo_svd_nodes[bin_key].output_files["write-svd"]
+		else:
+			parent_nodes = model_node
+			svd_file = dagparts.T050017_filename(ifo, '%s_SVD' % bin_key, boundary_arg, '.xml.gz', path = jobs['svd'].output_path)
+
 		# FIXME we keep this here in case we someday want to have a
 		# mass bin dependent prior, but it really doesn't matter for
 		# the time being.
 		priornode = dagparts.DAGNode(jobs['createPriorDistStats'], dag,
-			parent_nodes = [one_ifo_svd_nodes[bin_key]] + model_node or [],
+			parent_nodes = parent_nodes,
 			opts = {
 				"instrument": instrument_set,
 				"background-prior": 1,
@@ -626,7 +634,7 @@ def marginalize_layer(dag, jobs, svd_nodes, lloid_output, lloid_diststats, optio
 				"df": "bandwidth"
 			},
 			input_files = {
-				"svd-file": one_ifo_svd_nodes[bin_key].output_files["write-svd"],
+				"svd-file": svd_file,
 				"mass-model-file": model_file,
 				"dtdphi-file": svd_dtdphi_map[bin_key],
 				"psd-xml": ref_psd
@@ -1023,6 +1031,40 @@ def webserver_url():
 #
 
 
+def load_analysis_output(options):
+	# load triggers
+	bgbin_lloid_map = {}
+	for ce in map(CacheEntry, open(options.lloid_cache)):
+		try:
+			bgbin_idx, _, inj = ce.description.split('_', 2)
+		except:
+			bgbin_idx, _ = ce.description.split('_', 1)
+			inj = None
+		finally:
+			bgbin_lloid_map.setdefault(sim_tag_from_inj_file(inj), []).append(ce.path)
+
+	# load dist stats
+	lloid_diststats = {}
+	boundary_seg = None
+	for ce in map(CacheEntry, open(options.dist_stats_cache)):
+		if 'DIST_STATS' in ce.description:
+			lloid_diststats[ce.description.split("_")[0]] = [ce.path]
+			if not boundary_seg:
+				boundary_seg = ce.segment
+
+	# load svd dtdphi map
+	svd_dtdphi_map = {}
+	bank_cache = load_bank_cache(options)
+	instrument_set = bank_cache.keys()
+	for ifo, list_of_svd_caches in bank_cache.items():
+		bin_offset = 0
+		for j, svd_caches in enumerate(list_of_svd_caches):
+			for i, individual_svd_cache in enumerate(ce.path for ce in map(CacheEntry, open(svd_caches))):
+				svd_dtdphi_map["%04d" % (i+bin_offset)] = options.dtdphi_file[j]
+
+	return bgbin_lloid_map, lloid_diststats, svd_dtdphi_map, instrument_set, boundary_seg
+
+
 def get_threshold_values(template_mchirp_dict, bgbin_indices, svd_bank_strings, options):
 	"""Calculate the appropriate ht-gate-threshold values according to the scale given
 	"""
@@ -1173,7 +1215,7 @@ def sim_tag_from_inj_file(injections):
 	return injections.replace('.xml', '').replace('.gz', '').replace('-','_')
 
 
-def get_bank_params(options, verbose = False):
+def load_bank_cache(options):
 	bank_cache = {}
 	for bank_cache_str in options.bank_cache:
 		for c in bank_cache_str.split(','):
@@ -1181,6 +1223,12 @@ def get_bank_params(options, verbose = False):
 			cache = c.replace(ifo+"=","")
 			bank_cache.setdefault(ifo, []).append(cache)
 
+	return bank_cache
+
+
+def get_bank_params(options, verbose = False):
+	bank_cache = load_bank_cache(options)
+
 	max_time = 0
 	template_mchirp_dict = {}
 	for n, cache in enumerate(bank_cache.values()[0]):
-- 
GitLab