From aa53639d0a782a43b3cf10ba1e13cdaaf403e1d8 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Fri, 21 Oct 2022 11:30:06 -0700
Subject: [PATCH] make implementation a bit more generic

---
 bilby/gw/conversion.py              | 14 +------
 bilby/gw/detector/interferometer.py | 63 +++++------------------------
 2 files changed, 13 insertions(+), 64 deletions(-)

diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py
index 92abe19da..6c3cfacf9 100644
--- a/bilby/gw/conversion.py
+++ b/bilby/gw/conversion.py
@@ -1394,22 +1394,12 @@ def compute_snrs(sample, likelihood, npool=1):
 
     """
     if likelihood is not None:
-        if likelihood.__class__.__name__ == "RelativeBinningGravitationalWaveTransient":
-            logger.info("Relative Binning Likelihood; Calculating SNRs from Summary Data")
-
         if isinstance(sample, dict):
-            if likelihood.__class__.__name__ == "RelativeBinningGravitationalWaveTransient":
-                waveform_ratio = likelihood.compute_waveform_ratio(sample)
-            else:
-                signal_polarizations = likelihood.waveform_generator.frequency_domain_strain(sample)
+            signal_polarizations = likelihood.waveform_generator.frequency_domain_strain(sample)
             likelihood.parameters.update(sample)
 
             for ifo in likelihood.interferometers:
-                if likelihood.__class__.__name__ == "RelativeBinningGravitationalWaveTransient":
-                    per_detector_snr = likelihood.calculate_snrs_relative_binning(waveform_ratio[ifo.name], ifo)
-                else:
-                    per_detector_snr = likelihood.calculate_snrs(
-                        signal_polarizations, ifo)
+                per_detector_snr = likelihood.calculate_snrs(signal_polarizations, ifo)
                 sample['{}_matched_filter_snr'.format(ifo.name)] =\
                     per_detector_snr.complex_matched_filter_snr
                 sample['{}_optimal_snr'.format(ifo.name)] = \
diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py
index 10e965e35..93457aa32 100644
--- a/bilby/gw/detector/interferometer.py
+++ b/bilby/gw/detector/interferometer.py
@@ -279,7 +279,7 @@ class Interferometer(object):
         polarization_tensor = get_polarization_tensor(ra, dec, time, psi, mode)
         return three_by_three_matrix_contraction(self.geometry.detector_tensor, polarization_tensor)
 
-    def get_detector_response(self, waveform_polarizations, parameters):
+    def get_detector_response(self, waveform_polarizations, parameters, frequencies=None):
         """ Get the detector response for a particular waveform
 
         Parameters
@@ -293,6 +293,12 @@ class Interferometer(object):
         =======
         array_like: A 3x3 array representation of the detector response (signal observed in the interferometer)
         """
+        if frequencies is None:
+            frequencies = self.frequency_array[self.frequency_mask]
+            mask = self.frequency_mask
+        else:
+            mask = np.ones(len(frequencies), dtype=bool)
+
         signal = {}
         for mode in waveform_polarizations.keys():
             det_response = self.antenna_response(
@@ -314,58 +320,11 @@ class Interferometer(object):
         dt_geocent = parameters['geocent_time'] - self.strain_data.start_time
         dt = dt_geocent + time_shift
 
-        signal_ifo[self.strain_data.frequency_mask] = signal_ifo[self.strain_data.frequency_mask] * np.exp(
-            -1j * 2 * np.pi * dt * self.strain_data.frequency_array[self.strain_data.frequency_mask])
-
-        signal_ifo[self.strain_data.frequency_mask] *= self.calibration_model.get_calibration_factor(
-            self.strain_data.frequency_array[self.strain_data.frequency_mask],
-            prefix='recalib_{}_'.format(self.name), **parameters)
-
-        return signal_ifo
-
-    def get_detector_response_relative_binning(self, waveform_polarizations,
-                                               parameters, bin_frequencies):
-        """Get the detector response for a particular waveform, where the frequencies
-        of the waveform polarizations are only the binning frequencies and we
-        assume that there is no frequency mask necessary for the data. Kind of
-        a hacky workaround. Should do something better probably.
-
-        Parameters
-        -------
-        waveform_polarizations: dict
-            polarizations of the waveform
-        parameters: dict
-            parameters describing position and time of arrival of the signal
-
-        Returns
-        -------
-        array_like: A 3x3 array representation of the detector response (signal observed in the interferometer)
-        """
-        signal = {}
-        for mode in waveform_polarizations.keys():
-            det_response = self.antenna_response(
-                parameters['ra'],
-                parameters['dec'],
-                parameters['geocent_time'],
-                parameters['psi'], mode)
-
-            signal[mode] = waveform_polarizations[mode] * det_response
-        signal_ifo = sum(signal.values())
-
-        time_shift = self.time_delay_from_geocenter(
-            parameters['ra'], parameters['dec'], parameters['geocent_time'])
-
-        # Be careful to first subtract the two GPS times which are ~1e9 sec.
-        # And then add the time_shift which varies at ~1e-5 sec
-        dt_geocent = parameters['geocent_time'] - self.strain_data.start_time
-        dt = dt_geocent + time_shift
-
-        signal_ifo = signal_ifo * \
-            np.exp(-1j * 2 * np.pi * dt * bin_frequencies)
+        signal_ifo[mask] = signal_ifo[mask] * np.exp(-1j * 2 * np.pi * dt * frequencies)
 
-        signal_ifo *= self.calibration_model.get_calibration_factor(
-            bin_frequencies, prefix='recalib_{}_'.format(self.name),
-            **parameters)
+        signal_ifo[mask] *= self.calibration_model.get_calibration_factor(
+            frequencies, prefix='recalib_{}_'.format(self.name), **parameters
+        )
 
         return signal_ifo
 
-- 
GitLab