From 485af24774dedb27bd5623a2318c60d87906a424 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Mon, 1 Jul 2019 00:06:16 -0500
Subject: [PATCH] Fix reconstruction

---
 bilby/gw/likelihood.py | 70 +++++++++++++++++++++++-------------------
 1 file changed, 39 insertions(+), 31 deletions(-)

diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py
index c0eba1b1..ab78a15b 100644
--- a/bilby/gw/likelihood.py
+++ b/bilby/gw/likelihood.py
@@ -18,7 +18,8 @@ from scipy.special import i0e
 from ..core import likelihood
 from ..core.utils import (
     logger, UnsortedInterp2d, BilbyJsonEncoder, decode_bilby_json,
-    create_frequency_series, speed_of_light, radius_of_earth)
+    create_frequency_series, create_time_series, speed_of_light,
+    radius_of_earth)
 from ..core.prior import Interped, Prior, Uniform
 from .detector import InterferometerList
 from .prior import BBHPriorDict
@@ -115,7 +116,8 @@ class GravitationalWaveTransient(likelihood.Likelihood):
             priors['geocent_time'] = float(self.interferometers.start_time)
             if self.jitter_time:
                 priors['time_jitter'] = Uniform(
-                    minimum=- self._delta_tc / 2, maximum=self._delta_tc / 2)
+                    minimum=- self._delta_tc / 2, maximum=self._delta_tc / 2,
+                    boundary='periodic')
         elif self.jitter_time:
             logger.info(
                 "Time jittering requested with non-time-marginalised "
@@ -267,15 +269,11 @@ class GravitationalWaveTransient(likelihood.Likelihood):
                 d_inner_h_tc_array += per_detector_snr.d_inner_h_squared_tc_array
 
         if self.time_marginalization:
-            if self.jitter_time:
-                times = self._times + self.parameters['time_jitter']
-                self.parameters['geocent_time'] -= self.parameters['time_jitter']
-            else:
-                times = self._times
-            self.time_prior_array = self.priors['geocent_time'].prob(times) * self._delta_tc
             log_l = self.time_marginalized_likelihood(
                 d_inner_h_tc_array=d_inner_h_tc_array,
                 h_inner_h=optimal_snr_squared)
+            if self.jitter_time:
+                self.parameters['geocent_time'] -= self.parameters['time_jitter']
 
         elif self.distance_marginalization:
             log_l = self.distance_marginalized_likelihood(
@@ -353,41 +351,47 @@ class GravitationalWaveTransient(likelihood.Likelihood):
         if signal_polarizations is None:
             signal_polarizations = \
                 self.waveform_generator.frequency_domain_strain(self.parameters)
-        d_inner_h = 0.
-        h_inner_h = 0.
-        complex_matched_filter_snr = 0.
-        d_inner_h_tc_array = np.zeros(
-            self.interferometers.frequency_array[0:-1].shape,
-            dtype=np.complex128)
-
-        for interferometer in self.interferometers:
-            per_detector_snr = self.calculate_snrs(
-                signal_polarizations, interferometer)
 
-            d_inner_h += per_detector_snr.d_inner_h
-            h_inner_h += per_detector_snr.optimal_snr_squared
-            complex_matched_filter_snr += per_detector_snr.complex_matched_filter_snr
-
-            if self.time_marginalization:
-                d_inner_h_tc_array += per_detector_snr.d_inner_h_squared_tc_array
+        n_time_steps = int(self.waveform_generator.duration * 16384)
+        d_inner_h = np.zeros(n_time_steps, dtype=np.complex)
+        psd = np.ones(n_time_steps)
+        signal_long = np.zeros(n_time_steps, dtype=np.complex)
+        data = np.zeros(n_time_steps, dtype=np.complex)
+        h_inner_h = np.zeros(1)
+        for ifo in self.interferometers:
+            ifo_length = len(ifo.frequency_domain_strain)
+            signal = ifo.get_detector_response(
+                signal_polarizations, self.parameters)
+            signal_long[:ifo_length] = signal
+            data[:ifo_length] = np.conj(ifo.frequency_domain_strain)
+            psd[:ifo_length] = ifo.power_spectral_density_array
+            d_inner_h += np.fft.fft(signal_long * data / psd)
+            h_inner_h += ifo.optimal_snr_squared(signal=signal).real
 
         if self.distance_marginalization:
             time_log_like = self.distance_marginalized_likelihood(
                 d_inner_h, h_inner_h)
         elif self.phase_marginalization:
-            time_log_like = (
-                self._bessel_function_interped(abs(d_inner_h_tc_array)) -
-                h_inner_h.real / 2)
+            time_log_like = (self._bessel_function_interped(abs(d_inner_h)) -
+                             h_inner_h.real / 2)
         else:
-            time_log_like = (d_inner_h_tc_array.real - h_inner_h.real / 2)
+            time_log_like = (d_inner_h.real - h_inner_h.real / 2)
 
-        if self.jitter_time:
-            times = self._times + self.parameters['time_jitter']
+        times = create_time_series(
+            sampling_frequency=16384,
+            starting_time=self.parameters['geocent_time'] - self.waveform_generator.start_time,
+            duration=self.waveform_generator.duration)
+        times = times % self.waveform_generator.duration
+        times += self.waveform_generator.start_time
 
         time_prior_array = self.priors['geocent_time'].prob(times)
         time_post = (
             np.exp(time_log_like - max(time_log_like)) * time_prior_array)
 
+        keep = (time_post > max(time_post) / 1000)
+        time_post = time_post[keep]
+        times = times[keep]
+
         new_time = Interped(times, time_post).sample()
         return new_time
 
@@ -513,7 +517,11 @@ class GravitationalWaveTransient(likelihood.Likelihood):
                 h_inner_h=h_inner_h)
         else:
             log_l_tc_array = np.real(d_inner_h_tc_array) - h_inner_h / 2
-        return logsumexp(log_l_tc_array, b=self.time_prior_array)
+        times = self._times
+        if self.jitter_time:
+            times = self._times + self.parameters['time_jitter']
+        time_prior_array = self.priors['geocent_time'].prob(times) * self._delta_tc
+        return logsumexp(log_l_tc_array, b=time_prior_array)
 
     def _setup_rho(self, d_inner_h, optimal_snr_squared):
         optimal_snr_squared_ref = (optimal_snr_squared.real *
-- 
GitLab