From 258b8206005e3c4a64b650be5529de0970804d13 Mon Sep 17 00:00:00 2001
From: Chad Hanna <chad.hanna@ligo.org>
Date: Mon, 31 Dec 2018 12:31:20 -0500
Subject: [PATCH] gstlal_inspiral, gstlal_inspiral_create_prior_diststats,
 far.py, inspiral_lr.py, svd_bank.py: add support for per template corrections
 to the horizon distance used in the likelihood ratio using the sigma values
 computed during the SVD decomposition

---
 gstlal-inspiral/bin/gstlal_inspiral           |  8 +++--
 .../gstlal_inspiral_create_prior_diststats    |  8 +++--
 gstlal-inspiral/python/far.py                 |  6 ++--
 gstlal-inspiral/python/stats/inspiral_lr.py   | 25 +++++++++++++-
 gstlal-inspiral/python/svd_bank.py            | 34 +++++++++++--------
 5 files changed, 58 insertions(+), 23 deletions(-)

diff --git a/gstlal-inspiral/bin/gstlal_inspiral b/gstlal-inspiral/bin/gstlal_inspiral
index 516c9f5977..84631b6966 100755
--- a/gstlal-inspiral/bin/gstlal_inspiral
+++ b/gstlal-inspiral/bin/gstlal_inspiral
@@ -691,8 +691,10 @@ for output_file_number, (svd_bank_url_dict, output_url, ranking_stat_output_url,
 	# assume all instruments have the same templates, just extract them
 	# from one of the instruments at random
 	sngl_inspiral_table = banks.values()[0][0].sngl_inspiral_table.copy()
+	horizon_factors = {}
 	for bank in banks.values()[0]:
 		sngl_inspiral_table.extend(bank.sngl_inspiral_table)
+		horizon_factors.update(bank.horizon_factors)
 	@bottle.route("/template_bank.xml.gz")
 	def get_template_bank_xml(sngl_inspiral_table = sngl_inspiral_table):
 		xmldoc = ligolw.Document()
@@ -704,6 +706,7 @@ for output_file_number, (svd_bank_url_dict, output_url, ranking_stat_output_url,
 		output.close()
 		return outstr
 	template_ids = frozenset(row.template_id for row in sngl_inspiral_table)
+	assert set(horizon_factors) == template_ids, "horizon factors are not assigned for each template id"
 	assert len(template_ids) == len(sngl_inspiral_table), "template IDs are not unique within the template bank"
 	@bottle.route("/template_ids.txt")
 	def get_template_ids(template_ids = sorted(template_ids)):
@@ -783,8 +786,9 @@ for output_file_number, (svd_bank_url_dict, output_url, ranking_stat_output_url,
 				rankingstat.template_ids = template_ids
 			elif rankingstat.template_ids != template_ids:
 				raise ValueError("\"%s\" is for the wrong templates" % options.ranking_stat_input)
+			rankingstat.numerator.set_horizon_factors(horizon_factors)
 	if rankingstat is None:
-		rankingstat = far.RankingStat(template_ids = template_ids, instruments = all_instruments, delta_t = options.coincidence_threshold, min_instruments = options.min_instruments)
+		rankingstat = far.RankingStat(template_ids = template_ids, instruments = all_instruments, delta_t = options.coincidence_threshold, min_instruments = options.min_instruments, horizon_factors = horizon_factors)
 		rankingstat.numerator.add_signal_model(df = 40)
 
 
@@ -811,7 +815,7 @@ for output_file_number, (svd_bank_url_dict, output_url, ranking_stat_output_url,
 			verbose = options.verbose
 		),
 		rankingstat = rankingstat,
-		horizon_distance_func = svd_bank.make_horizon_distance_func(banks),
+		horizon_distance_func = banks.values()[0][0].horizon_distance_func,# they should all be the same
 		gracedbwrapper = inspiral.GracedBWrapper(
 			instruments = rankingstat.instruments,
 			far_threshold = options.gracedb_far_threshold,
diff --git a/gstlal-inspiral/bin/gstlal_inspiral_create_prior_diststats b/gstlal-inspiral/bin/gstlal_inspiral_create_prior_diststats
index 625874e8f2..42476d3b63 100755
--- a/gstlal-inspiral/bin/gstlal_inspiral_create_prior_diststats
+++ b/gstlal-inspiral/bin/gstlal_inspiral_create_prior_diststats
@@ -94,10 +94,12 @@ def parse_command_line():
 		raise ValueError("unrecognized arguments after options: %s" % " ".join(filenames))
 
 	template_ids = []
+	horizon_factors = {}
 	for n, bank in enumerate(svd_bank.read_banks(options.svd_file, contenthandler = LIGOLWContentHandler, verbose = options.verbose)):
 		template_ids += [row.template_id for row in bank.sngl_inspiral_table]
+		horizon_factors.update(bank.horizon_factors)
 
-	return options, process_params, filenames, template_ids
+	return options, process_params, filenames, template_ids, horizon_factors
 
 
 #
@@ -114,7 +116,7 @@ def parse_command_line():
 #
 
 
-options, process_params, filenames, template_ids = parse_command_line()
+options, process_params, filenames, template_ids, horizon_factors = parse_command_line()
 
 
 #
@@ -132,7 +134,7 @@ process = ligolw_process.register_to_xmldoc(xmldoc, u"gstlal_inspiral_create_pri
 #
 
 
-rankingstat = far.RankingStat(template_ids = template_ids, instruments = options.instrument, delta_t = options.coincidence_threshold, min_instruments = options.min_instruments, population_model_file = options.mass_model_file)
+rankingstat = far.RankingStat(template_ids = template_ids, instruments = options.instrument, delta_t = options.coincidence_threshold, min_instruments = options.min_instruments, population_model_file = options.mass_model_file, horizon_factors = horizon_factors)
 
 if options.background_prior > 0:
 	rankingstat.denominator.add_noise_model(number_of_events = options.background_prior)
diff --git a/gstlal-inspiral/python/far.py b/gstlal-inspiral/python/far.py
index b2f883d948..391accb21f 100644
--- a/gstlal-inspiral/python/far.py
+++ b/gstlal-inspiral/python/far.py
@@ -129,10 +129,10 @@ class RankingStat(snglcoinc.LnLikelihoodRatioMixin):
 		pass
 
 	# network SNR threshold
-	network_snrsq_threshold = 6.0**2.
+	network_snrsq_threshold = 5.0**2.
 
-	def __init__(self, template_ids = None, instruments = frozenset(("H1", "L1", "V1")), population_model_file = None, min_instruments = 1, delta_t = 0.005):
-		self.numerator = inspiral_lr.LnSignalDensity(template_ids = template_ids, instruments = instruments, delta_t = delta_t, population_model_file = population_model_file, min_instruments = min_instruments)
+	def __init__(self, template_ids = None, instruments = frozenset(("H1", "L1", "V1")), population_model_file = None, min_instruments = 1, delta_t = 0.005, horizon_factors = None):
+		self.numerator = inspiral_lr.LnSignalDensity(template_ids = template_ids, instruments = instruments, delta_t = delta_t, population_model_file = population_model_file, min_instruments = min_instruments, horizon_factors = horizon_factors)
 		self.denominator = inspiral_lr.LnNoiseDensity(template_ids = template_ids, instruments = instruments, delta_t = delta_t, min_instruments = min_instruments)
 		self.zerolag = inspiral_lr.LnLRDensity(template_ids = template_ids, instruments = instruments, delta_t = delta_t, min_instruments = min_instruments)
 
diff --git a/gstlal-inspiral/python/stats/inspiral_lr.py b/gstlal-inspiral/python/stats/inspiral_lr.py
index b9b6caaa8b..47096aaa70 100644
--- a/gstlal-inspiral/python/stats/inspiral_lr.py
+++ b/gstlal-inspiral/python/stats/inspiral_lr.py
@@ -39,6 +39,7 @@ from scipy import interpolate
 from scipy import stats
 import sys
 import warnings
+import json
 
 
 from glue.ligolw import ligolw
@@ -292,6 +293,7 @@ class LnLRDensity(snglcoinc.LnLRDensity):
 class LnSignalDensity(LnLRDensity):
 	def __init__(self, *args, **kwargs):
 		population_model_file = kwargs.pop("population_model_file", None)
+		self.horizon_factors = kwargs.pop("horizon_factors", None)
 		super(LnSignalDensity, self).__init__(*args, **kwargs)
 
 		# install SNR, chi^2 PDF (one for all instruments)
@@ -309,6 +311,9 @@ class LnSignalDensity(LnLRDensity):
 		self.population_model = inspiral_intrinsics.SourcePopulationModel(self.template_ids, filename = self.population_model_file)
 		self.InspiralExtrinsics = inspiral_extrinsics.InspiralExtrinsics(self.min_instruments)
 
+	def set_horizon_factors(self, horizon_factors):
+		self.horizon_factors = horizon_factors
+
 	def __call__(self, segments, snrs, chi2s_over_snr2s, phase, dt, template_id):
 		assert frozenset(segments) == self.instruments
 		if len(snrs) < self.min_instruments:
@@ -341,7 +346,12 @@ class LnSignalDensity(LnLRDensity):
 		horizon = sorted(horizons.values())[-self.min_instruments] / TYPICAL_HORIZON_DISTANCE
 		if not horizon:
 			return NegInf
-		lnP = 3. * math.log(horizon) + math.log(len(self.template_ids))
+		# horizon factors adjusts the computed horizon factor according
+		# to the sigma values computed at the time of the SVD. This
+		# gives a good approximation to the horizon distance for each
+		# template without computing them each explicitly. Only one
+		# template has its horizon calculated explicitly.
+		lnP = 3. * math.log(horizon * self.horizon_factors[template_id]) + math.log(len(self.template_ids))
 
 		# Add P(instruments | horizon distances)
 		try:
@@ -378,6 +388,10 @@ class LnSignalDensity(LnLRDensity):
 			raise ValueError("incompatible mass model file names")
 		if self.population_model_file is None and other.population_model_file is not None:
 			self.population_model_file = other.population_model_file
+		if self.horizon_factors is not None and other.horizon_factors is not None and other.horizon_factors != self.horizon_factors:
+			raise ValueError("incompatible horizon_factors")
+		if self.horizon_factors is None and other.horizon_factors is not None:
+			self.horizon_factors = other.horizon_factors
 
 		return self
 
@@ -391,6 +405,7 @@ class LnSignalDensity(LnLRDensity):
 		# okay to use references because read-only data
 		new.population_model = self.population_model
 		new.InspiralExtrinsics = self.InspiralExtrinsics
+		new.horizon_factors = self.horizon_factors
 		return new
 
 	def local_mean_horizon_distance(self, gps, window = segments.segment(-32., +2.)):
@@ -534,6 +549,7 @@ class LnSignalDensity(LnLRDensity):
 		xml = super(LnSignalDensity, self).to_xml(name)
 		xml.appendChild(self.horizon_history.to_xml(u"horizon_history"))
 		xml.appendChild(ligolw_param.Param.from_pyvalue(u"population_model_file", self.population_model_file))
+		xml.appendChild(ligolw_param.Param.from_pyvalue(u"horizon_factors", json.dumps(self.horizon_factors) if self.horizon_factors is not None else None))
 		return xml
 
 	@classmethod
@@ -542,6 +558,13 @@ class LnSignalDensity(LnLRDensity):
 		self = super(LnSignalDensity, cls).from_xml(xml, name)
 		self.horizon_history = horizonhistory.HorizonHistories.from_xml(xml, u"horizon_history")
 		self.population_model_file = ligolw_param.get_pyvalue(xml, u"population_model_file")
+		self.horizon_factors = ligolw_param.get_pyvalue(xml, u"horizon_factors")
+		if self.horizon_factors is not None:
+			# FIXME, how do we properly decode the json, I assume something in ligolw can do this?
+			self.horizon_factors = self.horizon_factors.replace("\\","").replace('\\"','"')
+			self.horizon_factors = json.loads(self.horizon_factors)
+			self.horizon_factors = dict((int(k), v) for k, v in self.horizon_factors.items())
+			assert set(self.template_ids) == set(self.horizon_factors)
 		self.population_model = inspiral_intrinsics.SourcePopulationModel(self.template_ids, filename = self.population_model_file)
 		self.InspiralExtrinsics = inspiral_extrinsics.InspiralExtrinsics(self.min_instruments)
 		return self
diff --git a/gstlal-inspiral/python/svd_bank.py b/gstlal-inspiral/python/svd_bank.py
index 2887c09775..0c39adce8b 100644
--- a/gstlal-inspiral/python/svd_bank.py
+++ b/gstlal-inspiral/python/svd_bank.py
@@ -388,6 +388,9 @@ def read_banks(filename, contenthandler, verbose = False):
 		bank.autocorrelation_mask = ligolw_array.get_array(root, 'autocorrelation_mask').array
 		bank.sigmasq = ligolw_array.get_array(root, 'sigmasq').array
 
+		# prepare the horizon distance factors
+		bank.horizon_factors = dict((row.template_id, sigmasq**.5) for row, sigmasq in zip(bank.sngl_inspiral_table, bank.sigmasq))
+
 		# attach a reference to the psd
 		bank.processed_psd = processed_psd
 
@@ -418,6 +421,15 @@ def read_banks(filename, contenthandler, verbose = False):
 			bank.bank_fragments.append(frag)
 
 		banks.append(bank)
+	template_id, func = horizon_distance_func(banks)
+	horizon_norm = None
+	for bank in banks:
+		if template_id in bank.horizon_factors:
+			assert horizon_norm is None
+			horizon_norm = bank.horizon_factors[template_id]
+	for bank in banks:
+		bank.horizon_distance_func = func
+		bank.horizon_factors = dict((tid, f / horizon_norm) for (tid, f) in bank.horizon_factors.items())
 	xmldoc.unlink()
 	return banks
 
@@ -449,27 +461,21 @@ def svdbank_templates_mapping(filenames, contenthandler, verbose = False):
 	return mapping
 
 
-def make_horizon_distance_func(banks):
+def horizon_distance_func(banks):
 	"""
 	Takes a dictionary of objects returned by read_banks keyed by instrument
 	"""
 	# span is [15 Hz, 0.85 * Nyquist frequency]
 	# find the Nyquist frequency for the PSD to be used for each
 	# instrument.  require them to all match
-	sngl_inspiral_table = banks.values()[0][0].sngl_inspiral_table.copy()
-	for bank in banks.values()[0]:
-		sngl_inspiral_table.extend(bank.sngl_inspiral_table)
-	nyquists = set(max(rate for bank in banklist for rate in bank.get_rates()) // 2 for instrument, banklist in banks.items())
+	nyquists = set((max(bank.get_rates()) for bank in banks))
 	assert len(nyquists) == 1, "all banks must have the same Nyquist frequency to define a consistent horizon distance function (got %s)" % ", ".join("%g" % rate for rate in sorted(nyquists))
-	# assume default 32 s PSD.  this is not required to be correct, but
+	# assume default 4 s PSD.  this is not required to be correct, but
 	# for best accuracy it should not be larger than the true value and
 	# for best performance it should not be smaller than the true
 	# value.
-	deltaF = 1. / 32.
-	# FIXME (from Chad) What is the 5/3 for???
-	# pick (m1, m2) from the median template ranked by Mchirp^(5/3)
-	# to provide the canonical waveform model.  See Maggiore equation
-	# (4.3).
-	assert len(sngl_inspiral_table) > 0, "no templates:  must have templates to define horizon distance function"
-	median_row = sorted(sngl_inspiral_table, key = lambda row: row.mchirp**(5./3.))[len(sngl_inspiral_table) // 2]
-	return reference_psd.HorizonDistance(15.0, 0.85 * max(nyquists), deltaF, median_row.mass1, median_row.mass2)
+	deltaF = 1. / 4.
+	# use the minimum template id as the cannonical horizon function
+	template_id, m1, m2 = min((row.template_id, row.mass1, row.mass2) for bank in banks for row in bank.sngl_inspiral_table)
+
+	return template_id, reference_psd.HorizonDistance(15.0, 0.85 * max(nyquists), deltaF, m1, m2)
-- 
GitLab