From aad4e48b45004e5684d46488f57f9e10bf1060f6 Mon Sep 17 00:00:00 2001
From: Soichiro Morisaki <soichiro.morisaki@ligo.org>
Date: Tue, 8 Nov 2022 19:38:20 +0000
Subject: [PATCH] Determine reference_chirp_mass for
 MBGravitationalWaveTransient from prior when it is not specified

---
 bilby/gw/likelihood/multiband.py |  34 ++++-
 test/gw/likelihood_test.py       | 250 +++++++++++++++++--------------
 2 files changed, 164 insertions(+), 120 deletions(-)

diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py
index 9d0b7e1d1..88e7234a2 100644
--- a/bilby/gw/likelihood/multiband.py
+++ b/bilby/gw/likelihood/multiband.py
@@ -8,6 +8,7 @@ from ...core.utils import (
     logger, speed_of_light, solar_mass, radius_of_earth,
     gravitational_constant, round_up_to_power_of_two
 )
+from ..prior import CBCPriorDict
 
 
 class MBGravitationalWaveTransient(GravitationalWaveTransient):
@@ -21,8 +22,9 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
         A list of `bilby.detector.Interferometer` instances - contains the detector data and power spectral densities
     waveform_generator: `bilby.waveform_generator.WaveformGenerator`
         An object which computes the frequency-domain strain of the signal, given some set of parameters
-    reference_chirp_mass: float
-        A reference chirp mass for determining the frequency banding
+    reference_chirp_mass: float, optional
+        A reference chirp mass for determining the frequency banding. This is set to prior minimum of chirp mass if
+        not specified. Hence a CBCPriorDict object needs to be passed to priors when this parameter is not specified.
     highest_mode: int, optional
         The maximum magnetic number of gravitational-wave moments. Default is 2
     linear_interpolation: bool, optional
