From c29fc37fd173793f9b50581e1392f4cb6e506c9b Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Mon, 6 Sep 2021 04:10:52 +0000
Subject: [PATCH] Add pool to compute_snrs

---
 bilby/gw/conversion.py | 65 +++++++++++++++++++++++++-----------------
 1 file changed, 39 insertions(+), 26 deletions(-)

diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py
index bec64ef87..77772ad64 100644
--- a/bilby/gw/conversion.py
+++ b/bilby/gw/conversion.py
@@ -829,7 +829,7 @@ def _generate_all_cbc_parameters(sample, defaults, base_conversion,
                     .format(type(output_sample))
                 )
     if likelihood is not None:
-        compute_snrs(output_sample, likelihood)
+        compute_snrs(output_sample, likelihood, npool=npool)
     for key, func in zip(["mass", "spin", "source frame"], [
             generate_mass_parameters, generate_spin_parameters,
             generate_source_frame_parameters]):
@@ -1130,7 +1130,7 @@ def generate_source_frame_parameters(sample):
     return output_sample
 
 
-def compute_snrs(sample, likelihood):
+def compute_snrs(sample, likelihood, npool=1):
     """
     Compute the optimal and matched filter snrs of all posterior samples
     and print it out.
@@ -1157,36 +1157,49 @@ def compute_snrs(sample, likelihood):
                     per_detector_snr.optimal_snr_squared.real ** 0.5
         else:
             from tqdm.auto import tqdm
-            logger.info(
-                'Computing SNRs for every sample.')
-
-            matched_filter_snrs = {
-                ifo.name: [] for ifo in likelihood.interferometers}
-            optimal_snrs = {ifo.name: [] for ifo in likelihood.interferometers}
-            for ii in tqdm(range(len(sample)), file=sys.stdout):
-                signal_polarizations =\
-                    likelihood.waveform_generator.frequency_domain_strain(
-                        dict(sample.iloc[ii]))
-                likelihood.parameters.update(sample.iloc[ii])
-                for ifo in likelihood.interferometers:
-                    per_detector_snr = likelihood.calculate_snrs(
-                        signal_polarizations, ifo)
-
-                    matched_filter_snrs[ifo.name].append(
-                        per_detector_snr.complex_matched_filter_snr)
-                    optimal_snrs[ifo.name].append(
-                        per_detector_snr.optimal_snr_squared.real ** 0.5)
+            logger.info('Computing SNRs for every sample.')
 
-            for ifo in likelihood.interferometers:
-                sample['{}_matched_filter_snr'.format(ifo.name)] =\
-                    matched_filter_snrs[ifo.name]
-                sample['{}_optimal_snr'.format(ifo.name)] =\
-                    optimal_snrs[ifo.name]
+            fill_args = [(ii, row, likelihood) for ii, row in sample.iterrows()]
+            if npool > 1:
+                pool = multiprocessing.Pool(processes=npool)
+                logger.info(
+                    "Using a pool with size {} for nsamples={}".format(npool, len(sample))
+                )
+                new_samples = np.array(pool.map(_compute_snrs, tqdm(fill_args, file=sys.stdout)))
+                pool.close()
+            else:
+                new_samples = np.array([_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)])
+
+            for ii, ifo in enumerate(likelihood.interferometers):
+                matched_filter_snrs = list()
+                optimal_snrs = list()
+                mf_key = '{}_matched_filter_snr'.format(ifo.name)
+                optimal_key = '{}_optimal_snr'.format(ifo.name)
+                for new_sample in new_samples:
+                    matched_filter_snrs.append(new_sample[ii].complex_matched_filter_snr)
+                    optimal_snrs.append(new_sample[ii].optimal_snr_squared.real ** 0.5)
+
+                sample[mf_key] = matched_filter_snrs
+                sample[optimal_key] = optimal_snrs
 
     else:
         logger.debug('Not computing SNRs.')
 
 
+def _compute_snrs(args):
+    """A wrapper of computing the SNRs to enable multiprocessing"""
+    ii, sample, likelihood = args
+    sample = dict(sample).copy()
+    signal_polarizations = likelihood.waveform_generator.frequency_domain_strain(
+        sample
+    )
+    likelihood.parameters.update(sample)
+    snrs = list()
+    for ifo in likelihood.interferometers:
+        snrs.append(likelihood.calculate_snrs(signal_polarizations, ifo))
+    return snrs
+
+
 def generate_posterior_samples_from_marginalized_likelihood(
         samples, likelihood, npool=1):
     """
-- 
GitLab