From 65ee1c9111d4c7cd9dad289dd260a116a2468af1 Mon Sep 17 00:00:00 2001
From: Leo Tsukada <leo.tsukada@ligo.org>
Date: Mon, 3 Apr 2023 13:58:58 -0700
Subject: [PATCH] python/dags/layers/inspiral.py : move functions and some
 parts to gstlal_inspiral_set_bin_option_svdmanifest

---
 .../python/dags/layers/inspiral.py            | 100 +++---------------
 1 file changed, 17 insertions(+), 83 deletions(-)

diff --git a/gstlal-inspiral/python/dags/layers/inspiral.py b/gstlal-inspiral/python/dags/layers/inspiral.py
index df16babf17..c8e2a609a3 100644
--- a/gstlal-inspiral/python/dags/layers/inspiral.py
+++ b/gstlal-inspiral/python/dags/layers/inspiral.py
@@ -20,11 +20,9 @@ import itertools
 import os
 import math
 import sys
-from typing import Iterable
 
 import numpy
 
-from lal import rate
 from lal.utils import CacheEntry
 from ligo.segments import segment
 
@@ -99,9 +97,6 @@ def svd_bank_layer(config, dag, median_psd_cache, split_bank_cache=None):
 		transfer_files=config.condor.transfer_files,
 	)
 