@@ -72,10 +74,11 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
 
     """
     def __init__(
-            self, interferometers, waveform_generator, reference_chirp_mass, highest_mode=2, linear_interpolation=True,
-            accuracy_factor=5, time_offset=None, delta_f_end=None, maximum_banding_frequency=None,
-            minimum_banding_duration=0., distance_marginalization=False, phase_marginalization=False, priors=None,
-            distance_marginalization_lookup_table=None, reference_frame="sky", time_reference="geocenter"
+            self, interferometers, waveform_generator, reference_chirp_mass=None, highest_mode=2,
+            linear_interpolation=True, accuracy_factor=5, time_offset=None, delta_f_end=None,
+            maximum_banding_frequency=None, minimum_banding_duration=0., distance_marginalization=False,
+            phase_marginalization=False, priors=None, distance_marginalization_lookup_table=None,
+            reference_frame="sky", time_reference="geocenter"
     ):
         super(MBGravitationalWaveTransient, self).__init__(
             interferometers=interferometers, waveform_generator=waveform_generator, priors=priors,
@@ -108,7 +111,24 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
         if isinstance(reference_chirp_mass, int) or isinstance(reference_chirp_mass, float):
             self._reference_chirp_mass = reference_chirp_mass
         else:
-            raise TypeError("reference_chirp_mass must be a number")
+            logger.info(
+                "No int or float number has been passed to reference_chirp_mass. "
+                "Checking prior minimum of chirp mass ..."
+            )
+            if not isinstance(self.priors, CBCPriorDict):
+                raise TypeError(
+                    f"priors: {self.priors} is not CBCPriorDict. Prior minimum of chirp mass can not be obtained."
+                )
+            self._reference_chirp_mass = self.priors.minimum_chirp_mass
+            if self._reference_chirp_mass is None:
+                raise Exception(
+                    "Prior minimum of chirp mass can not be determined as priors does not contain necessary mass "
+                    "parameters."
+                )
+            logger.info(
+                "reference_chirp_mass is automatically set to prior minimum of chirp mass: "
+                f"{self._reference_chirp_mass}."
+            )
 
     @property
     def highest_mode(self):
diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py
index d404804e9..29962de2b 100644
--- a/test/gw/likelihood_test.py
+++ b/test/gw/likelihood_test.py
@@ -2,7 +2,6 @@ import itertools
 import os
 import pytest
 import unittest
-from copy import deepcopy
 from itertools import product
 from parameterized import parameterized
 
@@ -1571,9 +1570,9 @@ class TestBBHLikelihoodSetUp(unittest.TestCase):
 
 class TestMBLikelihood(unittest.TestCase):
     def setUp(self):
-        duration = 16
-        fmin = 20.
-        sampling_frequency = 2048.
+        self.duration = 16
+        self.fmin = 20.
+        self.sampling_frequency = 2048.
         self.test_parameters = dict(
             chirp_mass=6.0,
             mass_ratio=0.5,
@@ -1592,18 +1591,18 @@ class TestMBLikelihood(unittest.TestCase):
             dec=-1.2
         )  # Network SNR is ~50
 
-        ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
+        self.ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
         np.random.seed(170817)
-        ifos.set_strain_data_from_power_spectral_densities(
-            sampling_frequency=sampling_frequency, duration=duration,
-            start_time=self.test_parameters['geocent_time'] - duration + 2.
+        self.ifos.set_strain_data_from_power_spectral_densities(
+            sampling_frequency=self.sampling_frequency, duration=self.duration,
+            start_time=self.test_parameters['geocent_time'] - self.duration + 2.
         )
-        for ifo in ifos:
-            ifo.minimum_frequency = fmin
+        for ifo in self.ifos:
+            ifo.minimum_frequency = self.fmin
 
         spline_calibration_nodes = 10
         self.calibration_parameters = {}
-        for ifo in ifos:
+        for ifo in self.ifos:
             ifo.calibration_model = bilby.gw.calibration.CubicSpline(
                 prefix=f"recalib_{ifo.name}_",
                 minimum_frequency=ifo.minimum_frequency,
@@ -1619,143 +1618,168 @@ class TestMBLikelihood(unittest.TestCase):
                 self.calibration_parameters[f"recalib_{ifo.name}_phase_{i}"] = \
                     np.random.normal(loc=0, scale=5 * np.pi / 180)
 
-        priors = bilby.gw.prior.BBHPriorDict()
-        priors.pop("mass_1")
-        priors.pop("mass_2")
-        priors["chirp_mass"] = bilby.core.prior.Uniform(5.5, 6.5)
-        priors["mass_ratio"] = bilby.core.prior.Uniform(0.125, 1)
-        priors["geocent_time"] = bilby.core.prior.Uniform(
+        self.priors = bilby.gw.prior.BBHPriorDict()
+        self.priors.pop("mass_1")
+        self.priors.pop("mass_2")
+        self.priors["chirp_mass"] = bilby.core.prior.Uniform(5.5, 6.5)
+        self.priors["mass_ratio"] = bilby.core.prior.Uniform(0.125, 1)
+        self.priors["geocent_time"] = bilby.core.prior.Uniform(
             self.test_parameters['geocent_time'] - 0.1,
             self.test_parameters['geocent_time'] + 0.1)
 
-        approximant_22 = "IMRPhenomD"
-        approximant_homs = "IMRPhenomHM"
-        non_mb_wfg_22 = bilby.gw.WaveformGenerator(
-            duration=duration, sampling_frequency=sampling_frequency,
+    def tearDown(self):
+        del (
+            self.ifos,
+            self.priors
+        )
+
+    @parameterized.expand([
+        ("IMRPhenomD", True, 2, False, 1.5e-2),
+        ("IMRPhenomD", True, 2, True, 1.5e-2),
+        ("IMRPhenomD", False, 2, False, 5e-3),
+        ("IMRPhenomD", False, 2, True, 6e-3),
+        ("IMRPhenomHM", False, 4, False, 8e-4),
+        ("IMRPhenomHM", False, 4, True, 1e-3)
+    ])
+    def test_matches_original_likelihood(
+        self, approximant, linear_interpolation, highest_mode, add_cal_errors, tolerance
+    ):
+        """
+        Check if multi-band likelihood values match original likelihood values
+        """
+        wfg = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
             frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
             waveform_arguments=dict(
-                reference_frequency=fmin, minimum_frequency=fmin, approximant=approximant_22)
+                reference_frequency=self.fmin, approximant=approximant
+            )
         )
-        mb_wfg_22 = bilby.gw.waveform_generator.WaveformGenerator(
-            duration=duration, sampling_frequency=sampling_frequency,
+        self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg)
+
+        wfg_mb = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
             frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
             waveform_arguments=dict(
-                reference_frequency=fmin, approximant=approximant_22)
-        )
-        non_mb_wfg_homs = bilby.gw.WaveformGenerator(
-            duration=duration, sampling_frequency=sampling_frequency,
-            frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-            waveform_arguments=dict(
-                reference_frequency=fmin, minimum_frequency=fmin, approximant=approximant_homs)
+                reference_frequency=self.fmin, approximant=approximant
+            )
         )
-        mb_wfg_homs = bilby.gw.waveform_generator.WaveformGenerator(
-            duration=duration, sampling_frequency=sampling_frequency,
-            frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
-            waveform_arguments=dict(
-                reference_frequency=fmin, approximant=approximant_homs)
+        likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg
         )
-
-        ifos_22 = deepcopy(ifos)
-        ifos_22.inject_signal(
-            parameters=self.test_parameters, waveform_generator=non_mb_wfg_22
+        likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg_mb,
+            reference_chirp_mass=self.test_parameters['chirp_mass'],
+            priors=self.priors.copy(), linear_interpolation=linear_interpolation,
+            highest_mode=highest_mode
         )
-        ifos_homs = deepcopy(ifos)
-        ifos_homs.inject_signal(
-            parameters=self.test_parameters, waveform_generator=non_mb_wfg_homs
+        likelihood.parameters.update(self.test_parameters)
+        likelihood_mb.parameters.update(self.test_parameters)
+        if add_cal_errors:
+            likelihood.parameters.update(self.calibration_parameters)
+            likelihood_mb.parameters.update(self.calibration_parameters)
+        self.assertLess(
+            abs(likelihood.log_likelihood_ratio() - likelihood_mb.log_likelihood_ratio()),
+            tolerance
         )
 
-        self.non_mb_22 = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=ifos_22, waveform_generator=non_mb_wfg_22
-        )
-        self.non_mb_homs = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=ifos_homs, waveform_generator=non_mb_wfg_homs
+    def test_large_accuracy_factor(self):
+        """
+        Check if larger accuracy factor increases the accuracy.
+        """
+        approximant = "IMRPhenomD"
+        wfg = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant=approximant
+            )
         )
+        self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg)
 
-        self.mb_22 = bilby.gw.likelihood.MBGravitationalWaveTransient(
-            interferometers=ifos_22, waveform_generator=deepcopy(mb_wfg_22),
-            reference_chirp_mass=self.test_parameters['chirp_mass'],
-            priors=priors.copy()
+        wfg_mb = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant=approximant
+            )
         )
-        self.mb_ifftfft_22 = bilby.gw.likelihood.MBGravitationalWaveTransient(
-            interferometers=ifos_22, waveform_generator=deepcopy(mb_wfg_22),
-            reference_chirp_mass=self.test_parameters['chirp_mass'],
-            priors=priors.copy(), linear_interpolation=False
+        likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg
         )
-        self.mb_homs = bilby.gw.likelihood.MBGravitationalWaveTransient(
-            interferometers=ifos_homs, waveform_generator=deepcopy(mb_wfg_homs),
+        likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg_mb,
             reference_chirp_mass=self.test_parameters['chirp_mass'],
-            priors=priors.copy(), linear_interpolation=False, highest_mode=4
+            priors=self.priors.copy(), accuracy_factor=5
         )
-        self.mb_more_accurate = bilby.gw.likelihood.MBGravitationalWaveTransient(
-            interferometers=ifos_22, waveform_generator=deepcopy(mb_wfg_22),
+        likelihood_mb_more_accurate = bilby.gw.likelihood.MBGravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg_mb,
             reference_chirp_mass=self.test_parameters['chirp_mass'],
-            priors=priors.copy(), accuracy_factor=50
-        )
-
-    def tearDown(self):
-        del (
-            self.non_mb_22,
-            self.non_mb_homs,
-            self.mb_22,
-            self.mb_ifftfft_22,
-            self.mb_homs,
-            self.mb_more_accurate
+            priors=self.priors.copy(), accuracy_factor=50
         )
-
-    @parameterized.expand([(False, ), (True, )])
-    def test_matches_non_mb(self, add_cal_errors):
-        self.non_mb_22.parameters.update(self.test_parameters)
-        self.mb_22.parameters.update(self.test_parameters)
-        if add_cal_errors:
-            self.non_mb_22.parameters.update(self.calibration_parameters)
-            self.mb_22.parameters.update(self.calibration_parameters)
+        likelihood.parameters.update(self.test_parameters)
+        likelihood_mb.parameters.update(self.test_parameters)
+        likelihood_mb_more_accurate.parameters.update(self.test_parameters)
         self.assertLess(
-            abs(self.non_mb_22.log_likelihood_ratio() - self.mb_22.log_likelihood_ratio()),
-            1.5e-2
+            abs(likelihood.log_likelihood_ratio() - likelihood_mb_more_accurate.log_likelihood_ratio()),
+            abs(likelihood.log_likelihood_ratio() - likelihood_mb.log_likelihood_ratio()) / 2
         )
 
-    @parameterized.expand([(False, ), (True, )])
-    def test_ifft_fft(self, add_cal_errors):
+    def test_reference_chirp_mass_from_prior(self):
         """
