diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 7eb65b2b5ab004ce8cb6fb814958c51420d45844..d405a7e0d01c05783a37663058e5b19aaa1cabad 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -1465,7 +1465,7 @@ def _compute_snrs(args): ) snrs = list() for ifo in likelihood.interferometers: - snrs.append(likelihood.calculate_snrs(signal_polarizations, ifo)) + snrs.append(likelihood.calculate_snrs(signal_polarizations, ifo, return_array=False)) return snrs diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index e2d447a2a7f062fe7d4b0be4be1b1583ba544bc9..1daa9a1e5eee341764342a814c5ab19859eab977 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -236,16 +236,26 @@ class GravitationalWaveTransient(Likelihood): "waveform_generator.".format(attribute)) setattr(self.waveform_generator, attribute, ifo_attr) - def calculate_snrs(self, waveform_polarizations, interferometer): + def calculate_snrs(self, waveform_polarizations, interferometer, return_array=True): """ Compute the snrs Parameters - ========== + ---------- waveform_polarizations: dict A dictionary of waveform polarizations and the corresponding array interferometer: bilby.gw.detector.Interferometer The bilby interferometer object + return_array: bool + If true, calculate and return internal array objects + (d_inner_h_array and optimal_snr_squared_array), otherwise + these are returned as None. + + Returns + ------- + calculated_snrs: _CalculatedSNRs + An object containing the SNR quantities and (if return_array=True) + the internal array objects. """ signal = self._compute_full_waveform( @@ -266,7 +276,10 @@ class GravitationalWaveTransient(Likelihood): normalization = 4 / self.waveform_generator.duration - if self.time_marginalization and self.calibration_marginalization: + if return_array is False: + d_inner_h_array = None + optimal_snr_squared_array = None + elif self.time_marginalization and self.calibration_marginalization: d_inner_h_integrand = np.tile( interferometer.frequency_domain_strain.conjugate() * signal / diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index ee1bb1723fb0da9aebc72f435ebad83741f72511..3ed8819afaa81aca1e9712cf435596bee46131d1 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -693,18 +693,26 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): else: setattr(self, key, value) - def calculate_snrs(self, waveform_polarizations, interferometer): + def calculate_snrs(self, waveform_polarizations, interferometer, return_array=False): """ - Compute the snrs for multi-banding + Compute the snrs Parameters ---------- - waveform_polarizations: waveform + waveform_polarizations: dict + A dictionary of waveform polarizations and the corresponding array interferometer: bilby.gw.detector.Interferometer + The bilby interferometer object + return_array: bool + If true, calculate and return internal array objects + (d_inner_h_array and optimal_snr_squared_array), otherwise + these are returned as None. This parameter is ignored for the multiband + model as these arrays are never calculated. Returns ------- - snrs: named tuple of snrs + calculated_snrs: _CalculatedSNRs + An object containing the SNR quantities. """ strain = np.zeros(len(self.banded_frequency_points), dtype=complex) diff --git a/bilby/gw/likelihood/relative.py b/bilby/gw/likelihood/relative.py index 43a45e5ce92005eb58b9d0951eaab81780328922..64fd81ab130753b079eaa9f6694088a3c6ef4667 100644 --- a/bilby/gw/likelihood/relative.py +++ b/bilby/gw/likelihood/relative.py @@ -382,7 +382,7 @@ class RelativeBinningGravitationalWaveTransient(GravitationalWaveTransient): full_waveform_ratio = duplicated_r0 + duplicated_r1 * (f - duplicated_fm) return fiducial_waveform * full_waveform_ratio - def calculate_snrs(self, waveform_polarizations, interferometer): + def calculate_snrs(self, waveform_polarizations, interferometer, return_array=True): r0, r1 = self.compute_waveform_ratio_per_interferometer( waveform_polarizations=waveform_polarizations, interferometer=interferometer, @@ -393,7 +393,7 @@ class RelativeBinningGravitationalWaveTransient(GravitationalWaveTransient): optimal_snr_squared = h_inner_h complex_matched_filter_snr = d_inner_h / (optimal_snr_squared ** 0.5) - if self.time_marginalization: + if return_array and self.time_marginalization: full_waveform = self._compute_full_waveform( signal_polarizations=waveform_polarizations, interferometer=interferometer, diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index 96665c5047ae45bd0368727357cc13254fdf3360..a0e67fc8d1381d44e294d18f35dfe1fa8d367176 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -394,7 +394,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): def waveform_generator(self, waveform_generator): self._waveform_generator = waveform_generator - def calculate_snrs(self, waveform_polarizations, interferometer): + def calculate_snrs(self, waveform_polarizations, interferometer, return_array=True): """ Compute the snrs for ROQ @@ -458,7 +458,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): with np.errstate(invalid="ignore"): complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) - if self.time_marginalization: + if return_array and self.time_marginalization: ifo_times = self._times - interferometer.strain_data.start_time ifo_times += dt if self.jitter_time: