From 13df713f07a8f39b81542a7dcda5b79be0b2e8c8 Mon Sep 17 00:00:00 2001
From: Soichiro Morisaki <soichiro.morisaki@ligo.org>
Date: Fri, 4 Feb 2022 15:38:51 +0000
Subject: [PATCH] Implement time marginalization into ROQ likelihood

---
 bilby/gw/likelihood/base.py |   4 +-
 bilby/gw/likelihood/roq.py  | 179 ++++++++++++++++++++++++++---------
 test/gw/likelihood_test.py  | 182 +++++++++++++++++++++++++++++++++++-
 3 files changed, 315 insertions(+), 50 deletions(-)

diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py
index d6c9f6f14..a75dded56 100644
--- a/bilby/gw/likelihood/base.py
+++ b/bilby/gw/likelihood/base.py
@@ -378,9 +378,7 @@ class GravitationalWaveTransient(Likelihood):
         elif self.time_marginalization:
             if self.jitter_time:
                 self.parameters['geocent_time'] += self.parameters['time_jitter']
-            d_inner_h_array = np.zeros(
-                len(self.interferometers.frequency_array[0:-1]),
-                dtype=np.complex128)
+            d_inner_h_array = np.zeros(len(self._times), dtype=np.complex128)
 
         elif self.calibration_marginalization:
             d_inner_h_array = np.zeros(self.number_of_response_curves, dtype=np.complex128)
diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py
index 5d459082a..ce3e679f5 100644
--- a/bilby/gw/likelihood/roq.py
+++ b/bilby/gw/likelihood/roq.py
@@ -44,6 +44,20 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
         A dictionary of priors containing at least the geocent_time prior
         Warning: when using marginalisation the dict is overwritten which will change the
         the dict you are passing in. If this behaviour is undesired, pass `priors.copy()`.
+    time_marginalization: bool, optional
+        If true, marginalize over time in the likelihood.
+        The spacing of time samples can be specified through delta_tc.
+        If using time marginalisation and jitter_time is True a "jitter"
+        parameter is added to the prior which modifies the position of the
+        grid of times.
+    jitter_time: bool, optional
+        Whether to introduce a `time_jitter` parameter. This avoids either
+        missing the likelihood peak, or introducing biases in the
+        reconstructed time posterior due to an insufficient sampling frequency.
+        Default is False, however using this parameter is strongly encouraged.
+    delta_tc: float, optional
+        The spacing of time samples for time marginalization. If not specified,
+        it is determined based on the signal-to-noise ratio of signal.
     distance_marginalization_lookup_table: (dict, str), optional
         If a dict, dictionary containing the lookup_table, distance_array,
         (distance) prior_array, and reference_distance used to construct
@@ -71,18 +85,20 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
             weights=None, linear_matrix=None, quadratic_matrix=None,
             roq_params=None, roq_params_check=True, roq_scale_factor=1,
             distance_marginalization=False, phase_marginalization=False,
+            time_marginalization=False, jitter_time=True, delta_tc=None,
             distance_marginalization_lookup_table=None,
             reference_frame="sky", time_reference="geocenter"
 
     ):
+        self._delta_tc = delta_tc
         super(ROQGravitationalWaveTransient, self).__init__(
             interferometers=interferometers,
             waveform_generator=waveform_generator, priors=priors,
             distance_marginalization=distance_marginalization,
             phase_marginalization=phase_marginalization,
-            time_marginalization=False,
+            time_marginalization=time_marginalization,
             distance_marginalization_lookup_table=distance_marginalization_lookup_table,
-            jitter_time=False,
+            jitter_time=jitter_time,
             reference_frame=reference_frame,
             time_reference=time_reference
         )
@@ -117,6 +133,19 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
         self.frequency_nodes_quadratic = \
             waveform_generator.waveform_arguments['frequency_nodes_quadratic']
 