-        Check if multi-banding likelihood with (h, h) computed with the
-        IFFT-FFT algorithm matches the original likelihood.
+        Check if reference chirp mass is automatically determined from prior if no number has been passed
         """
-        self.non_mb_22.parameters.update(self.test_parameters)
-        self.mb_ifftfft_22.parameters.update(self.test_parameters)
-        if add_cal_errors:
-            self.non_mb_22.parameters.update(self.calibration_parameters)
-            self.mb_ifftfft_22.parameters.update(self.calibration_parameters)
-        self.assertLess(
-            abs(self.non_mb_22.log_likelihood_ratio() - self.mb_ifftfft_22.log_likelihood_ratio()),
-            6e-3
+        wfg_mb = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant="IMRPhenomD"
+            )
+        )
+        likelihood1 = bilby.gw.likelihood.MBGravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg_mb,
+            reference_chirp_mass=self.priors["chirp_mass"].minimum,
+            priors=self.priors.copy()
+        )
+        likelihood2 = bilby.gw.likelihood.MBGravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg_mb,
+            priors=self.priors.copy()
         )
+        self.assertAlmostEqual(likelihood1.reference_chirp_mass, likelihood2.reference_chirp_mass)
 
-    @parameterized.expand([(False, ), (True, )])
-    def test_homs(self, add_cal_errors):
+    def test_no_reference_chirp_mass(self):
         """
