From 573db214a5527185d0f64d1a7ec07bfa6c48369e Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Thu, 8 Nov 2018 17:21:19 -0600
Subject: [PATCH] make time marginalised likelihood use prior

---
 CHANGELOG.md               |  1 +
 bilby/core/utils.py        |  2 +-
 bilby/gw/likelihood.py     | 34 ++++++++++------
 test/gw_likelihood_test.py | 83 +++++++++++++++++++++++++++-----------
 4 files changed, 83 insertions(+), 37 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 896d484c..fdd2375d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -24,6 +24,7 @@ Changes currently on master, but not under a tag.
   compatibility were too much. Note, working in only python 2 or 3, we do not
   expect users to encounter issues.
 - Intermediate data products of samples, nested_samples are stored in the h5
+- Time marginalised GravitationalWaveTransient works with arbitrary time priors.
 
 ## [0.3.1] 2018-11-06
 
diff --git a/bilby/core/utils.py b/bilby/core/utils.py
index 89d3a53e..6d68f352 100644
--- a/bilby/core/utils.py
+++ b/bilby/core/utils.py
@@ -711,5 +711,5 @@ else:
             matplotlib.use(backend, warn=False)
             plt.switch_backend(backend)
             break
-        except Exception as e:
+        except Exception:
             print(traceback.format_exc())
diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py
index e775cd63..eaa6d9da 100644
--- a/bilby/gw/likelihood.py
+++ b/bilby/gw/likelihood.py
@@ -169,8 +169,8 @@ class GravitationalWaveTransient(likelihood.Likelihood):
             if self.time_marginalization:
                 matched_filter_snr_squared_tc_array +=\
                     4 / self.waveform_generator.duration * np.fft.fft(
-                        signal_ifo.conjugate()[0:-1] *
-                        interferometer.frequency_domain_strain[0:-1] /
+                        signal_ifo[0:-1] *
+                        interferometer.frequency_domain_strain.conjugate()[0:-1] /
                         interferometer.power_spectral_density_array[0:-1])
 
         if self.time_marginalization:
@@ -181,18 +181,21 @@ class GravitationalWaveTransient(likelihood.Likelihood):
                 if self.phase_marginalization:
                     dist_marged_log_l_tc_array = self._interp_dist_margd_loglikelihood(
                         abs(rho_mf_ref_tc_array), rho_opt_ref)
-                    log_l = logsumexp(dist_marged_log_l_tc_array) + self.tc_log_norm
+                    log_l = logsumexp(dist_marged_log_l_tc_array,
+                                      b=self.time_prior_array)
                 else:
                     dist_marged_log_l_tc_array = self._interp_dist_margd_loglikelihood(
                         rho_mf_ref_tc_array.real, rho_opt_ref)
-                    log_l = logsumexp(dist_marged_log_l_tc_array) + self.tc_log_norm
+                    log_l = logsumexp(dist_marged_log_l_tc_array,
+                                      b=self.time_prior_array)
             elif self.phase_marginalization:
-                log_l = (
-                    logsumexp(self._bessel_function_interped(abs(matched_filter_snr_squared_tc_array))) -
-                    optimal_snr_squared / 2 + self.tc_log_norm)
+                log_l = logsumexp(self._bessel_function_interped(abs(
+                    matched_filter_snr_squared_tc_array)),
+                    b=self.time_prior_array) - optimal_snr_squared / 2
             else:
-                log_l = (logsumexp(matched_filter_snr_squared_tc_array.real) +
-                         self.tc_log_norm - optimal_snr_squared / 2)
+                log_l = logsumexp(
+                    matched_filter_snr_squared_tc_array.real,
+                    b=self.time_prior_array) - optimal_snr_squared / 2
 
         elif self.distance_marginalization:
             rho_mf_ref, rho_opt_ref = self._setup_rho(matched_filter_snr_squared, optimal_snr_squared)
@@ -271,12 +274,19 @@ class GravitationalWaveTransient(likelihood.Likelihood):
 
     def _setup_phase_marginalization(self):
         self._bessel_function_interped = interp1d(
-            np.logspace(-5, 10, int(1e6)), np.log([i0e(snr) for snr in np.logspace(-5, 10, int(1e6))]) +
-            np.logspace(-5, 10, int(1e6)), bounds_error=False, fill_value=(0, np.nan))
+            np.logspace(-5, 10, int(1e6)), np.logspace(-5, 10, int(1e6)) +
+            np.log([i0e(snr) for snr in np.logspace(-5, 10, int(1e6))]),
+            bounds_error=False, fill_value=(0, np.nan))
 
     def _setup_time_marginalization(self):
         delta_tc = 2 / self.waveform_generator.sampling_frequency
-        self.tc_log_norm = np.log(delta_tc / self.waveform_generator.duration)
+        times =\
+            self.interferometers.start_time + np.linspace(
+                0, self.interferometers.duration,
+                int(self.interferometers.duration / 2 *
+                    self.waveform_generator.sampling_frequency) + 1)[1:]
+        self.time_prior_array =\
+            self.prior['geocent_time'].prob(times) * delta_tc
 
 
 class BasicGravitationalWaveTransient(likelihood.Likelihood):
