From 53d6f2fab9b52f769be989647f2dfb3eb3f353a9 Mon Sep 17 00:00:00 2001
From: Hector Estelles <hector.estelles@ligo.org>
Date: Thu, 13 Jul 2023 13:56:33 +0000
Subject: [PATCH] FEATURE: Interface for gwsignal (new waveform interface)

---
 bilby/gw/source.py     | 236 +++++++++++++++++++++++++++++++++++++++++
 test/gw/source_test.py |  86 ++++++++++++++-
 2 files changed, 319 insertions(+), 3 deletions(-)

diff --git a/bilby/gw/source.py b/bilby/gw/source.py
index aa0ad5bbb..84b1a3e9d 100644
--- a/bilby/gw/source.py
+++ b/bilby/gw/source.py
@@ -11,6 +11,242 @@ from .utils import (lalsim_GetApproximantFromString,
                     lalsim_SimInspiralChooseFDWaveformSequence)
 
 
+def gwsignal_binary_black_hole(frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1,
+                               phi_12, a_2, tilt_2, phi_jl, theta_jn, phase, **kwargs):
+    """
+    A binary black hole waveform model using GWsignal
+
+    Parameters
+    ==========
+    frequency_array: array_like
+        The frequencies at which we want to calculate the strain
+    mass_1: float
+        The mass of the heavier object in solar masses
+    mass_2: float
+        The mass of the lighter object in solar masses
+    luminosity_distance: float
+        The luminosity distance in megaparsec
+    a_1: float
+        Dimensionless primary spin magnitude
+    tilt_1: float
+        Primary tilt angle
+    phi_12: float
+        Azimuthal angle between the two component spins
+    a_2: float
+        Dimensionless secondary spin magnitude
+    tilt_2: float
+        Secondary tilt angle
+    phi_jl: float
+        Azimuthal angle between the total binary angular momentum and the
+        orbital angular momentum
+    theta_jn: float
+        Angle between the total binary angular momentum and the line of sight
+    phase: float
+        The phase at coalescence
+    kwargs: dict
+        Optional keyword arguments
+        Supported arguments:
+
+        - waveform_approximant
+        - reference_frequency
+        - minimum_frequency
+        - maximum_frequency
+        - catch_waveform_errors
+        - pn_amplitude_order
+        - mode_array:
+          Activate a specific mode array and evaluate the model using those
+          modes only.  e.g. waveform_arguments =
+          dict(waveform_approximant='IMRPhenomHM', mode_array=[[2,2],[2,-2]])
+          returns the 22 and 2-2 modes only of IMRPhenomHM.  You can only
+          specify modes that are included in that particular model.  e.g.
+          waveform_arguments = dict(waveform_approximant='IMRPhenomHM',
+          mode_array=[[2,2],[2,-2],[5,5],[5,-5]]) is not allowed because the
+          55 modes are not included in this model.  Be aware that some models
+          only take positive modes and return the positive and the negative
+          mode together, while others need to call both.  e.g.
+          waveform_arguments = dict(waveform_approximant='IMRPhenomHM',
+          mode_array=[[2,2],[4,-4]]) returns the 22 and 2-2 of IMRPhenomHM.
+          However, waveform_arguments =
+          dict(waveform_approximant='IMRPhenomXHM', mode_array=[[2,2],[4,-4]])
+          returns the 22 and 4-4 of IMRPhenomXHM.
+
+    Returns
+    =======
+    dict: A dictionary with the plus and cross polarisation strain modes
+
+    Notes
+    =====
+    This function is a temporary wrapper to the interface that will
+    likely be significantly changed or removed in a future release.
+    This version is only intended to be used with `SEOBNRv5HM` and `SEOBNRv5PHM` and
+    does not have full functionality for other waveform models.
+    """
+
+    from lalsimulation.gwsignal import GenerateFDWaveform
+    from lalsimulation.gwsignal.models import gwsignal_get_waveform_generator
+    import astropy.units as u
+
+    waveform_kwargs = dict(
+        waveform_approximant="SEOBNRv5PHM",
+        reference_frequency=50.0,
+        minimum_frequency=20.0,
+        maximum_frequency=frequency_array[-1],
+        catch_waveform_errors=False,
+        mode_array=None,
+        pn_amplitude_order=0,
+    )
+    waveform_kwargs.update(kwargs)
+
+    waveform_approximant = waveform_kwargs['waveform_approximant']
+    if waveform_approximant not in ["SEOBNRv5HM", "SEOBNRv5PHM"]:
+        if waveform_approximant == "IMRPhenomXPHM":
+            logger.warning("The new waveform interface is unreviewed for this model" +
+                           "and it is only intended for testing.")
+        else:
+            raise ValueError("The new waveform interface is unreviewed for this model.")
+    reference_frequency = waveform_kwargs['reference_frequency']
+    minimum_frequency = waveform_kwargs['minimum_frequency']
+    maximum_frequency = waveform_kwargs['maximum_frequency']
+    catch_waveform_errors = waveform_kwargs['catch_waveform_errors']
+    mode_array = waveform_kwargs['mode_array']
+    pn_amplitude_order = waveform_kwargs['pn_amplitude_order']
+
+    if pn_amplitude_order != 0:
+        # This is to mimic the behaviour in
+        # https://git.ligo.org/lscsoft/lalsuite/-/blob/master/lalsimulation/lib/LALSimInspiral.c#L5542
+        if pn_amplitude_order == -1:
+            if waveform_approximant in ["SpinTaylorT4", "SpinTaylorT5"]:
+                pn_amplitude_order = 3  # Equivalent to MAX_PRECESSING_AMP_PN_ORDER in LALSimulation
+            else:
+                pn_amplitude_order = 6  # Equivalent to MAX_NONPRECESSING_AMP_PN_ORDER in LALSimulation
+        start_frequency = minimum_frequency * 2. / (pn_amplitude_order + 2)
+    else:
+        start_frequency = minimum_frequency
+
+    # Call GWsignal generator
+    wf_gen = gwsignal_get_waveform_generator(waveform_approximant)
+
+    delta_frequency = frequency_array[1] - frequency_array[0]
+
+    frequency_bounds = ((frequency_array >= minimum_frequency) *
+                        (frequency_array <= maximum_frequency))
+
+    iota, spin_1x, spin_1y, spin_1z, spin_2x, spin_2y, spin_2z = bilby_to_lalsimulation_spins(
+        theta_jn=theta_jn, phi_jl=phi_jl, tilt_1=tilt_1, tilt_2=tilt_2,
+        phi_12=phi_12, a_1=a_1, a_2=a_2, mass_1=mass_1 * utils.solar_mass, mass_2=mass_2 * utils.solar_mass,
+        reference_frequency=reference_frequency, phase=phase)
+
+    eccentricity = 0.0
+    longitude_ascending_nodes = 0.0
+    mean_per_ano = 0.0
+
+    # Check if conditioning is needed
+    condition = 0
+    if wf_gen.metadata["implemented_domain"] == 'time':
+        condition = 1
+
+    # Create dict for gwsignal generator
+    gwsignal_dict = {'mass1' : mass_1 * u.solMass,
+                     'mass2' : mass_2 * u.solMass,
+                     'spin1x' : spin_1x * u.dimensionless_unscaled,
+                     'spin1y' : spin_1y * u.dimensionless_unscaled,
+                     'spin1z' : spin_1z * u.dimensionless_unscaled,
+                     'spin2x' : spin_2x * u.dimensionless_unscaled,
+                     'spin2y' : spin_2y * u.dimensionless_unscaled,
+                     'spin2z' : spin_2z * u.dimensionless_unscaled,
+                     'deltaF' : delta_frequency * u.Hz,
+                     'f22_start' : start_frequency * u.Hz,
+                     'f_max': maximum_frequency * u.Hz,
+                     'f22_ref': reference_frequency * u.Hz,
+                     'phi_ref' : phase * u.rad,
+                     'distance' : luminosity_distance * u.Mpc,
+                     'inclination' : iota * u.rad,
+                     'eccentricity' : eccentricity * u.dimensionless_unscaled,
+                     'longAscNodes' : longitude_ascending_nodes * u.rad,
+                     'meanPerAno' : mean_per_ano * u.rad,
+                     # 'ModeArray': mode_array,
+                     'condition': condition
+                     }
+
+    if mode_array is not None:
+        gwsignal_dict.update(ModeArray=mode_array)
+
+    # Pass extra waveform arguments to gwsignal
+    extra_args = waveform_kwargs.copy()
+
+    for key in [
+            "waveform_approximant",
+            "reference_frequency",
+            "minimum_frequency",
+            "maximum_frequency",
+            "catch_waveform_errors",
+            "mode_array",
+            "pn_spin_order",
+            "pn_amplitude_order",
+            "pn_tidal_order",
+            "pn_phase_order",
+            "numerical_relativity_file",
+    ]:
+        if key in extra_args.keys():
+            del extra_args[key]
+
+    gwsignal_dict.update(extra_args)
+
+    try:
+        hpc = GenerateFDWaveform(gwsignal_dict, wf_gen)
+    except Exception as e:
+        if not catch_waveform_errors:
+            raise
+        else:
+            EDOM = (
+                "Internal function call failed: Input domain error" in e.args[0]
+            ) or "Input domain error" in e.args[
+                0
+            ]
+            if EDOM:
+                failed_parameters = dict(mass_1=mass_1, mass_2=mass_2,
+                                         spin_1=(spin_1x, spin_2y, spin_1z),
+                                         spin_2=(spin_2x, spin_2y, spin_2z),
+                                         luminosity_distance=luminosity_distance,
+                                         iota=iota, phase=phase,
+                                         eccentricity=eccentricity,
+                                         start_frequency=minimum_frequency)
+                logger.warning("Evaluating the waveform failed with error: {}\n".format(e) +
+                               "The parameters were {}\n".format(failed_parameters) +
+                               "Likelihood will be set to -inf.")
+                return None
+            else:
+                raise
+
+    hplus = hpc.hp
+    hcross = hpc.hc
+
+    h_plus = np.zeros_like(frequency_array, dtype=complex)
+    h_cross = np.zeros_like(frequency_array, dtype=complex)
+
+    if len(hplus) > len(frequency_array):
+        logger.debug("GWsignal waveform longer than bilby's `frequency_array`" +
+                     "({} vs {}), ".format(len(hplus), len(frequency_array)) +
+                     "probably because padded with zeros up to the next power of two length." +
+                     " Truncating GWsignal array.")
+        h_plus = hplus[:len(h_plus)]
+        h_cross = hcross[:len(h_cross)]
+    else:
+        h_plus[:len(hplus)] = hplus
+        h_cross[:len(hcross)] = hcross
+
+    h_plus *= frequency_bounds
+    h_cross *= frequency_bounds
+
+    if condition:
+        dt = 1 / hplus.df.value + hplus.epoch.value
+        time_shift = np.exp(-1j * 2 * np.pi * dt * frequency_array[frequency_bounds])
+        h_plus[frequency_bounds] *= time_shift
+        h_cross[frequency_bounds] *= time_shift
+
+    return dict(plus=h_plus, cross=h_cross)
+
+
 def lal_binary_black_hole(
         frequency_array, mass_1, mass_2, luminosity_distance, a_1, tilt_1,
         phi_12, a_2, tilt_2, phi_jl, theta_jn, phase, **kwargs):