-        Check if multi-banding likelihood matches the original likelihood for higher-order moments.
+        Check if an error is raised if either reference_chirp_mass or priors is not specified.
         """
-        self.non_mb_homs.parameters.update(self.test_parameters)
-        self.mb_homs.parameters.update(self.test_parameters)
-        if add_cal_errors:
-            self.non_mb_homs.parameters.update(self.calibration_parameters)
-            self.mb_homs.parameters.update(self.calibration_parameters)
-        self.assertLess(
-            abs(self.non_mb_homs.log_likelihood_ratio() - self.mb_homs.log_likelihood_ratio()),
-            1e-3
+        wfg_mb = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant="IMRPhenomD"
+            )
         )
+        with self.assertRaises(TypeError):
+            bilby.gw.likelihood.MBGravitationalWaveTransient(
+                interferometers=self.ifos, waveform_generator=wfg_mb
+            )
 
-    def test_large_accuracy_factor(self):
+    def test_cannot_determine_reference_chirp_mass(self):
         """
-        Check if larger accuracy factor increases the accuracy.
+        Check if an error is raised if priors does not contain necessary information to determine reference chirp mass
         """
-        self.non_mb_22.parameters.update(self.test_parameters)
-        self.mb_22.parameters.update(self.test_parameters)
-        self.mb_more_accurate.parameters.update(self.test_parameters)
-        self.assertLess(
-            abs(self.non_mb_22.log_likelihood_ratio() - self.mb_more_accurate.log_likelihood_ratio()),
-            abs(self.non_mb_22.log_likelihood_ratio() - self.mb_22.log_likelihood_ratio()) / 2
+        wfg_mb = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant="IMRPhenomD"
+            )
         )
+        for key in ["chirp_mass", "mass_1", "mass_2"]:
+            if key in self.priors:
+                self.priors.pop(key)
+        with self.assertRaises(Exception):
+            bilby.gw.likelihood.MBGravitationalWaveTransient(
+                interferometers=self.ifos, waveform_generator=wfg_mb, priors=self.priors
+            )
 
 
 if __name__ == "__main__":
-- 
GitLab