Skip to content
Snippets Groups Projects

Make building distance marginalization lookup table faster

Merged Colm Talbot requested to merge accelerate-distance-marginalization-lookup into master
1 file
+ 20
15
Compare changes
  • Side-by-side
  • Inline
+ 20
15
@@ -912,24 +912,29 @@ class GravitationalWaveTransient(Likelihood):
def _create_lookup_table(self):
""" Make the lookup table """
from tqdm.auto import tqdm
logger.info('Building lookup table for distance marginalisation.')
self._dist_margd_loglikelihood_array = np.zeros((400, 800))
for ii, optimal_snr_squared_ref in enumerate(self._optimal_snr_squared_ref_array):
optimal_snr_squared_array = (
optimal_snr_squared_ref * self._ref_dist ** 2. /
self._distance_array ** 2)
for jj, d_inner_h_ref in enumerate(self._d_inner_h_ref_array):
d_inner_h_array = (
d_inner_h_ref * self._ref_dist / self._distance_array)
if self.phase_marginalization:
d_inner_h_array =\
self._bessel_function_interped(abs(d_inner_h_array))
self._dist_margd_loglikelihood_array[ii][jj] = \
logsumexp(d_inner_h_array - optimal_snr_squared_array / 2,
b=self.distance_prior_array * self._delta_distance)
log_norm = logsumexp(0. / self._distance_array,
b=self.distance_prior_array * self._delta_distance)
scaling = self._ref_dist / self._distance_array
d_inner_h_array_full = np.outer(self._d_inner_h_ref_array, scaling)
h_inner_h_array_full = np.outer(self._optimal_snr_squared_ref_array, scaling ** 2)
if self.phase_marginalization:
d_inner_h_array_full = self._bessel_function_interped(abs(
d_inner_h_array_full
))
prior_term = self.distance_prior_array * self._delta_distance
for ii, optimal_snr_squared_array in tqdm(
enumerate(h_inner_h_array_full), total=len(self._optimal_snr_squared_ref_array)
):
for jj, d_inner_h_array in enumerate(d_inner_h_array_full):
self._dist_margd_loglikelihood_array[ii][jj] = logsumexp(
d_inner_h_array - optimal_snr_squared_array / 2,
b=prior_term
)
log_norm = logsumexp(
0 / self._distance_array, b=self.distance_prior_array * self._delta_distance
)
self._dist_margd_loglikelihood_array -= log_norm
self.cache_lookup_table()
Loading