From bb30ff01114a7171e3aaa46b4f8d8f0501874627 Mon Sep 17 00:00:00 2001
From: Sylvia Biscoveanu <sylvia.biscoveanu@ligo.org>
Date: Fri, 10 Feb 2023 13:19:20 +0000
Subject: [PATCH] BUGFIX: Fix whitening procedure

---
 bilby/gw/detector/interferometer.py     |  8 ++++++--
 bilby/gw/result.py                      | 22 ++++++++++++++--------
 test/gw/detector/interferometer_test.py | 18 ++++++++++++++++++
 3 files changed, 38 insertions(+), 10 deletions(-)

diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py
index 86c5c03da..6fa91a1de 100644
--- a/bilby/gw/detector/interferometer.py
+++ b/bilby/gw/detector/interferometer.py
@@ -618,13 +618,17 @@ class Interferometer(object):
 
     @property
     def whitened_frequency_domain_strain(self):
-        """ Calculates the whitened data by dividing data by the amplitude spectral density
+        """ Calculates the whitened data by dividing the frequency domain data by
+        ((amplitude spectral density) * (duration / 4) ** 0.5). The resulting
+        data will have unit variance.
 
         Returns
         =======
         array_like: The whitened data
         """
-        return self.strain_data.frequency_domain_strain / self.amplitude_spectral_density_array
+        return self.strain_data.frequency_domain_strain / (
+            self.amplitude_spectral_density_array * np.sqrt(self.duration / 4)
+        )
 
     def save_data(self, outdir, label=None):
         """ Creates save files for interferometer data in plain text format.
diff --git a/bilby/gw/result.py b/bilby/gw/result.py
index 197173b7a..5d1b6e7ed 100644
--- a/bilby/gw/result.py
+++ b/bilby/gw/result.py
@@ -377,6 +377,10 @@ class CompactBinaryCoalescenceResult(CoreResult):
         logger.debug("Downsampling frequency mask to {} values".format(
             len(frequency_idxs))
         )
+        frequency_window_factor = (
+            np.sum(interferometer.frequency_mask)
+            / len(interferometer.frequency_mask)
+        )
         plot_times = interferometer.time_array[time_idxs]
         plot_times -= interferometer.strain_data.start_time
         start_time -= interferometer.strain_data.start_time
@@ -447,10 +451,11 @@ class CompactBinaryCoalescenceResult(CoreResult):
                 fig.add_trace(
                     go.Scatter(
                         x=plot_times,
-                        y=infft(
-                            interferometer.whitened_frequency_domain_strain *
-                            np.sqrt(2. / interferometer.sampling_frequency),
-                            sampling_frequency=interferometer.strain_data.sampling_frequency)[time_idxs],
+                        y=np.fft.irfft(
+                            interferometer.whitened_frequency_domain_strain
+                            * np.sqrt(np.sum(interferometer.frequency_mask))
+                            / frequency_window_factor
+                        )[time_idxs],
                         fill=None,
                         mode='lines', line_color=DATA_COLOR,
                         opacity=0.5,
@@ -473,10 +478,11 @@ class CompactBinaryCoalescenceResult(CoreResult):
                     interferometer.amplitude_spectral_density_array[frequency_idxs],
                     color=DATA_COLOR, label='ASD')
                 axs[1].plot(
-                    plot_times, infft(
-                        interferometer.whitened_frequency_domain_strain *
-                        np.sqrt(2. / interferometer.sampling_frequency),
-                        sampling_frequency=interferometer.strain_data.sampling_frequency)[time_idxs],
+                    plot_times, np.fft.irfft(
+                        interferometer.whitened_frequency_domain_strain
+                        * np.sqrt(np.sum(interferometer.frequency_mask))
+                        / frequency_window_factor
+                    )[time_idxs],
                     color=DATA_COLOR, alpha=0.3)
             logger.debug('Plotted interferometer data.')
 
diff --git a/test/gw/detector/interferometer_test.py b/test/gw/detector/interferometer_test.py
index ad324e007..358045418 100644
--- a/test/gw/detector/interferometer_test.py
+++ b/test/gw/detector/interferometer_test.py
@@ -558,5 +558,23 @@ class TestInterferometerAntennaPatternAgainstLAL(unittest.TestCase):
                 self.assertAlmostEqual(std, 0.0, places=10)
 
 
+class TestInterferometerWhitenedStrain(unittest.TestCase):
+    def setUp(self):
+        self.ifo = bilby.gw.detector.get_empty_interferometer('H1')
+        self.ifo.set_strain_data_from_power_spectral_density(
+            sampling_frequency=4096, duration=64)
+
+    def tearDown(self):
+        del self.ifo
+
+    def test_whitened_strain(self):
+        mask = self.ifo.frequency_mask
+        white = self.ifo.whitened_frequency_domain_strain[mask]
+        std_real = np.std(white.real)
+        std_imag = np.std(white.imag)
+        self.assertAlmostEqual(std_real, 1, places=2)
+        self.assertAlmostEqual(std_imag, 1, places=2)
+
+
 if __name__ == "__main__":
     unittest.main()
-- 
GitLab