diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index 4fedef7717dfccc44e74d0f4a0c089c8f7b7af39..0a3abcbcc6dd1ee286de291832c229c49432f1bb 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -106,44 +106,28 @@ class GravitationalWaveTransient(Likelihood): """ - @attr.s + @attr.s(slots=True, weakref_slot=False) class _CalculatedSNRs: - d_inner_h = attr.ib() - optimal_snr_squared = attr.ib() - complex_matched_filter_snr = attr.ib() - d_inner_h_array = attr.ib() - optimal_snr_squared_array = attr.ib() - d_inner_h_squared_tc_array = attr.ib() + d_inner_h = attr.ib(default=0j, converter=complex) + optimal_snr_squared = attr.ib(default=0, converter=float) + complex_matched_filter_snr = attr.ib(default=0j, converter=complex) + d_inner_h_array = attr.ib(default=None) + optimal_snr_squared_array = attr.ib(default=None) def __add__(self, other_snr): - - total_d_inner_h = self.d_inner_h + other_snr.d_inner_h - total_optimal_snr_squared = self.optimal_snr_squared + \ - np.real(other_snr.optimal_snr_squared) - total_complex_matched_filter_snr = self.complex_matched_filter_snr + \ - other_snr.complex_matched_filter_snr - - total_d_inner_h_array = self.d_inner_h_array - if other_snr.d_inner_h_array is not None \ - and self.d_inner_h_array is not None: - total_d_inner_h_array += other_snr.d_inner_h_array - - total_optimal_snr_squared_array = self.optimal_snr_squared_array - if other_snr.optimal_snr_squared_array is not None \ - and self.optimal_snr_squared_array is not None: - total_optimal_snr_squared_array += other_snr.optimal_snr_squared_array - - total_d_inner_h_squared_tc_array = self.d_inner_h_squared_tc_array - if other_snr.d_inner_h_squared_tc_array is not None \ - and self.d_inner_h_squared_tc_array is not None: - total_d_inner_h_squared_tc_array += other_snr.d_inner_h_squared_tc_array - - return self.__class__(d_inner_h=total_d_inner_h, - optimal_snr_squared=total_optimal_snr_squared, - complex_matched_filter_snr=total_complex_matched_filter_snr, - d_inner_h_array=total_d_inner_h_array, - optimal_snr_squared_array=total_optimal_snr_squared_array, - d_inner_h_squared_tc_array=total_d_inner_h_squared_tc_array) + new = copy.deepcopy(self) + new += other_snr + return new + + def __iadd__(self, other_snr): + for key in self.__slots__: + this = getattr(self, key) + other = getattr(other_snr, key) + if this is not None and other is not None: + setattr(self, key, this + other) + elif this is None: + setattr(self, key, other) + return self def __init__( self, interferometers, waveform_generator, time_marginalization=False, @@ -331,7 +315,7 @@ class GravitationalWaveTransient(Likelihood): complex_matched_filter_snr=complex_matched_filter_snr, d_inner_h_array=d_inner_h_array, optimal_snr_squared_array=optimal_snr_squared_array, - d_inner_h_squared_tc_array=None) + ) def _check_marginalized_prior_is_set(self, key): if key in self.priors and self.priors[key].is_fixed: @@ -395,34 +379,15 @@ class GravitationalWaveTransient(Likelihood): def log_likelihood_ratio(self): waveform_polarizations = \ self.waveform_generator.frequency_domain_strain(self.parameters) + if waveform_polarizations is None: + return np.nan_to_num(-np.inf) if self.time_marginalization and self.jitter_time: self.parameters['geocent_time'] += self.parameters['time_jitter'] self.parameters.update(self.get_sky_frame_parameters()) - if waveform_polarizations is None: - return np.nan_to_num(-np.inf) - - total_snrs = self._CalculatedSNRs( - d_inner_h=0., optimal_snr_squared=0., complex_matched_filter_snr=0., - d_inner_h_array=None, optimal_snr_squared_array=None, d_inner_h_squared_tc_array=None) - - if self.time_marginalization and self.calibration_marginalization: - total_snrs.d_inner_h_array = np.zeros( - (self.number_of_response_curves, len(self.interferometers.frequency_array[0:-1])), - dtype=np.complex128) - total_snrs.optimal_snr_squared_array = \ - np.zeros(self.number_of_response_curves, dtype=np.complex128) - - elif self.time_marginalization: - total_snrs.d_inner_h_array = np.zeros(len(self._times), dtype=np.complex128) - - elif self.calibration_marginalization: - total_snrs.d_inner_h_array = \ - np.zeros(self.number_of_response_curves, dtype=np.complex128) - total_snrs.optimal_snr_squared_array = \ - np.zeros(self.number_of_response_curves, dtype=np.complex128) + total_snrs = self._CalculatedSNRs() for interferometer in self.interferometers: per_detector_snr = self.calculate_snrs( @@ -827,10 +792,7 @@ class GravitationalWaveTransient(Likelihood): signal_polarizations = \ self.waveform_generator.frequency_domain_strain(self.parameters) - total_snrs = self._CalculatedSNRs( - d_inner_h=0., optimal_snr_squared=0., complex_matched_filter_snr=0., - d_inner_h_array=np.zeros(self.number_of_response_curves, dtype=np.complex128), - optimal_snr_squared_array=np.zeros(self.number_of_response_curves, dtype=np.complex128)) + total_snrs = self._CalculatedSNRs() for interferometer in self.interferometers: per_detector_snr = self.calculate_snrs( diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index c53ee9ff8063eafd53bb60f9585a7167a976d4d4..f18bd394c8c4c30e6c941e53795eb60f8671250c 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -760,9 +760,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): return self._CalculatedSNRs( d_inner_h=d_inner_h, optimal_snr_squared=optimal_snr_squared, complex_matched_filter_snr=complex_matched_filter_snr, - d_inner_h_squared_tc_array=None, - d_inner_h_array=None, - optimal_snr_squared_array=None) + ) def _rescale_signal(self, signal, new_distance): for mode in signal: diff --git a/bilby/gw/likelihood/relative.py b/bilby/gw/likelihood/relative.py index efd366a74742b15281778cb5b1d194e23e6eb80f..5202a3534c52b07e5ff7aa2ee5bb02222951b52f 100644 --- a/bilby/gw/likelihood/relative.py +++ b/bilby/gw/likelihood/relative.py @@ -393,5 +393,5 @@ class RelativeBinningGravitationalWaveTransient(GravitationalWaveTransient): return self._CalculatedSNRs( d_inner_h=d_inner_h, optimal_snr_squared=optimal_snr_squared, complex_matched_filter_snr=complex_matched_filter_snr, - d_inner_h_array=d_inner_h_array, optimal_snr_squared_array=None, - d_inner_h_squared_tc_array=None) + d_inner_h_array=d_inner_h_array + ) diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index f7ba3b1db5da686e693dfcde267d7ae5c6e418fe..aa4371b17783340b776628f7c3c3dec90328c839 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -411,11 +411,10 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): if not in_bounds: logger.debug("SNR calculation error: requested time at edge of ROQ time samples") return self._CalculatedSNRs( - d_inner_h=np.nan_to_num(-np.inf), optimal_snr_squared=0, + d_inner_h=np.nan_to_num(-np.inf), + optimal_snr_squared=0, complex_matched_filter_snr=np.nan_to_num(-np.inf), - d_inner_h_squared_tc_array=None, - d_inner_h_array=None, - optimal_snr_squared_array=None) + ) d_inner_h_tc_array = np.einsum( 'i,ji->j', np.conjugate(h_linear), @@ -441,9 +440,8 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): return self._CalculatedSNRs( d_inner_h=d_inner_h, optimal_snr_squared=optimal_snr_squared, complex_matched_filter_snr=complex_matched_filter_snr, - d_inner_h_squared_tc_array=None, d_inner_h_array=d_inner_h_array, - optimal_snr_squared_array=None) + ) @staticmethod def _closest_time_indices(time, samples):