From 638629ef8b55e59ae34c073f49900559ba223ba1 Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Thu, 9 Sep 2021 16:30:53 +0000 Subject: [PATCH] Fix snr calculation --- bilby/gw/conversion.py | 4 ++-- bilby/gw/likelihood.py | 29 +++++++++++++++-------------- requirements.txt | 1 + 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py index 77772ad64..758b0fdbe 100644 --- a/bilby/gw/conversion.py +++ b/bilby/gw/conversion.py @@ -1165,10 +1165,10 @@ def compute_snrs(sample, likelihood, npool=1): logger.info( "Using a pool with size {} for nsamples={}".format(npool, len(sample)) ) - new_samples = np.array(pool.map(_compute_snrs, tqdm(fill_args, file=sys.stdout))) + new_samples = pool.map(_compute_snrs, tqdm(fill_args, file=sys.stdout)) pool.close() else: - new_samples = np.array([_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)]) + new_samples = [_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)] for ii, ifo in enumerate(likelihood.interferometers): matched_filter_snrs = list() diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py index 476f585ac..a6ab1a2b7 100644 --- a/bilby/gw/likelihood.py +++ b/bilby/gw/likelihood.py @@ -4,6 +4,7 @@ import json import copy import math +import attr import numpy as np import pandas as pd from scipy.special import logsumexp @@ -23,7 +24,6 @@ from .utils import ( ln_i0 ) from .waveform_generator import WaveformGenerator -from collections import namedtuple class GravitationalWaveTransient(Likelihood): @@ -117,13 +117,14 @@ class GravitationalWaveTransient(Likelihood): """ - _CalculatedSNRs = namedtuple('CalculatedSNRs', - ['d_inner_h', - 'optimal_snr_squared', - 'complex_matched_filter_snr', - 'd_inner_h_array', - 'optimal_snr_squared_array', - 'd_inner_h_squared_tc_array']) + @attr.s + class _CalculatedSNRs: + d_inner_h = attr.ib() + optimal_snr_squared = attr.ib() + complex_matched_filter_snr = attr.ib() + d_inner_h_array = attr.ib() + optimal_snr_squared_array = attr.ib() + d_inner_h_squared_tc_array = attr.ib() def __init__( self, interferometers, waveform_generator, time_marginalization=False, @@ -218,19 +219,19 @@ class GravitationalWaveTransient(Likelihood): """ attributes = ['duration', 'sampling_frequency', 'start_time'] - for attr in attributes: - wfg_attr = getattr(self.waveform_generator, attr) - ifo_attr = getattr(self.interferometers, attr) + for attribute in attributes: + wfg_attr = getattr(self.waveform_generator, attribute) + ifo_attr = getattr(self.interferometers, attribute) if wfg_attr is None: logger.debug( "The waveform_generator {} is None. Setting from the " - "provided interferometers.".format(attr)) + "provided interferometers.".format(attribute)) elif wfg_attr != ifo_attr: logger.debug( "The waveform_generator {} is not equal to that of the " "provided interferometers. Overwriting the " - "waveform_generator.".format(attr)) - setattr(self.waveform_generator, attr, ifo_attr) + "waveform_generator.".format(attribute)) + setattr(self.waveform_generator, attribute, ifo_attr) def calculate_snrs(self, waveform_polarizations, interferometer): """ diff --git a/requirements.txt b/requirements.txt index 423a97d09..c3bf8ebb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ tqdm h5py tables astropy +attrs -- GitLab