From 0e3bbe0d2ffa057219071aa630d1713fef898a88 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Tue, 14 Feb 2023 14:36:59 +0000
Subject: [PATCH] FEAT: Precomputed calibration model

---
 bilby/gw/detector/calibration.py              | 109 ++++++++++++++++++
 .../injection_examples/calibration_example.py |  44 ++++---
 test/gw/likelihood/marginalization_test.py    |  23 ++--
 3 files changed, 148 insertions(+), 28 deletions(-)

diff --git a/bilby/gw/detector/calibration.py b/bilby/gw/detector/calibration.py
index 7c9559a49..a5811c1ae 100644
--- a/bilby/gw/detector/calibration.py
+++ b/bilby/gw/detector/calibration.py
@@ -9,6 +9,7 @@ from scipy.interpolate import interp1d
 
 from ...core.utils.log import logger
 from ...core.prior.dict import PriorDict
+from ..prior import CalibrationPriorDict
 
 
 def read_calibration_file(filename, frequency_array, number_of_response_curves, starting_index=0):
@@ -225,6 +226,108 @@ class CubicSpline(Recalibrate):
         return calibration_factor
 
 
+class Precomputed(Recalibrate):
+
+    name = "precomputed"
+
+    def __init__(self, label, curves, frequency_array, parameters=None):
+        """
+        A class for accessing an array of precomputed recalibration curves.
+
+        Parameters
+        ==========
+        label: str
+            The label for the interferometer, e.g., H1. The corresponding
+            parameter is :code:`recalib_index_{label}`.
+        curves: array-like
+            Array with shape (n_curves, n_frequencies) with the recalibration
+            curves.
+        frequency_array: array-like
+            Array of frequencies at which the curves are evaluated.
+        """
+        self.label = label
+        self.curves = curves
+        self.frequency_array = frequency_array
+        self.parameters = parameters
+        super(Precomputed, self).__init__(prefix=f"recalib_index_{self.label}")
+
+    def get_calibration_factor(self, frequency_array, **params):
+        idx = int(params.get(self.prefix, None))
+        if idx is None:
+            raise KeyError(f"Calibration index for {self.label} not found.")
+        if not np.array_equal(frequency_array, self.frequency_array):
+            raise ValueError("Frequency grid passed to calibrator doesn't match.")
+        return self.curves[idx]
+
+    @classmethod
+    def constant_uncertainty_spline(
+        cls, amplitude_sigma, phase_sigma, frequency_array, n_nodes, label, n_curves
+    ):
+        priors = CalibrationPriorDict.constant_uncertainty_spline(
+            amplitude_sigma=amplitude_sigma,
+            phase_sigma=phase_sigma,
+            minimum_frequency=frequency_array[0],
+            maximum_frequency=frequency_array[-1],
+            n_nodes=n_nodes,
+            label=label,
+        )
+        parameters = pd.DataFrame(priors.sample(n_curves))
+        curves = curves_from_spline_and_prior(
+            label=label,
+            frequency_array=frequency_array,
+            n_points=n_nodes,
+            parameters=parameters,
+            n_curves=n_curves
+        )
+        return cls(
+            label=label,
+            curves=np.array(curves),
+            frequency_array=frequency_array,
+            parameters=parameters,
+        )
+
+    @classmethod
+    def from_envelope_file(
+        cls, envelope, frequency_array, n_nodes, label, n_curves
+    ):
+        priors = CalibrationPriorDict.from_envelope_file(
+            envelope_file=envelope,
+            minimum_frequency=frequency_array[0],
+            maximum_frequency=frequency_array[-1],
+            n_nodes=n_nodes,
+            label=label,
+        )
+        parameters = pd.DataFrame(priors.sample(n_curves))
+        curves = curves_from_spline_and_prior(
+            label=label,
+            frequency_array=frequency_array,
+            n_points=n_nodes,
+            parameters=parameters,
+            n_curves=n_curves,
+        )
+        return cls(
+            label=label,
+            curves=np.array(curves),
+            frequency_array=frequency_array,
+            parameters=parameters,
+        )
+
+    @classmethod
+    def from_calibration_file(cls, label, filename, frequency_array, n_curves, starting_index=0):
+        curves, parameters = read_calibration_file(
+            filename=filename,
+            frequency_array=frequency_array,
+            number_of_response_curves=n_curves,
+            starting_index=starting_index,
+        )
+        return cls(
+            label=label,
+            curves=np.array(curves),
+            frequency_array=frequency_array,
+            parameters=parameters,
+        )
+
+
 def build_calibration_lookup(
     interferometers,
     lookup_files=None,
@@ -255,6 +358,12 @@ def build_calibration_lookup(
                 number_of_response_curves,
                 starting_index,
             )
+        elif isinstance(interferometer.calibration_model, Precomputed):
+            model = interferometer.calibration_model
+            idxs = np.arange(number_of_response_curves, dtype=int) + starting_index
+            draws[name] = model.curves[idxs]
+            parameters[name] = pd.DataFrame(model.parameters.iloc[idxs])
+            parameters[name][model.prefix] = idxs
         else:
             if priors is None:
                 raise ValueError(
diff --git a/examples/gw_examples/injection_examples/calibration_example.py b/examples/gw_examples/injection_examples/calibration_example.py
index 91629dd47..9f7ceaf62 100644
--- a/examples/gw_examples/injection_examples/calibration_example.py
+++ b/examples/gw_examples/injection_examples/calibration_example.py
@@ -5,6 +5,11 @@ uncertainties included.
 
 We set up the full problem as is required and then just sample over a small
 number of calibration parameters.
+
+We demonstrate, two formulations of the calibration model:
+- a cubic spline described by gaussian distributions at a set of nodes.
+- a set of precomputed curves, in this example we use cubic spline realizations,
+  however, it also applies to physically motivated models.
 """
 
 import bilby
@@ -62,22 +67,32 @@ waveform_generator = bilby.gw.WaveformGenerator(
 # (LIGO-Hanford (H1), LIGO-Livingston (L1), and Virgo (V1)).
 # These default to their design sensitivity
 ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
-for ifo in ifos:
-    injection_parameters.update(
-        {f"recalib_{ifo.name}_amplitude_{ii}": 0.1 for ii in range(5)}
-    )
-    injection_parameters.update(
-        {f"recalib_{ifo.name}_phase_{ii}": 0.01 for ii in range(5)}
-    )
-    ifo.calibration_model = bilby.gw.calibration.CubicSpline(
-        prefix=f"recalib_{ifo.name}_",
-        minimum_frequency=ifo.minimum_frequency,
-        maximum_frequency=ifo.maximum_frequency,
-        n_points=5,
-    )
 ifos.set_strain_data_from_power_spectral_densities(
     sampling_frequency=sampling_frequency, duration=duration
 )
+
+ifo = ifos[0]
+injection_parameters.update(
+    {f"recalib_{ifo.name}_amplitude_{ii}": 0.1 for ii in range(5)}
+)
+injection_parameters.update({f"recalib_{ifo.name}_phase_{ii}": 0.01 for ii in range(5)})
+ifo.calibration_model = bilby.gw.calibration.CubicSpline(
+    prefix=f"recalib_{ifo.name}_",
+    minimum_frequency=ifo.minimum_frequency,
+    maximum_frequency=ifo.maximum_frequency,
+    n_points=5,
+)
+ifo = ifos[1]
+injection_parameters["recalib_index_L1"] = 3
+ifo.calibration_model = bilby.gw.calibration.Precomputed.constant_uncertainty_spline(
+    amplitude_sigma=0.1,
+    phase_sigma=0.01,
+    label="L1",
+    frequency_array=ifo.frequency_array[ifo.frequency_mask],
+    n_nodes=5,
+    n_curves=100,
+)
+
 ifos.inject_signal(
     parameters=injection_parameters, waveform_generator=waveform_generator
 )
@@ -94,6 +109,9 @@ for name in ["recalib_H1_amplitude_0", "recalib_H1_amplitude_1"]:
     priors[name] = bilby.core.prior.Gaussian(
         mu=0, sigma=0.2, name=name, latex_label=f"H1 $A_{name[-1]}$"
     )
+priors["recalib_index_L1"] = bilby.core.prior.Categorical(
+    ncategories=100, latex_label="recalib index L1"
+)
 
 # Initialise the likelihood by passing in the interferometer data (IFOs) and
 # the waveform generator
diff --git a/test/gw/likelihood/marginalization_test.py b/test/gw/likelihood/marginalization_test.py
index b8f0aeb1d..42fa0332f 100644
--- a/test/gw/likelihood/marginalization_test.py
+++ b/test/gw/likelihood/marginalization_test.py
@@ -455,11 +455,13 @@ class CalibrationMarginalization(unittest.TestCase):
             maximum_frequency=512,
             n_points=5,
         )
-        self.ifos[1].calibration_model = calibration.CubicSpline(
-            prefix="recalib_L1_",
-            minimum_frequency=20,
-            maximum_frequency=512,
-            n_points=5,
+        self.ifos[1].calibration_model = calibration.Precomputed.constant_uncertainty_spline(
+            amplitude_sigma=0.1,
+            phase_sigma=0.1,
+            frequency_array=self.ifos[1].frequency_array[self.ifos[1].frequency_mask],
+            n_nodes=5,
+            label="L1",
+            n_curves=1000,
         )
         self.priors = bilby.gw.prior.BBHPriorDict()
         self.priors["geocent_time"] = bilby.core.prior.Uniform(0, 4)
@@ -471,14 +473,6 @@ class CalibrationMarginalization(unittest.TestCase):
             n_nodes=5,
             label="H1",
         ))
-        self.priors.update(bilby.gw.prior.CalibrationPriorDict.constant_uncertainty_spline(
-            amplitude_sigma=0.1,
-            phase_sigma=0.1,
-            minimum_frequency=20,
-            maximum_frequency=512,
-            n_nodes=5,
-            label="L1",
-        ))
         self.wfg = bilby.gw.waveform_generator.WaveformGenerator(
             frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
             parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
@@ -518,8 +512,7 @@ class CalibrationMarginalization(unittest.TestCase):
         for ii in range(100):
             for name, value in draws["H1"].items():
                 non_marginalized.parameters[name] = value[ii]
-            for name, value in draws["L1"].items():
-                non_marginalized.parameters[name] = value[ii]
+            non_marginalized.parameters["recalib_index_L1"] = draws["L1"]["recalib_index_L1"][ii]
             non_marg_ln_ls.append(non_marginalized.log_likelihood_ratio())
         non_marg_ln_l = logsumexp(non_marg_ln_ls, b=1 / 100)
         self.assertAlmostEqual(marg_ln_l, non_marg_ln_l)
-- 
GitLab