+    def _setup_time_marginalization(self):
+        if self._delta_tc is None:
+            self._delta_tc = self._get_time_resolution()
+        tcmin = self.priors['geocent_time'].minimum
+        tcmax = self.priors['geocent_time'].maximum
+        number_of_time_samples = int(np.ceil((tcmax - tcmin) / self._delta_tc))
+        # adjust delta tc so that the last time sample has an equal weight
+        self._delta_tc = (tcmax - tcmin) / number_of_time_samples
+        logger.info(
+            "delta tc for time marginalization = {} seconds.".format(self._delta_tc))
+        self._times = tcmin + self._delta_tc / 2. + np.arange(number_of_time_samples) * self._delta_tc
+        self._beam_pattern_reference_time = (tcmin + tcmax) / 2.
+
     def calculate_snrs(self, waveform_polarizations, interferometer):
         """
         Compute the snrs for ROQ
@@ -128,18 +157,21 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
 
         """
 
-        f_plus = interferometer.antenna_response(
-            self.parameters['ra'], self.parameters['dec'],
-            self.parameters['geocent_time'], self.parameters['psi'], 'plus')
-        f_cross = interferometer.antenna_response(
-            self.parameters['ra'], self.parameters['dec'],
-            self.parameters['geocent_time'], self.parameters['psi'], 'cross')
-
-        dt = interferometer.time_delay_from_geocenter(
-            self.parameters['ra'], self.parameters['dec'],
-            self.parameters['geocent_time'])
-        dt_geocent = self.parameters['geocent_time'] - interferometer.strain_data.start_time
-        ifo_time = dt_geocent + dt
+        if self.time_marginalization:
+            time_ref = self._beam_pattern_reference_time
+        else:
+            time_ref = self.parameters['geocent_time']
+
+        h_linear = np.zeros(len(self.frequency_nodes_linear), dtype=complex)
+        h_quadratic = np.zeros(len(self.frequency_nodes_quadratic), dtype=complex)
+        for mode in waveform_polarizations['linear']:
+            response = interferometer.antenna_response(
+                self.parameters['ra'], self.parameters['dec'],
+                self.parameters['geocent_time'], self.parameters['psi'],
+                mode
+            )
+            h_linear += waveform_polarizations['linear'][mode] * response
+            h_quadratic += waveform_polarizations['quadratic'][mode] * response
 
         calib_linear = interferometer.calibration_model.get_calibration_factor(
             self.frequency_nodes_linear,
@@ -148,46 +180,56 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
             self.frequency_nodes_quadratic,
             prefix='recalib_{}_'.format(interferometer.name), **self.parameters)
 
-        h_plus_linear = f_plus * waveform_polarizations['linear']['plus'] * calib_linear
-        h_cross_linear = f_cross * waveform_polarizations['linear']['cross'] * calib_linear
-        h_plus_quadratic = (
-            f_plus * waveform_polarizations['quadratic']['plus'] * calib_quadratic
-        )
-        h_cross_quadratic = (
-            f_cross * waveform_polarizations['quadratic']['cross'] * calib_quadratic
-        )
+        h_linear *= calib_linear
+        h_quadratic *= calib_quadratic
 
-        indices, in_bounds = self._closest_time_indices(
-            ifo_time, self.weights['time_samples'])
-        if not in_bounds:
-            logger.debug("SNR calculation error: requested time at edge of ROQ time samples")
-            return self._CalculatedSNRs(
-                d_inner_h=np.nan_to_num(-np.inf), optimal_snr_squared=0,
-                complex_matched_filter_snr=np.nan_to_num(-np.inf),
-                d_inner_h_squared_tc_array=None,
-                d_inner_h_array=None,
-                optimal_snr_squared_array=None)
+        optimal_snr_squared = \
+            np.vdot(np.abs(h_quadratic)**2, self.weights[interferometer.name + '_quadratic'])
+
+        dt = interferometer.time_delay_from_geocenter(
+            self.parameters['ra'], self.parameters['dec'], time_ref)
 
-        d_inner_h_tc_array = np.einsum(
-            'i,ji->j', np.conjugate(h_plus_linear + h_cross_linear),
-            self.weights[interferometer.name + '_linear'][indices])
+        if not self.time_marginalization:
+            dt_geocent = self.parameters['geocent_time'] - interferometer.strain_data.start_time
+            ifo_time = dt_geocent + dt
 
-        d_inner_h = self._interp_five_samples(
-            self.weights['time_samples'][indices], d_inner_h_tc_array, ifo_time)
+            indices, in_bounds = self._closest_time_indices(
+                ifo_time, self.weights['time_samples'])
+            if not in_bounds:
+                logger.debug("SNR calculation error: requested time at edge of ROQ time samples")
+                return self._CalculatedSNRs(
+                    d_inner_h=np.nan_to_num(-np.inf), optimal_snr_squared=0,
+                    complex_matched_filter_snr=np.nan_to_num(-np.inf),
+                    d_inner_h_squared_tc_array=None,
+                    d_inner_h_array=None,
+                    optimal_snr_squared_array=None)
 
