From 56a58aebd3a2760634dd3c827a7954bb45fd8a14 Mon Sep 17 00:00:00 2001
From: Rhiannon Udall <rhiannon.udall@ligo.org>
Date: Thu, 4 Apr 2024 16:23:09 +0000
Subject: [PATCH] MAINT: Separate adding SNRs per IFO to a sample into a new
 function

---
 bilby/gw/conversion.py      | 18 ++++++++----------
 bilby/gw/likelihood/base.py | 14 ++++++++++++++
 2 files changed, 22 insertions(+), 10 deletions(-)

diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py
index 00d957f98..4be5634eb 100644
--- a/bilby/gw/conversion.py
+++ b/bilby/gw/conversion.py
@@ -2254,17 +2254,15 @@ def compute_snrs(sample, likelihood, npool=1):
                 new_samples = [_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)]
 
             for ii, ifo in enumerate(likelihood.interferometers):
-                matched_filter_snrs = list()
-                optimal_snrs = list()
-                mf_key = '{}_matched_filter_snr'.format(ifo.name)
-                optimal_key = '{}_optimal_snr'.format(ifo.name)
+                snr_updates = dict()
+                for key in new_samples[0][ii].snrs_as_sample.keys():
+                    snr_updates[f"{ifo.name}_{key}"] = list()
                 for new_sample in new_samples:
-                    matched_filter_snrs.append(new_sample[ii].complex_matched_filter_snr)
-                    optimal_snrs.append(new_sample[ii].optimal_snr_squared.real ** 0.5)
-
-                sample[mf_key] = matched_filter_snrs
-                sample[optimal_key] = optimal_snrs
-
+                    snr_update = new_sample[ii].snrs_as_sample
+                    for key, val in snr_update.items():
+                        snr_updates[f"{ifo.name}_{key}"].append(val)
+            for k, v in snr_updates.items():
+                sample[k] = v
     else:
         logger.debug('Not computing SNRs.')
 
diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py
index 256237bfb..e0a09c1e9 100644
--- a/bilby/gw/likelihood/base.py
+++ b/bilby/gw/likelihood/base.py
@@ -128,6 +128,20 @@ class GravitationalWaveTransient(Likelihood):
                     setattr(self, key, other)
             return self
 
+        @property
+        def snrs_as_sample(self) -> dict:
+            """Get the SNRs of this object as a sample dictionary
+
+            Returns
+            =======
+            dict
+                The dictionary of SNRs labelled accordingly
+            """
+            return {
+                "matched_filter_snr" : self.complex_matched_filter_snr,
+                "optimal_snr" : self.optimal_snr_squared.real ** 0.5
+            }
+
     def __init__(
             self, interferometers, waveform_generator, time_marginalization=False,
             distance_marginalization=False, phase_marginalization=False, calibration_marginalization=False, priors=None,
-- 
GitLab