From 3001751f5347ff50d187609ac5de77c92e0305ed Mon Sep 17 00:00:00 2001
From: Chad Hanna <crh184@psu.edu>
Date: Tue, 15 Aug 2017 06:39:18 -0400
Subject: [PATCH] treebank: reuse the metric more often

---
 gstlal-ugly/bin/gstlal_inspiral_treebank |  2 +-
 gstlal-ugly/python/tree.py               | 18 +++++++++---------
 2 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/gstlal-ugly/bin/gstlal_inspiral_treebank b/gstlal-ugly/bin/gstlal_inspiral_treebank
index ac98b580a6..d1984ee7ea 100755
--- a/gstlal-ugly/bin/gstlal_inspiral_treebank
+++ b/gstlal-ugly/bin/gstlal_inspiral_treebank
@@ -245,7 +245,7 @@ for n, c in enumerate(nodes):
 			row.mass1 = row.mass2
 			row.mass2 = mass1
 		row.process_id = process.process_id
-		if coord_limits[0][0] * 0.98 <= row.mass1 <= coord_limits[0][1] * 1.1 and coord_limits[1][0] * 0.98 <= row.mass2 <= coord_limits[1][1] * 1.1 and (row.mass1+row.mass2 < args.max_mtotal) and (row.mass1 / row.mass2 < args.max_q):
+		if coord_limits[0][0] * 0.97 <= row.mass1 <= coord_limits[0][1] * 1.1 and coord_limits[1][0] * 0.97 <= row.mass2 <= coord_limits[1][1] * 1.1 and (row.mass1+row.mass2 < args.max_mtotal) and (row.mass1 / row.mass2 < args.max_q):
 		#if (row.mass1+row.mass2 < args.max_mtotal) and (row.mass1 / row.mass2 < args.max_q):
 			tbl.append(row)
 			previous_tiles.append(t)
diff --git a/gstlal-ugly/python/tree.py b/gstlal-ugly/python/tree.py
index 685f16bb42..2fa185babb 100644
--- a/gstlal-ugly/python/tree.py
+++ b/gstlal-ugly/python/tree.py
@@ -68,7 +68,7 @@ def packing_density(n):
 	# this packing density puts two in a cell, we split if there is more
 	# than this expected in a cell
 	# From: http://mathworld.wolfram.com/HyperspherePacking.html
-	prefactor = 1./2**.5
+	prefactor = 1.0
 	if n==1:
 		return prefactor
 	if n==2:
@@ -308,23 +308,23 @@ class Node(object):
 			numtmps = self.cube.num_templates(mismatch)
 
 
-			#metric_diff = self.cube.metric_tensor - self.sibling.cube.metric_tensor
-			#metric_diff = numpy.linalg.norm(metric_diff) / numpy.linalg.norm(self.cube.metric_tensor)**.5 / numpy.linalg.norm(self.sibling.cube.metric_tensor)**.5
-			#metric_diff2 = self.cube.metric_tensor - self.parent.cube.metric_tensor
-			#metric_diff2 = numpy.linalg.norm(metric_diff2) / numpy.linalg.norm(self.cube.metric_tensor)**.5 / numpy.linalg.norm(self.parent.cube.metric_tensor)**.5
-			#metric_diff = max(metric_diff, metric_diff2)
+			metric_diff = self.cube.metric_tensor - self.sibling.cube.metric_tensor
+			metric_diff = numpy.linalg.norm(metric_diff) / numpy.linalg.norm(self.cube.metric_tensor)**.5 / numpy.linalg.norm(self.sibling.cube.metric_tensor)**.5
+			metric_diff2 = self.cube.metric_tensor - self.parent.cube.metric_tensor
+			metric_diff2 = numpy.linalg.norm(metric_diff2) / numpy.linalg.norm(self.cube.metric_tensor)**.5 / numpy.linalg.norm(self.parent.cube.metric_tensor)**.5
+			metric_diff = max(metric_diff, metric_diff2)
 
 			#metric_cond = (not self.cube.metric_is_valid) or (metric_diff > metric_tol) or (sib_numtmps + numtmps > (1.0 + metric_tol) * par_numtmps) or (numtmps > (1.0 + metric_tol) * sib_numtmps)
 
-			metric_diff = max(abs(self.sibling.cube.eigv - self.cube.eigv) / (self.sibling.cube.eigv + self.cube.eigv) / 2.)
+			#metric_diff = max(abs(self.sibling.cube.eigv - self.cube.eigv) / (self.sibling.cube.eigv + self.cube.eigv) / 2.)
 			# take the bigger of self, sibling and parent
 			numtmps = max(max(numtmps, par_numtmps/2.0), sib_numtmps)# * aspect_factor
 
 		#if self.cube.constraint_func(self.cube.vertices + [self.cube.center]) and ((numtmps >= split_num_templates) or (numtmps >= split_num_templates/2.0 and metric_cond)):
-		if self.cube.constraint_func(self.cube.vertices + [self.cube.center]) and ((numtmps >= split_num_templates) or (False > mismatch and numtmps > split_num_templates/2.0)) or bifurcation < 2:
+		if self.cube.constraint_func(self.cube.vertices + [self.cube.center]) and ((numtmps >= split_num_templates) or (metric_diff > 0.05 and numtmps > split_num_templates/2.0**.5)) or bifurcation < 2:
 			bifurcation += 1
 			#if False:# (self.cube.num_templates(0.02) < len(size)**2/2. or numtmps < 2 * split_num_templates) and metric_diff < 0.1:
-			if (numtmps < 2**len(size) * self.cube.num_templates(0.003)) and metric_diff < 0.003:
+			if metric_diff <= 0.05:
 			#if self.cube.metric_is_valid:# and aspect_factor <= 1.0:
 			#if not metric_cond:
 			#if metric_diff <= metric_tol and self.cube.metric_is_valid:# and aspect_factor <= 1.0:
-- 
GitLab