From 6071737e307ed7d16d216d68c61cd59da2fab416 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Tue, 8 Mar 2022 02:29:41 +0000
Subject: [PATCH] Refactor marginalized likelihood tests

---
 test/gw/likelihood_test.py | 384 ++++++++++---------------------------
 1 file changed, 105 insertions(+), 279 deletions(-)

diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py
index 6e34f5872..8ad6cac41 100644
--- a/test/gw/likelihood_test.py
+++ b/test/gw/likelihood_test.py
@@ -1,6 +1,9 @@
+import os
+import pytest
 import unittest
 from copy import deepcopy
-import os
+from itertools import product
+from parameterized import parameterized
 
 import numpy as np
 import bilby
@@ -427,9 +430,17 @@ class TestMarginalizations(unittest.TestCase):
     For time, this is strongly dependent on the specific time grid used.
     The `time_jitter` parameter makes this a weaker dependence during sampling.
     """
+    _parameters = product(
+        ["regular", "roq"],
+        ["luminosity_distance", "geocent_time", "phase"],
+        [True, False],
+        [True, False],
+        [True, False],
+    )
 
     lookup_phase = "distance_lookup_phase.npz"
     lookup_no_phase = "distance_lookup_no_phase.npz"
+    path_to_roq_weights = "weights.npz"
 
     def setUp(self):
         np.random.seed(500)
@@ -466,6 +477,11 @@ class TestMarginalizations(unittest.TestCase):
             sampling_frequency=self.sampling_frequency,
             frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
             start_time=1126259640,
+            waveform_arguments=dict(
+                reference_frequency=20.0,
+                minimum_frequency=20.0,
+                approximant="IMRPhenomPv2",
+            )
         )
         self.interferometers.inject_signal(
             parameters=self.parameters, waveform_generator=self.waveform_generator
@@ -473,9 +489,38 @@ class TestMarginalizations(unittest.TestCase):
 
         self.priors = bilby.gw.prior.BBHPriorDict()
         self.priors["geocent_time"] = bilby.prior.Uniform(
-            minimum=self.interferometers.start_time,
-            maximum=self.interferometers.start_time + self.interferometers.duration,
+            minimum=self.parameters["geocent_time"] - 0.1,
+            maximum=self.parameters["geocent_time"] + 0.1
+        )
+
+        trial_roq_paths = [
+            "/roq_basis",
+            os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"),
+            "/home/cbc/ROQ_data/IMRPhenomPv2/4s",
+        ]
+        roq_dir = None
+        for path in trial_roq_paths:
+            if os.path.isdir(path):
+                roq_dir = path
+                break
+        if roq_dir is None:
+            raise Exception("Unable to load ROQ basis: cannot proceed with tests")
+
+        self.roq_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
+            duration=self.duration,
+            sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq,
+            start_time=1126259640,
+            waveform_arguments=dict(
+                reference_frequency=20.0,
+                minimum_frequency=20.0,
+                approximant="IMRPhenomPv2",
+                frequency_nodes_linear=np.load(f"{roq_dir}/fnodes_linear.npy"),
+                frequency_nodes_quadratic=np.load(f"{roq_dir}/fnodes_quadratic.npy"),
+            )
         )
+        self.roq_linear_matrix_file = f"{roq_dir}/B_linear.npy"
+        self.roq_quadratic_matrix_file = f"{roq_dir}/B_quadratic.npy"
 
     def tearDown(self):
         del self.duration
@@ -483,22 +528,17 @@ class TestMarginalizations(unittest.TestCase):
         del self.parameters
         del self.interferometers
         del self.waveform_generator
+        del self.roq_waveform_generator
         del self.priors
 
     @classmethod
     def tearDownClass(cls):
-        # remove lookup tables so that they are not used accidentally in subsequent tests
-        for filename in [cls.lookup_phase, cls.lookup_no_phase]:
+        """remove lookup tables so that they are not used accidentally in subsequent tests"""
+        for filename in [cls.lookup_phase, cls.lookup_no_phase, cls.path_to_roq_weights]:
             if os.path.exists(filename):
                 os.remove(filename)
 
-    def get_likelihood(
-        self,
-        time_marginalization=False,
-        phase_marginalization=False,
-        distance_marginalization=False,
-        priors=None
-    ):
+    def likelihood_kwargs(self, kind, time_marginalization, phase_marginalization, distance_marginalization, priors):
         if priors is None:
             priors = self.priors.copy()
         if distance_marginalization and phase_marginalization:
@@ -507,7 +547,7 @@ class TestMarginalizations(unittest.TestCase):
             lookup = TestMarginalizations.lookup_no_phase
         else:
             lookup = None
-        like = bilby.gw.likelihood.GravitationalWaveTransient(
+        kwargs = dict(
             interferometers=self.interferometers,
             waveform_generator=self.waveform_generator,
             distance_marginalization=distance_marginalization,
@@ -516,6 +556,36 @@ class TestMarginalizations(unittest.TestCase):
             distance_marginalization_lookup_table=lookup,
             priors=priors,
         )
+        if kind == "roq":
+            kwargs.update(dict(
+                linear_matrix=self.roq_linear_matrix_file,
+                quadratic_matrix=self.roq_quadratic_matrix_file,
+                waveform_generator=self.roq_waveform_generator,
+            ))
+            if os.path.exists(self.__class__.path_to_roq_weights):
+                kwargs["weights"] = self.__class__.path_to_roq_weights
+        return kwargs
+
+    def get_likelihood(
+        self,
+        kind,
+        time_marginalization=False,
+        phase_marginalization=False,
+        distance_marginalization=False,
+        priors=None
+    ):
+        kwargs = self.likelihood_kwargs(
+            kind, time_marginalization, phase_marginalization, distance_marginalization, priors
+        )
+        if kind == "regular":
+            cls_ = bilby.gw.likelihood.GravitationalWaveTransient
+        elif kind == "roq":
+            cls_ = bilby.gw.likelihood.ROQGravitationalWaveTransient
+        else:
+            raise ValueError(f"kind {kind} not understood")
+        like = cls_(**kwargs)
+        if kind == "roq" and not os.path.exists(self.__class__.path_to_roq_weights):
+            like.save_weights(self.__class__.path_to_roq_weights)
         like.parameters = self.parameters.copy()
         if time_marginalization:
             like.parameters["geocent_time"] = self.interferometers.start_time
@@ -538,290 +608,46 @@ class TestMarginalizations(unittest.TestCase):
             marg_like, marginalized.log_likelihood_ratio(), delta=0.5
         )
 
-    def test_distance_marginalisation(self):
+    @parameterized.expand(_parameters)
+    def test_marginalisation(self, kind, key, distance, time, phase):
+        if all([distance, time, phase]):
+            pytest.skip()
+        tested_args = dict(
+            distance_marginalization=distance,
+            time_marginalization=time,
+            phase_marginalization=phase,
+        )
+        marg_key = f"{key.split('_')[-1]}_marginalization"
+        if tested_args[marg_key]:
+            pytest.skip()
+        reference_args = tested_args.copy()
+        reference_args[marg_key] = True
         self._template(
-            self.get_likelihood(distance_marginalization=True),
-            self.get_likelihood(),
-            key="luminosity_distance",
+            self.get_likelihood(kind, **reference_args),
+            self.get_likelihood(kind, **tested_args),
+            key=key,
         )
 
-    def test_distance_phase_marginalisation(self):
-        self._template(
-            self.get_likelihood(distance_marginalization=True, phase_marginalization=True),
-            self.get_likelihood(phase_marginalization=True),
-            key="luminosity_distance",
-        )
-
-    def test_distance_time_marginalisation(self):
-        self._template(
-            self.get_likelihood(distance_marginalization=True, time_marginalization=True),
-            self.get_likelihood(time_marginalization=True),
-            key="luminosity_distance",
-        )
-
-    def test_distance_phase_time_marginalisation(self):
-        """
-        Test phase marginalised likelihood matches brute force version when
-        also marginalising over time.
-        """
-        self._template(
-            self.get_likelihood(distance_marginalization=True, phase_marginalization=True, time_marginalization=True),
-            self.get_likelihood(phase_marginalization=True, time_marginalization=True),
-            key="luminosity_distance",
-        )
-
-    def test_phase_marginalisation(self):
-        self._template(
-            self.get_likelihood(phase_marginalization=True),
-            self.get_likelihood(),
-            key="phase",
-        )
-
-    def test_phase_distance_marginalisation(self):
-        self._template(
-            self.get_likelihood(distance_marginalization=True, phase_marginalization=True),
-            self.get_likelihood(distance_marginalization=True),
-            key="phase",
-        )
-
-    def test_phase_time_marginalisation(self):
-        self._template(
-            self.get_likelihood(time_marginalization=True, phase_marginalization=True),
-            self.get_likelihood(time_marginalization=True),
-            key="phase",
-        )
-
-    def test_phase_distance_time_marginalisation(self):
-        self._template(
-            self.get_likelihood(time_marginalization=True, distance_marginalization=True, phase_marginalization=True),
-            self.get_likelihood(time_marginalization=True, distance_marginalization=True),
-            key="phase",
-        )
-
-    def test_time_marginalisation(self):
-        times = self.waveform_generator.time_array
-        self._template(
-            self.get_likelihood(time_marginalization=True),
-            self.get_likelihood(),
-            key="geocent_time",
-            values=times,
-        )
-
-    def test_time_distance_marginalisation(self):
-        times = self.waveform_generator.time_array
-        self._template(
-            self.get_likelihood(time_marginalization=True, distance_marginalization=True),
-            self.get_likelihood(distance_marginalization=True),
-            key="geocent_time",
-            values=times
-        )
-
-    def test_time_phase_marginalisation(self):
-        times = self.waveform_generator.time_array
-        self._template(
-            self.get_likelihood(time_marginalization=True, phase_marginalization=True),
-            self.get_likelihood(phase_marginalization=True),
-            key="geocent_time",
-            values=times
-        )
-
-    def test_time_distance_phase_marginalisation(self):
-        times = self.waveform_generator.time_array
-        self._template(
-            self.get_likelihood(time_marginalization=True, phase_marginalization=True, distance_marginalization=True),
-            self.get_likelihood(phase_marginalization=True, distance_marginalization=True),
-            key="geocent_time",
-            values=times
-        )
-
-    def test_time_marginalisation_partial_segment(self):
+    def test_time_marginalisation_full_segment(self):
         """
         Test time marginalised likelihood matches brute force version over
         just part of a segment.
         """
         priors = self.priors.copy()
         prior = bilby.prior.Uniform(
-            minimum=self.parameters["geocent_time"] - 0.1,
-            maximum=self.parameters["geocent_time"] + 0.1,
+            minimum=self.interferometers.start_time,
+            maximum=self.interferometers.start_time + self.interferometers.duration,
         )
         priors["geocent_time"] = prior
         self._template(
-            self.get_likelihood(time_marginalization=True, priors=priors.copy()),
-            self.get_likelihood(priors=priors.copy()),
+            self.get_likelihood("regular", time_marginalization=True, priors=priors.copy()),
+            self.get_likelihood("regular", priors=priors.copy()),
             key="geocent_time",
             values=self.waveform_generator.time_array,
             prior=prior,
         )
 
 
-class TestMarginalizationsROQ(TestMarginalizations):
-
-    lookup_phase = "distance_lookup_phase.npz"
-    lookup_no_phase = "distance_lookup_no_phase.npz"
-    path_to_roq_weights = "weights.npz"
-
-    def setUp(self):
-        np.random.seed(500)
-        self.duration = 4
-        self.sampling_frequency = 2048
-        self.parameters = dict(
-            mass_1=31.0,
-            mass_2=29.0,
-            a_1=0.4,
-            a_2=0.3,
-            tilt_1=0.0,
-            tilt_2=0.0,
-            phi_12=1.7,
-            phi_jl=0.3,
-            luminosity_distance=4000.0,
-            theta_jn=0.4,
-            psi=2.659,
-            phase=1.3,
-            geocent_time=1126259642.413,
-            ra=1.375,
-            dec=-1.2108,
-            time_jitter=0,
-        )
-
-        self.interferometers = bilby.gw.detector.InterferometerList(["H1"])
-        self.interferometers.set_strain_data_from_power_spectral_densities(
-            sampling_frequency=self.sampling_frequency,
-            duration=self.duration,
-            start_time=1126259640,
-        )
-
-        waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
-            duration=self.duration,
-            sampling_frequency=self.sampling_frequency,
-            frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-            start_time=1126259640,
-            waveform_arguments=dict(
-                reference_frequency=20.0,
-                minimum_frequency=20.0,
-                approximant="IMRPhenomPv2"
-            )
-        )
-        self.interferometers.inject_signal(
-            parameters=self.parameters, waveform_generator=waveform_generator
-        )
-
-        self.priors = bilby.gw.prior.BBHPriorDict()
-        # prior range should be a part of segment since ROQ likelihood can not
-        # calculate values at samples close to edges
-        self.priors["geocent_time"] = bilby.prior.Uniform(
-            minimum=self.parameters["geocent_time"] - 0.1,
-            maximum=self.parameters["geocent_time"] + 0.1
-        )
-
-        # Possible locations for the ROQ: in the docker image, local, or on CIT
-        trial_roq_paths = [
-            "/roq_basis",
-            os.path.join(os.path.expanduser("~"), "ROQ_data/IMRPhenomPv2/4s"),
-            "/home/cbc/ROQ_data/IMRPhenomPv2/4s",
-        ]
-        roq_dir = None
-        for path in trial_roq_paths:
-            if os.path.isdir(path):
-                roq_dir = path
-                break
-        if roq_dir is None:
-            raise Exception("Unable to load ROQ basis: cannot proceed with tests")
-
-        self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
-            duration=self.duration,
-            sampling_frequency=self.sampling_frequency,
-            frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq,
-            start_time=1126259640,
-            waveform_arguments=dict(
-                reference_frequency=20.0,
-                minimum_frequency=20.0,
-                approximant="IMRPhenomPv2",
-                frequency_nodes_linear=np.load("{}/fnodes_linear.npy".format(roq_dir)),
-                frequency_nodes_quadratic=np.load("{}/fnodes_quadratic.npy".format(roq_dir)),
-            )
-        )
-        self.roq_linear_matrix_file = "{}/B_linear.npy".format(roq_dir)
-        self.roq_quadratic_matrix_file = "{}/B_quadratic.npy".format(roq_dir)
-
-    @classmethod
-    def tearDownClass(cls):
-        for filename in [cls.lookup_phase, cls.lookup_no_phase, cls.path_to_roq_weights]:
-            if os.path.exists(filename):
-                os.remove(filename)
-
-    def get_likelihood(
-        self,
-        time_marginalization=False,
-        phase_marginalization=False,
-        distance_marginalization=False,
-        priors=None
-    ):
-        if priors is None:
-            priors = self.priors.copy()
-        if distance_marginalization and phase_marginalization:
-            lookup = TestMarginalizationsROQ.lookup_phase
-        elif distance_marginalization:
-            lookup = TestMarginalizationsROQ.lookup_no_phase
-        else:
-            lookup = None
-        kwargs = dict(
-            interferometers=self.interferometers,
-            waveform_generator=self.waveform_generator,
-            distance_marginalization=distance_marginalization,
-            phase_marginalization=phase_marginalization,
-            time_marginalization=time_marginalization,
-            distance_marginalization_lookup_table=lookup,
-            priors=priors
-        )
-        if os.path.exists(TestMarginalizationsROQ.path_to_roq_weights):
-            kwargs.update(dict(weights=TestMarginalizationsROQ.path_to_roq_weights))
-            like = bilby.gw.likelihood.ROQGravitationalWaveTransient(**kwargs)
-        else:
-            kwargs.update(
-                dict(
-                    linear_matrix=self.roq_linear_matrix_file,
-                    quadratic_matrix=self.roq_quadratic_matrix_file
-                )
-            )
-            like = bilby.gw.likelihood.ROQGravitationalWaveTransient(**kwargs)
-            like.save_weights(TestMarginalizationsROQ.path_to_roq_weights)
-        like.parameters = self.parameters.copy()
-        if time_marginalization:
-            like.parameters["geocent_time"] = self.interferometers.start_time
-        return like
-
-    def test_time_marginalisation(self):
-        self._template(
-            self.get_likelihood(time_marginalization=True),
-            self.get_likelihood(),
-            key="geocent_time",
-        )
-
-    def test_time_distance_marginalisation(self):
-        self._template(
-            self.get_likelihood(time_marginalization=True, distance_marginalization=True),
-            self.get_likelihood(distance_marginalization=True),
-            key="geocent_time",
-        )
-
-    def test_time_phase_marginalisation(self):
-        self._template(
-            self.get_likelihood(time_marginalization=True, phase_marginalization=True),
-            self.get_likelihood(phase_marginalization=True),
-            key="geocent_time",
-        )
-
-    def test_time_distance_phase_marginalisation(self):
-        self._template(
-            self.get_likelihood(time_marginalization=True, phase_marginalization=True, distance_marginalization=True),
-            self.get_likelihood(phase_marginalization=True, distance_marginalization=True),
-            key="geocent_time",
-        )
-
-    def test_time_marginalisation_partial_segment(self):
-        pass
-
-
 class TestROQLikelihood(unittest.TestCase):
     def setUp(self):
         self.duration = 4
-- 
GitLab