diff --git a/test/gw_likelihood_test.py b/test/gw_likelihood_test.py
index 133fadde..989aaa65 100644
--- a/test/gw_likelihood_test.py
+++ b/test/gw_likelihood_test.py
@@ -146,35 +146,27 @@ class TestTimeMarginalization(unittest.TestCase):
         self.parameters = dict(
             mass_1=31., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.0, tilt_2=0.0,
             phi_12=1.7, phi_jl=0.3, luminosity_distance=4000., iota=0.4,
-            psi=2.659, phase=1.3, geocent_time=1126259642.413, ra=1.375,
+            psi=2.659, phase=1.3, geocent_time=1126259640, ra=1.375,
             dec=-1.2108)
 
         self.interferometers = bilby.gw.detector.InterferometerList(['H1'])
         self.interferometers.set_strain_data_from_power_spectral_densities(
-            sampling_frequency=self.sampling_frequency, duration=self.duration)
+            sampling_frequency=self.sampling_frequency, duration=self.duration,
+            start_time=1126259640)
 
         self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
             duration=self.duration, sampling_frequency=self.sampling_frequency,
             frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-        )
+            start_time=1126259640)
 
         self.prior = bilby.gw.prior.BBHPriorDict()
-        self.prior['geocent_time'] = bilby.prior.Uniform(
-            minimum=self.parameters['geocent_time'] - self.duration / 2,
-            maximum=self.parameters['geocent_time'] + self.duration / 2)
 
         self.likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
             interferometers=self.interferometers,
             waveform_generator=self.waveform_generator, prior=self.prior.copy()
         )
 
-        self.time = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=self.interferometers,
-            waveform_generator=self.waveform_generator,
-            time_marginalization=True, prior=self.prior.copy()
-        )
-        for like in [self.likelihood, self.time]:
-            like.parameters = self.parameters.copy()
+        self.likelihood.parameters = self.parameters.copy()
 
     def tearDown(self):
         del self.duration
@@ -184,20 +176,62 @@ class TestTimeMarginalization(unittest.TestCase):
         del self.waveform_generator
         del self.prior
         del self.likelihood
-        del self.time
 
-    def test_time_marginalisation(self):
-        """Test time marginalised likelihood matches brute force version"""
-        like = []
-        times = np.linspace(self.prior['geocent_time'].minimum,
-                            self.prior['geocent_time'].maximum, 4097)[:-1]
+    def test_time_marginalisation_full_segment(self):
+        """
+        Test time marginalised likelihood matches brute force version over the
+        whole segment.
+        """
+        likes = []
+        lls = []
+        self.prior['geocent_time'] = bilby.prior.Uniform(
+            minimum=self.waveform_generator.start_time,
+            maximum=self.waveform_generator.start_time + self.duration)
+        self.time = bilby.gw.likelihood.GravitationalWaveTransient(
+            interferometers=self.interferometers,
+            waveform_generator=self.waveform_generator,
+            time_marginalization=True, prior=self.prior.copy()
+        )
+        times = self.waveform_generator.start_time + np.linspace(
+            0, self.duration, 4097)[:-1]
         for time in times:
             self.likelihood.parameters['geocent_time'] = time
-            like.append(np.exp(self.likelihood.log_likelihood_ratio()))
+            lls.append(self.likelihood.log_likelihood_ratio())
+            likes.append(np.exp(lls[-1]))
 
-        marg_like = np.log(np.trapz(like, times)
-                           / self.waveform_generator.duration)
+        marg_like = np.log(np.trapz(
+            likes * self.prior['geocent_time'].prob(times), times))
         self.time.parameters = self.parameters.copy()
+        self.time.parameters['geocent_time'] = self.waveform_generator.start_time
+        self.assertAlmostEqual(marg_like, self.time.log_likelihood_ratio(),
+                               delta=0.5)
+
+    def test_time_marginalisation_partial_segment(self):
+        """
+        Test time marginalised likelihood matches brute force version over the
+        whole segment.
+        """
+        likes = []
+        lls = []
+        self.prior['geocent_time'] = bilby.prior.Uniform(
+            minimum=self.parameters['geocent_time'] + 1 - 0.1,
+            maximum=self.parameters['geocent_time'] + 1 + 0.1)
+        self.time = bilby.gw.likelihood.GravitationalWaveTransient(
+            interferometers=self.interferometers,
+            waveform_generator=self.waveform_generator,
+            time_marginalization=True, prior=self.prior.copy()
+        )
+        times = self.waveform_generator.start_time + np.linspace(
+            0, self.duration, 4097)[:-1]
+        for time in times:
+            self.likelihood.parameters['geocent_time'] = time
+            lls.append(self.likelihood.log_likelihood_ratio())
+            likes.append(np.exp(lls[-1]))
+
+        marg_like = np.log(np.trapz(
+            likes * self.prior['geocent_time'].prob(times), times))
+        self.time.parameters = self.parameters.copy()
+        self.time.parameters['geocent_time'] = self.waveform_generator.start_time
         self.assertAlmostEqual(marg_like, self.time.log_likelihood_ratio(),
                                delta=0.5)
 
@@ -343,12 +377,13 @@ class TestTimePhaseMarginalization(unittest.TestCase):
 
         self.interferometers = bilby.gw.detector.InterferometerList(['H1'])
         self.interferometers.set_strain_data_from_power_spectral_densities(
-            sampling_frequency=self.sampling_frequency, duration=self.duration)
+            sampling_frequency=self.sampling_frequency, duration=self.duration,
+            start_time=1126259640)
 
         self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
             duration=self.duration, sampling_frequency=self.sampling_frequency,
             frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-        )
+            start_time=1126259640)
 
         self.prior = bilby.gw.prior.BBHPriorDict()
         self.prior['geocent_time'] = bilby.prior.Uniform(
-- 
GitLab