-        optimal_snr_squared = \
-            np.vdot(np.abs(h_plus_quadratic + h_cross_quadratic)**2,
-                    self.weights[interferometer.name + '_quadratic'])
+            d_inner_h_tc_array = np.einsum(
+                'i,ji->j', np.conjugate(h_linear),
+                self.weights[interferometer.name + '_linear'][indices])
+
+            d_inner_h = self._interp_five_samples(
+                self.weights['time_samples'][indices], d_inner_h_tc_array, ifo_time)
+
+            with np.errstate(invalid="ignore"):
+                complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5)
+
+            d_inner_h_array = None
 
-        with np.errstate(invalid="ignore"):
-            complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5)
-        d_inner_h_squared_tc_array = None
+        else:
+            ifo_times = self._times - interferometer.strain_data.start_time + dt
+            if self.jitter_time:
+                ifo_times += self.parameters['time_jitter']
+            d_inner_h_array = self._calculate_d_inner_h_array(ifo_times, h_linear, interferometer.name)
+
+            d_inner_h = 0.
+            complex_matched_filter_snr = 0.
 
         return self._CalculatedSNRs(
             d_inner_h=d_inner_h, optimal_snr_squared=optimal_snr_squared,
             complex_matched_filter_snr=complex_matched_filter_snr,
-            d_inner_h_squared_tc_array=d_inner_h_squared_tc_array,
-            d_inner_h_array=None,
+            d_inner_h_squared_tc_array=None,
+            d_inner_h_array=d_inner_h_array,
             optimal_snr_squared_array=None)
 
     @staticmethod
@@ -242,6 +284,53 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
         d = (b**3. - b) / 6.
         return a * values[2] + b * values[3] + c * r1 + d * r2
 
+    def _calculate_d_inner_h_array(self, times, h_linear, ifo_name):
+        """
+        Calculate d_inner_h at regularly-spaced time samples. Each value is
+        interpolated from the nearest 5 samples with the algorithm explained in
+        https://dcc.ligo.org/T2100224.
+
+        Parameters
+        ==========
+        times: array-like
+            Regularly-spaced time samples at which d_inner_h are calculated.
+        h_linear: array-like
+            Waveforms at linear frequency nodes
+        ifo_name: str
+
+        Returns
+        =======
+        d_inner_h_array: array-like
+        """
+        roq_time_space = self.weights['time_samples'][1] - self.weights['time_samples'][0]
+        times_per_roq_time_space = (times - self.weights['time_samples'][0]) / roq_time_space
+        closest_idxs = np.floor(times_per_roq_time_space).astype(int)
+        # Get the nearest 5 samples of d_inner_h. Calculate only the required d_inner_h values if the time
+        # spacing is larger than 5 times the ROQ time spacing.
+        weights_linear = self.weights[ifo_name + '_linear']
+        h_linear_conj = np.conjugate(h_linear)
+        if (times[1] - times[0]) / roq_time_space > 5:
+            d_inner_h_m2 = np.dot(weights_linear[closest_idxs - 2], h_linear_conj)
+            d_inner_h_m1 = np.dot(weights_linear[closest_idxs - 1], h_linear_conj)
+            d_inner_h_0 = np.dot(weights_linear[closest_idxs], h_linear_conj)
+            d_inner_h_p1 = np.dot(weights_linear[closest_idxs + 1], h_linear_conj)
+            d_inner_h_p2 = np.dot(weights_linear[closest_idxs + 2], h_linear_conj)
+        else:
+            d_inner_h_at_roq_time_samples = np.dot(weights_linear, h_linear_conj)
+            d_inner_h_m2 = d_inner_h_at_roq_time_samples[closest_idxs - 2]
+            d_inner_h_m1 = d_inner_h_at_roq_time_samples[closest_idxs - 1]
+            d_inner_h_0 = d_inner_h_at_roq_time_samples[closest_idxs]
+            d_inner_h_p1 = d_inner_h_at_roq_time_samples[closest_idxs + 1]
+            d_inner_h_p2 = d_inner_h_at_roq_time_samples[closest_idxs + 2]
+        # quantities required for spline interpolation
+        b = times_per_roq_time_space - closest_idxs
+        a = 1. - b
+        c = (a**3. - a) / 6.
+        d = (b**3. - b) / 6.
+        r1 = (-d_inner_h_m2 + 8. * d_inner_h_m1 - 14. * d_inner_h_0 + 8. * d_inner_h_p1 - d_inner_h_p2) / 4.
+        r2 = d_inner_h_0 - 2. * d_inner_h_p1 + d_inner_h_p2
+        return a * d_inner_h_0 + b * d_inner_h_p1 + c * r1 + d * r2
+
     def perform_roq_params_check(self, ifo=None):
         """ Perform checking that the prior and data are valid for the ROQ
 
diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py
index f2c6e7ec5..6e34f5872 100644
--- a/test/gw/likelihood_test.py
+++ b/test/gw/likelihood_test.py
@@ -428,6 +428,9 @@ class TestMarginalizations(unittest.TestCase):
     The `time_jitter` parameter makes this a weaker dependence during sampling.
     """
 