-	# set up autocorrelation mapping
-	mchirp_to_ac_length = autocorrelation_length_map(config.svd.autocorrelation_length)
-
 	svd_cache = DataCache.generate(
 		DataType.SVD_BANK,
 		config.ifos,
@@ -120,8 +115,6 @@ def svd_bank_layer(config, dag, median_psd_cache, split_bank_cache=None):
 		else:
 			svd_config = config.svd
 
-		bin_mchirp = config.svd.stats.bins[svd_bin]["mean_mchirp"]
-
 		arguments = [
 			Option("instrument-override", ifo),
 			Option("flow", svd_config.f_low),
@@ -130,7 +123,7 @@ def svd_bank_layer(config, dag, median_psd_cache, split_bank_cache=None):
 			Option("samples-max-256", svd_config.samples_max_256),
 			Option("samples-max", svd_config.samples_max),
 			Option("svd-tolerance", svd_config.tolerance),
-			Option("autocorrelation-length", mchirp_to_ac_length(bin_mchirp)),
+			Option("autocorrelation-length", config.svd.stats.bins[svd_bin]["ac_length"]),
 		]
 		if "max_duration" in svd_config:
 			arguments.append(Option("max-duration", svd_config.max_duration))
@@ -275,7 +268,7 @@ def filter_layer(config, dag, ref_psd_cache, svd_bank_cache):
 		for trigger_group in triggers.chunked(num_per_group):
 			svd_bins = trigger_group.groupby("bin").keys()
 
-			thresholds = [calc_gate_threshold(config, svd_bin) for svd_bin in svd_bins]
+			thresholds = [config.svd.stats.bins[svd_bin]["ht_gate_threshold"] for svd_bin in svd_bins]
 			these_opts = [Option("ht-gate-threshold", thresholds), *filter_opts]
 
 			svd_bank_files = dagutil.flatten(
@@ -395,7 +388,7 @@ def filter_injections_layer(config, dag, ref_psd_cache, svd_bank_cache):
 		for trigger_group in triggers.chunked(num_per_group):
 			svd_bins = trigger_group.groupby("bin").keys()
 
-			thresholds = [calc_gate_threshold(config, svd_bin) for svd_bin in svd_bins]
+			thresholds = [config.svd.stats.bins[svd_bin]["ht_gate_threshold"] for svd_bin in svd_bins]
 			these_opts = [Option("ht-gate-threshold", thresholds), *filter_opts]
 
 			svd_bank_files = dagutil.flatten(
@@ -527,6 +520,7 @@ def create_prior_layer(config, dag, svd_bank_cache, median_psd_cache, dist_stat_
 					]
 		if prior_df == "bandwidth":
 			inputs += [Option("psd-xml", median_psd_cache.files)]
+
 		layer += Node(
 			arguments = [
 				Option("df", prior_df),
@@ -1694,7 +1688,7 @@ def filter_online_layer(config, dag, svd_bank_cache, dist_stat_cache, zerolag_pd
 		job_tag = f"{int(svd_bin):04d}_noninj"
 		filter_opts = [
 			Option("job-tag", job_tag),
-			Option("ht-gate-threshold", calc_gate_threshold(config, svd_bin)),
+			Option("ht-gate-threshold", config.svd.stats.bins[svd_bin]["ht_gate_threshold"]),
 		]
 		filter_opts.extend(common_opts)
 		filter_opts.extend(datasource_opts)
@@ -1826,7 +1820,7 @@ def filter_injections_online_layer(config, dag, svd_bank_cache, dist_stat_cache,
 	zerolag_pdfs = zerolag_pdf_cache.groupby("bin")
 	for svd_bin, svd_banks in svd_bank_cache.groupby("bin").items():
 		filter_opts = [
-			Option("ht-gate-threshold", calc_gate_threshold(config, svd_bin)),
+			Option("ht-gate-threshold", cconfig.svd.stats.bins[svd_bin]["ht_gate_threshold"]),
 		]
 		filter_opts.extend(common_opts)
 		filter_opts.extend(datasource_opts)
@@ -2242,79 +2236,19 @@ def add_ranking_stat_file_options(config, svd_bin=None, transfer_only=False):
 	else:
 		kwargs = {}
 
-	inputs = [Option("mass-model-file", config.prior.mass_model, **kwargs)]
-
-	if config.prior.idq_timeseries:
-		inputs.append(Option("idq-file", config.prior.idq_timeseries, **kwargs))
-
-	if config.prior.dtdphi:
-		if isinstance(config.prior.dtdphi, Mapping):
-			if svd_bin is None:
-				dtdphi_files = list(config.prior.dtdphi.values())
-				inputs.append(Option("dtdphi-file", dtdphi_files, **kwargs))
-			else:
-				sub_bank = config.svd.stats.bins[svd_bin]["bank_name"]
-				inputs.append(Option("dtdphi-file", config.prior.dtdphi[sub_bank], **kwargs))
-		else:
-			inputs.append(Option("dtdphi-file", config.prior.dtdphi, **kwargs))
-
-	return inputs
-
-
-def calc_gate_threshold(config, svd_bin, aggregate="max"):
-	"""
-	Given a configuration, svd bin and aggregate, this calculates
-	the h(t) gate threshold used for a given svd bin.
-	"""
-	if isinstance(config.filter.ht_gate_threshold, str):
-		bank_mchirp = config.svd.stats["bins"][svd_bin][f"{aggregate}_mchirp"]
-		min_mchirp, min_threshold, max_mchirp, max_threshold = [
-			float(y) for x in config.filter.ht_gate_threshold.split("-") for y in x.split(":")
-		]
-		gate_mchirp_ratio = (max_threshold - min_threshold) / (max_mchirp - min_mchirp)
-		return round(gate_mchirp_ratio * (bank_mchirp - min_mchirp) + min_threshold, 3)
-	else: # uniform threshold
-		return config.filter.ht_gate_threshold
-
-
-def autocorrelation_length_map(ac_length_range):
-	"""
-	Given autocorrelation length ranges (e.g. 0:15:701)
-	or a single autocorrelation value, returns a function that
-	maps a given chirp mass to an autocorrelation length.
-	"""
-	if isinstance(ac_length_range, str):
-		ac_length_range = [ac_length_range]
-
-	# handle case with AC length ranges
-	if isinstance(ac_length_range, Iterable):
-		ac_lengths = []
-		min_mchirps = []
-		max_mchirps = []
-		for this_range in ac_length_range:
-			min_mchirp, max_mchirp, ac_length = this_range.split(":")
-			min_mchirps.append(float(min_mchirp))
-			max_mchirps.append(float(max_mchirp))
-			ac_lengths.append(int(ac_length))
-
-		# sanity check inputs
-		for bound1, bound2 in zip(min_mchirps[1:], max_mchirps[:-1]):
-			assert bound1 == bound2, "gaps not allowed in autocorrelation length ranges"
-
-		# convert to binning
-		bins = rate.IrregularBins([min_mchirps[0]] + max_mchirps)
-
-	# handle single value case
+	if svd_bin is None:
+		inputs = [Option("mass-model-file%d" % i, mass_model_file, **kwargs) for i, mass_model_file in enumerate(set(stats_bin["mass_model_file"] for stats_bin in config.svd.stats.bins.values()))]
+		if config.prior.idq_timeseries:
+			inputs.extend(Option("idq-file%d" % i, idq_file, **kwargs) for i, idq_file in enumerate(set(stats_bin["idq_file"] for stats_bin in config.svd.stats.bins.values())))
+		inputs.extend(Option("dtdphi-file%d" % i, dtdphi_file, **kwargs) for i, dtdphi_file in enumerate(set(stats_bin["dtdphi_file"] for stats_bin in config.svd.stats.bins.values())))
 	else:
-		ac_lengths = [ac_length_range]
-		bins = rate.IrregularBins([0., numpy.inf])
+		stats_bin = config.svd.stats.bins[svd_bin]
+		inputs = [Option("mass-model-file", stats_bin["mass_model_file"], **kwargs)]
+		if config.prior.idq_timeseries:
+			inputs.extend([Option("idq-file", stats_bin["idq_file"], **kwargs)])
+		inputs.extend([Option("dtdphi-file", stats_bin["dtdphi_file"], **kwargs)])
 
-	# create mapping
-	def mchirp_to_ac_length(mchirp):
-		idx = bins[mchirp]
-		return ac_lengths[idx]
-
-	return mchirp_to_ac_length
+	return inputs
 
 
 def mchirp_range_to_bins(min_mchirp, max_mchirp, svd_metadata):
-- 
GitLab