diff --git a/test/gw/source_test.py b/test/gw/source_test.py
index 8a4b7625b..979699b34 100644
--- a/test/gw/source_test.py
+++ b/test/gw/source_test.py
@@ -1,12 +1,12 @@
 import unittest
 
-import numpy as np
-from copy import copy
-
 import bilby
 import lal
 import lalsimulation
 
+import numpy as np
+from copy import copy
+
 
 class TestLalBBH(unittest.TestCase):
     def setUp(self):
@@ -95,6 +95,86 @@ class TestLalBBH(unittest.TestCase):
         self.assertFalse(np.all(out_v223["plus"] == out_v102["plus"]))
 
 
+class TestGWSignalBBH(unittest.TestCase):
+    def setUp(self):
+        self.parameters = dict(
+            mass_1=30.0,
+            mass_2=30.0,
+            luminosity_distance=400.0,
+            a_1=0.4,
+            tilt_1=0.2,
+            phi_12=1.0,
+            a_2=0.8,
+            tilt_2=2.7,
+            phi_jl=2.9,
+            theta_jn=0.3,
+            phase=0.0,
+        )
+        self.waveform_kwargs = dict(
+            waveform_approximant="IMRPhenomXPHM",
+            reference_frequency=50.0,
+            minimum_frequency=20.0,
+            catch_waveform_errors=True,
+        )
+        self.frequency_array = bilby.core.utils.create_frequency_series(2048, 4)
+        self.bad_parameters = copy(self.parameters)
+        self.bad_parameters["mass_1"] = -30.0
+
+    def tearDown(self):
+        del self.parameters
+        del self.waveform_kwargs
+        del self.frequency_array
+        del self.bad_parameters
+
+    def test_gwsignal_bbh_works_runs_valid_parameters(self):
+        self.parameters.update(self.waveform_kwargs)
+        self.assertIsInstance(
+            bilby.gw.source.gwsignal_binary_black_hole(
+                self.frequency_array, **self.parameters
+            ),
+            dict,
+        )
+
+    def test_waveform_error_catching(self):
+        self.bad_parameters.update(self.waveform_kwargs)
+        self.assertIsNone(
+            bilby.gw.source.gwsignal_binary_black_hole(
+                self.frequency_array, **self.bad_parameters
+            )
+        )
+
+    def test_waveform_error_raising(self):
+        raise_error_parameters = copy(self.bad_parameters)
+        raise_error_parameters.update(self.waveform_kwargs)
+        raise_error_parameters["catch_waveform_errors"] = False
+        with self.assertRaises(Exception):
+            bilby.gw.source.gwsignal_binary_black_hole(
+                self.frequency_array, **raise_error_parameters
+            )
+    # def test_gwsignal_bbh_works_without_waveform_parameters(self):
+    #    self.assertIsInstance(
+    #        bilby.gw.source.gwsignal_binary_black_hole(
+    #            self.frequency_array, **self.parameters
+    #        ),
+    #        dict,
+    #    )
+
+    def test_gwsignal_lal_bbh_consistency(self):
+        self.parameters.update(self.waveform_kwargs)
+        hpc_gwsignal = bilby.gw.source.gwsignal_binary_black_hole(
+            self.frequency_array, **self.parameters
+        )
+        hpc_lal = bilby.gw.source.lal_binary_black_hole(
+            self.frequency_array, **self.parameters
+        )
+        self.assertTrue(
+            np.allclose(hpc_gwsignal["plus"], hpc_lal["plus"], atol=0, rtol=1e-7)
+        )
+        self.assertTrue(
+            np.allclose(hpc_gwsignal["cross"], hpc_lal["cross"], atol=0, rtol=1e-7)
+        )
+
+
 class TestLalBNS(unittest.TestCase):
     def setUp(self):
         self.parameters = dict(
-- 
GitLab