+    lookup_phase = "distance_lookup_phase.npz"
+    lookup_no_phase = "distance_lookup_no_phase.npz"
+
     def setUp(self):
         np.random.seed(500)
         self.duration = 4
@@ -482,6 +485,13 @@ class TestMarginalizations(unittest.TestCase):
         del self.waveform_generator
         del self.priors
 
+    @classmethod
+    def tearDownClass(cls):
+        # remove lookup tables so that they are not used accidentally in subsequent tests
+        for filename in [cls.lookup_phase, cls.lookup_no_phase]:
+            if os.path.exists(filename):
+                os.remove(filename)
+
     def get_likelihood(
         self,
         time_marginalization=False,
@@ -492,9 +502,9 @@ class TestMarginalizations(unittest.TestCase):
         if priors is None:
             priors = self.priors.copy()
         if distance_marginalization and phase_marginalization:
-            lookup = "distance_lookup_phase.npz"
+            lookup = TestMarginalizations.lookup_phase
         elif distance_marginalization:
-            lookup = "distance_lookup_no_phase.npz"
+            lookup = TestMarginalizations.lookup_no_phase
         else:
             lookup = None
         like = bilby.gw.likelihood.GravitationalWaveTransient(
@@ -644,6 +654,174 @@ class TestMarginalizations(unittest.TestCase):
         )
 
 
+class TestMarginalizationsROQ(TestMarginalizations):
+
+    lookup_phase = "distance_lookup_phase.npz"
+    lookup_no_phase = "distance_lookup_no_phase.npz"
+    path_to_roq_weights = "weights.npz"
+
+    def setUp(self):
+        np.random.seed(500)
+        self.duration = 4
+        self.sampling_frequency = 2048
+        self.parameters = dict(
+            mass_1=31.0,
+            mass_2=29.0,
+            a_1=0.4,
+            a_2=0.3,
+            tilt_1=0.0,
+            tilt_2=0.0,
+            phi_12=1.7,
+            phi_jl=0.3,
+            luminosity_distance=4000.0,
+            theta_jn=0.4,
+            psi=2.659,
+            phase=1.3,
+            geocent_time=1126259642.413,
+            ra=1.375,
+            dec=-1.2108,
+            time_jitter=0,
+        )
+
+        self.interferometers = bilby.gw.detector.InterferometerList(["H1"])
+        self.interferometers.set_strain_data_from_power_spectral_densities(
+            sampling_frequency=self.sampling_frequency,
+            duration=self.duration,
+            start_time=1126259640,
+        )
+
+        waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
+            duration=self.duration,
+            sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
+            start_time=1126259640,
+            waveform_arguments=dict(
+                reference_frequency=20.0,
+                minimum_frequency=20.0,
+                approximant="IMRPhenomPv2"
+            )
+        )
+        self.interferometers.inject_signal(
+            parameters=self.parameters, waveform_generator=waveform_generator
+        )
+
+        self.priors = bilby.gw.prior.BBHPriorDict()
+        # prior range should be a part of segment since ROQ likelihood can not
+        # calculate values at samples close to edges
+        self.priors["geocent_time"] = bilby.prior.Uniform(
+            minimum=self.parameters["geocent_time"] - 0.1,
+            maximum=self.parameters["geocent_time"] + 0.1
+        )
+
+        # Possible locations for the ROQ: in the docker image, local, or on CIT
+        trial_roq_paths = [
+            "/roq_basis",
+            os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"),
+            "/home/cbc/ROQ_data/IMRPhenomPv2/4s",
+        ]
+        roq_dir = None
+        for path in trial_roq_paths:
+            if os.path.isdir(path):
+                roq_dir = path
+                break
+        if roq_dir is None:
+            raise Exception("Unable to load ROQ basis: cannot proceed with tests")
+
+        self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
+            duration=self.duration,
+            sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq,
+            start_time=1126259640,
+            waveform_arguments=dict(
+                reference_frequency=20.0,
+                minimum_frequency=20.0,
+                approximant="IMRPhenomPv2",
+                frequency_nodes_linear=np.load("{}/fnodes_linear.npy".format(roq_dir)),
+                frequency_nodes_quadratic=np.load("{}/fnodes_quadratic.npy".format(roq_dir)),
+            )
+        )
+        self.roq_linear_matrix_file = "{}/B_linear.npy".format(roq_dir)
+        self.roq_quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir)
+
+    @classmethod
+    def tearDownClass(cls):
+        for filename in [cls.lookup_phase, cls.lookup_no_phase, cls.path_to_roq_weights]:
+            if os.path.exists(filename):
+                os.remove(filename)
+
+    def get_likelihood(
+        self,
+        time_marginalization=False,
+        phase_marginalization=False,
+        distance_marginalization=False,
+        priors=None
+    ):
+        if priors is None:
+            priors = self.priors.copy()
+        if distance_marginalization and phase_marginalization:
+            lookup = TestMarginalizationsROQ.lookup_phase
+        elif distance_marginalization:
+            lookup = TestMarginalizationsROQ.lookup_no_phase
+        else:
+            lookup = None
+        kwargs = dict(
+            interferometers=self.interferometers,
+            waveform_generator=self.waveform_generator,
+            distance_marginalization=distance_marginalization,
+            phase_marginalization=phase_marginalization,
+            time_marginalization=time_marginalization,
+            distance_marginalization_lookup_table=lookup,
+            priors=priors
+        )
+        if os.path.exists(TestMarginalizationsROQ.path_to_roq_weights):
+            kwargs.update(dict(weights=TestMarginalizationsROQ.path_to_roq_weights))
+            like = bilby.gw.likelihood.ROQGravitationalWaveTransient(**kwargs)
+        else:
+            kwargs.update(
+                dict(
+                    linear_matrix=self.roq_linear_matrix_file,
+                    quadratic_matrix=self.roq_quadratic_matrix_file
+                )
+            )
+            like = bilby.gw.likelihood.ROQGravitationalWaveTransient(**kwargs)
+            like.save_weights(TestMarginalizationsROQ.path_to_roq_weights)
+        like.parameters = self.parameters.copy()
+        if time_marginalization:
+            like.parameters["geocent_time"] = self.interferometers.start_time
+        return like
+
+    def test_time_marginalisation(self):
+        self._template(
+            self.get_likelihood(time_marginalization=True),
+            self.get_likelihood(),
+            key="geocent_time",
+        )
+
+    def test_time_distance_marginalisation(self):
+        self._template(
+            self.get_likelihood(time_marginalization=True, distance_marginalization=True),
+            self.get_likelihood(distance_marginalization=True),
+            key="geocent_time",
+        )
+
+    def test_time_phase_marginalisation(self):
+        self._template(
+            self.get_likelihood(time_marginalization=True, phase_marginalization=True),
+            self.get_likelihood(phase_marginalization=True),
+            key="geocent_time",
+        )
+
+    def test_time_distance_phase_marginalisation(self):
+        self._template(
+            self.get_likelihood(time_marginalization=True, phase_marginalization=True, distance_marginalization=True),
+            self.get_likelihood(phase_marginalization=True, distance_marginalization=True),
+            key="geocent_time",
+        )
+
+    def test_time_marginalisation_partial_segment(self):
+        pass
+
+
 class TestROQLikelihood(unittest.TestCase):
     def setUp(self):
         self.duration = 4
-- 
GitLab