From 638629ef8b55e59ae34c073f49900559ba223ba1 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Thu, 9 Sep 2021 16:30:53 +0000
Subject: [PATCH] Fix snr calculation

---
 bilby/gw/conversion.py |  4 ++--
 bilby/gw/likelihood.py | 29 +++++++++++++++--------------
 requirements.txt       |  1 +
 3 files changed, 18 insertions(+), 16 deletions(-)

diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py
index 77772ad64..758b0fdbe 100644
--- a/bilby/gw/conversion.py
+++ b/bilby/gw/conversion.py
@@ -1165,10 +1165,10 @@ def compute_snrs(sample, likelihood, npool=1):
                 logger.info(
                     "Using a pool with size {} for nsamples={}".format(npool, len(sample))
                 )
-                new_samples = np.array(pool.map(_compute_snrs, tqdm(fill_args, file=sys.stdout)))
+                new_samples = pool.map(_compute_snrs, tqdm(fill_args, file=sys.stdout))
                 pool.close()
             else:
-                new_samples = np.array([_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)])
+                new_samples = [_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)]
 
             for ii, ifo in enumerate(likelihood.interferometers):
                 matched_filter_snrs = list()
diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py
index 476f585ac..a6ab1a2b7 100644
--- a/bilby/gw/likelihood.py
+++ b/bilby/gw/likelihood.py
@@ -4,6 +4,7 @@ import json
 import copy
 import math
 
+import attr
 import numpy as np
 import pandas as pd
 from scipy.special import logsumexp
@@ -23,7 +24,6 @@ from .utils import (
     ln_i0
 )
 from .waveform_generator import WaveformGenerator
-from collections import namedtuple
 
 
 class GravitationalWaveTransient(Likelihood):
@@ -117,13 +117,14 @@ class GravitationalWaveTransient(Likelihood):
 
     """
 
-    _CalculatedSNRs = namedtuple('CalculatedSNRs',
-                                 ['d_inner_h',
-                                  'optimal_snr_squared',
-                                  'complex_matched_filter_snr',
-                                  'd_inner_h_array',
-                                  'optimal_snr_squared_array',
-                                  'd_inner_h_squared_tc_array'])
+    @attr.s
+    class _CalculatedSNRs:
+        d_inner_h = attr.ib()
+        optimal_snr_squared = attr.ib()
+        complex_matched_filter_snr = attr.ib()
+        d_inner_h_array = attr.ib()
+        optimal_snr_squared_array = attr.ib()
+        d_inner_h_squared_tc_array = attr.ib()
 
     def __init__(
         self, interferometers, waveform_generator, time_marginalization=False,
@@ -218,19 +219,19 @@ class GravitationalWaveTransient(Likelihood):
         """
 
         attributes = ['duration', 'sampling_frequency', 'start_time']
-        for attr in attributes:
-            wfg_attr = getattr(self.waveform_generator, attr)
-            ifo_attr = getattr(self.interferometers, attr)
+        for attribute in attributes:
+            wfg_attr = getattr(self.waveform_generator, attribute)
+            ifo_attr = getattr(self.interferometers, attribute)
             if wfg_attr is None:
                 logger.debug(
                     "The waveform_generator {} is None. Setting from the "
-                    "provided interferometers.".format(attr))
+                    "provided interferometers.".format(attribute))
             elif wfg_attr != ifo_attr:
                 logger.debug(
                     "The waveform_generator {} is not equal to that of the "
                     "provided interferometers. Overwriting the "
-                    "waveform_generator.".format(attr))
-            setattr(self.waveform_generator, attr, ifo_attr)
+                    "waveform_generator.".format(attribute))
+            setattr(self.waveform_generator, attribute, ifo_attr)
 
     def calculate_snrs(self, waveform_polarizations, interferometer):
         """
diff --git a/requirements.txt b/requirements.txt
index 423a97d09..c3bf8ebb8 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -11,3 +11,4 @@ tqdm
 h5py
 tables
 astropy
+attrs
-- 
GitLab