From 4820ecc0d86c618fbc25b52136d7c946419353b0 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Wed, 28 Apr 2021 14:51:29 +0000
Subject: [PATCH] Make building distance marginalization lookup table faster

---
 bilby/gw/likelihood.py | 35 ++++++++++++++++++++---------------
 1 file changed, 20 insertions(+), 15 deletions(-)

diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py
index e0558c04d..9e704fd62 100644
--- a/bilby/gw/likelihood.py
+++ b/bilby/gw/likelihood.py
@@ -912,24 +912,29 @@ class GravitationalWaveTransient(Likelihood):
 
     def _create_lookup_table(self):
         """ Make the lookup table """
+        from tqdm.auto import tqdm
         logger.info('Building lookup table for distance marginalisation.')
 
         self._dist_margd_loglikelihood_array = np.zeros((400, 800))
-        for ii, optimal_snr_squared_ref in enumerate(self._optimal_snr_squared_ref_array):
-            optimal_snr_squared_array = (
-                optimal_snr_squared_ref * self._ref_dist ** 2. /
-                self._distance_array ** 2)
-            for jj, d_inner_h_ref in enumerate(self._d_inner_h_ref_array):
-                d_inner_h_array = (
-                    d_inner_h_ref * self._ref_dist / self._distance_array)
-                if self.phase_marginalization:
-                    d_inner_h_array =\
-                        self._bessel_function_interped(abs(d_inner_h_array))
-                self._dist_margd_loglikelihood_array[ii][jj] = \
-                    logsumexp(d_inner_h_array - optimal_snr_squared_array / 2,
-                              b=self.distance_prior_array * self._delta_distance)
-        log_norm = logsumexp(0. / self._distance_array,
-                             b=self.distance_prior_array * self._delta_distance)
+        scaling = self._ref_dist / self._distance_array
+        d_inner_h_array_full = np.outer(self._d_inner_h_ref_array, scaling)
+        h_inner_h_array_full = np.outer(self._optimal_snr_squared_ref_array, scaling ** 2)
+        if self.phase_marginalization:
+            d_inner_h_array_full = self._bessel_function_interped(abs(
+                d_inner_h_array_full
+            ))
+        prior_term = self.distance_prior_array * self._delta_distance
+        for ii, optimal_snr_squared_array in tqdm(
+            enumerate(h_inner_h_array_full), total=len(self._optimal_snr_squared_ref_array)
+        ):
+            for jj, d_inner_h_array in enumerate(d_inner_h_array_full):
+                self._dist_margd_loglikelihood_array[ii][jj] = logsumexp(
+                    d_inner_h_array - optimal_snr_squared_array / 2,
+                    b=prior_term
+                )
+        log_norm = logsumexp(
+            0 / self._distance_array, b=self.distance_prior_array * self._delta_distance
+        )
         self._dist_margd_loglikelihood_array -= log_norm
         self.cache_lookup_table()
 
-- 
GitLab