diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py index 88e7234a20820d1eee151118b717a5d0c6f91cb8..c53ee9ff8063eafd53bb60f9585a7167a976d4d4 100644 --- a/bilby/gw/likelihood/multiband.py +++ b/bilby/gw/likelihood/multiband.py @@ -1,12 +1,15 @@ import math +import numbers import numpy as np from .base import GravitationalWaveTransient from ...core.utils import ( logger, speed_of_light, solar_mass, radius_of_earth, - gravitational_constant, round_up_to_power_of_two + gravitational_constant, round_up_to_power_of_two, + recursively_load_dict_contents_from_group, + recursively_save_dict_contents_to_group ) from ..prior import CBCPriorDict @@ -42,6 +45,8 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): A maximum frequency for multi-banding. If specified, the low-frequency limit of a band does not exceed it. minimum_banding_duration: float, optional A minimum duration for multi-banding. If specified, the duration of a band is not smaller than it. + weights: str or dict, optional + Pre-computed multiband weights for calculating inner products. distance_marginalization: bool, optional If true, marginalize over distance in the likelihood. This uses a look up table calculated at run time. The distance prior is set to be a delta function at the minimum distance allowed in the prior being marginalised @@ -76,9 +81,9 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): def __init__( self, interferometers, waveform_generator, reference_chirp_mass=None, highest_mode=2, linear_interpolation=True, accuracy_factor=5, time_offset=None, delta_f_end=None, - maximum_banding_frequency=None, minimum_banding_duration=0., distance_marginalization=False, - phase_marginalization=False, priors=None, distance_marginalization_lookup_table=None, - reference_frame="sky", time_reference="geocenter" + maximum_banding_frequency=None, minimum_banding_duration=0., weights=None, + distance_marginalization=False, phase_marginalization=False, priors=None, + distance_marginalization_lookup_table=None, reference_frame="sky", time_reference="geocenter" ): super(MBGravitationalWaveTransient, self).__init__( interferometers=interferometers, waveform_generator=waveform_generator, priors=priors, @@ -86,17 +91,23 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): time_marginalization=False, distance_marginalization_lookup_table=distance_marginalization_lookup_table, jitter_time=False, reference_frame=reference_frame, time_reference=time_reference ) - self.reference_chirp_mass = reference_chirp_mass - self.highest_mode = highest_mode - self.linear_interpolation = linear_interpolation - self.accuracy_factor = accuracy_factor - self.time_offset = time_offset - self.delta_f_end = delta_f_end - self.minimum_frequency = np.min([i.minimum_frequency for i in self.interferometers]) - self.maximum_frequency = np.max([i.maximum_frequency for i in self.interferometers]) - self.maximum_banding_frequency = maximum_banding_frequency - self.minimum_banding_duration = minimum_banding_duration - self.setup_multibanding() + if weights is None: + self.reference_chirp_mass = reference_chirp_mass + self.highest_mode = highest_mode + self.linear_interpolation = linear_interpolation + self.accuracy_factor = accuracy_factor + self.time_offset = time_offset + self.delta_f_end = delta_f_end + self.maximum_banding_frequency = maximum_banding_frequency + self.minimum_banding_duration = minimum_banding_duration + self.setup_multibanding() + else: + if isinstance(weights, str): + import h5py + logger.info(f"Loading multiband weights from {weights}.") + with h5py.File(weights, 'r') as f: + weights = recursively_load_dict_contents_from_group(f, '/') + self.setup_multibanding_from_weights(weights) @property def reference_chirp_mass(self): @@ -108,7 +119,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): @reference_chirp_mass.setter def reference_chirp_mass(self, reference_chirp_mass): - if isinstance(reference_chirp_mass, int) or isinstance(reference_chirp_mass, float): + if isinstance(reference_chirp_mass, numbers.Number): self._reference_chirp_mass = reference_chirp_mass else: logger.info( @@ -136,7 +147,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): @highest_mode.setter def highest_mode(self, highest_mode): - if isinstance(highest_mode, int) or isinstance(highest_mode, float): + if isinstance(highest_mode, numbers.Number): self._highest_mode = highest_mode else: raise TypeError("highest_mode must be a number") @@ -147,7 +158,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): @linear_interpolation.setter def linear_interpolation(self, linear_interpolation): - if isinstance(linear_interpolation, bool): + if isinstance(linear_interpolation, bool) or isinstance(linear_interpolation, np.bool_): self._linear_interpolation = linear_interpolation else: raise TypeError("linear_interpolation must be a bool") @@ -158,7 +169,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): @accuracy_factor.setter def accuracy_factor(self, accuracy_factor): - if isinstance(accuracy_factor, int) or isinstance(accuracy_factor, float): + if isinstance(accuracy_factor, numbers.Number): self._accuracy_factor = accuracy_factor else: raise TypeError("accuracy_factor must be a number") @@ -182,7 +193,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): else: safety = 2 * radius_of_earth / speed_of_light if time_offset is not None: - if isinstance(time_offset, int) or isinstance(time_offset, float): + if isinstance(time_offset, numbers.Number): self._time_offset = time_offset else: raise TypeError("time_offset must be a number") @@ -216,7 +227,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): else: safety = 2 * radius_of_earth / speed_of_light if delta_f_end is not None: - if isinstance(delta_f_end, int) or isinstance(delta_f_end, float): + if isinstance(delta_f_end, numbers.Number): self._delta_f_end = delta_f_end else: raise TypeError("delta_f_end must be a number") @@ -248,7 +259,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): / self.reference_chirp_mass_in_second ) if maximum_banding_frequency is not None: - if isinstance(maximum_banding_frequency, int) or isinstance(maximum_banding_frequency, float): + if isinstance(maximum_banding_frequency, numbers.Number): if maximum_banding_frequency < fmax_tmp: fmax_tmp = maximum_banding_frequency else: @@ -264,11 +275,23 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): @minimum_banding_duration.setter def minimum_banding_duration(self, minimum_banding_duration): - if isinstance(minimum_banding_duration, int) or isinstance(minimum_banding_duration, float): + if isinstance(minimum_banding_duration, numbers.Number): self._minimum_banding_duration = minimum_banding_duration else: raise TypeError("minimum_banding_duration must be a number") + @property + def minimum_frequency(self): + return np.min([i.minimum_frequency for i in self.interferometers]) + + @property + def maximum_frequency(self): + return np.max([i.maximum_frequency for i in self.interferometers]) + + @property + def number_of_bands(self): + return len(self.durations) + def setup_multibanding(self): """Set up frequency bands and coefficients needed for likelihood evaluations""" self._setup_frequency_bands() @@ -365,25 +388,26 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): original duration. This sets the following instance variables. durations: durations of bands (T^(b) in the paper) - fb_dfb: the list of tuples, which contain starting frequencies (f^(b) in the paper) and frequency scales for + fb_dfb: 2-dimensional ndarray, which contain starting frequencies (f^(b) in the paper) and frequency scales for smoothing waveforms (\Delta f^(b) in the paper) of bands """ - self.durations = [self.interferometers.duration] - self.fb_dfb = [(self.minimum_frequency, 0.)] + self.durations = np.array([self.interferometers.duration]) + self.fb_dfb = [[self.minimum_frequency, 0.]] dnext = self.interferometers.duration / 2 while dnext > max(self.time_offset, self.minimum_banding_duration): fnow, _ = self.fb_dfb[-1] fnext, dfnext = self._find_starting_frequency(dnext, fnow) if fnext is not None and fnext < min(self.maximum_frequency, self.maximum_banding_frequency): - self.durations.append(dnext) - self.fb_dfb.append((fnext, dfnext)) + self.durations = np.append(self.durations, dnext) + self.fb_dfb.append([fnext, dfnext]) dnext /= 2 else: break - self.fb_dfb.append((self.maximum_frequency + self.delta_f_end, self.delta_f_end)) + self.fb_dfb.append([self.maximum_frequency + self.delta_f_end, self.delta_f_end]) + self.fb_dfb = np.array(self.fb_dfb) logger.info("The total frequency range is divided into {} bands with frequency intervals of {}.".format( - len(self.durations), ", ".join(["1/{} Hz".format(d) for d in self.durations]))) + self.number_of_bands, ", ".join(["1/{} Hz".format(d) for d in self.durations]))) def _setup_integers(self): """Set up integers needed for likelihood evaluations. This sets the following instance variables. @@ -393,17 +417,18 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): Ks_Ke: start and end frequency indices of bands (K^(b)_s and K^(b)_e in the paper) """ - self.Nbs = [] - self.Mbs = [] + self.Nbs = np.array([], dtype=int) + self.Mbs = np.array([], dtype=int) self.Ks_Ke = [] - for b in range(len(self.durations)): + for b in range(self.number_of_bands): dnow = self.durations[b] fnow, dfnow = self.fb_dfb[b] fnext, _ = self.fb_dfb[b + 1] Nb = max(round_up_to_power_of_two(2. * (fnext * self.interferometers.duration + 1.)), 2**b) - self.Nbs.append(Nb) - self.Mbs.append(Nb // 2**b) - self.Ks_Ke.append((math.ceil((fnow - dfnow) * dnow), math.floor(fnext * dnow))) + self.Nbs = np.append(self.Nbs, Nb) + self.Mbs = np.append(self.Mbs, Nb // 2**b) + self.Ks_Ke.append([math.ceil((fnow - dfnow) * dnow), math.floor(fnext * dnow)]) + self.Ks_Ke = np.array(self.Ks_Ke) def _setup_waveform_frequency_points(self): """Set up frequency points where waveforms are evaluated. Frequency points are reordered because some waveform @@ -419,13 +444,14 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): self.banded_frequency_points = np.array([]) self.start_end_idxs = [] start_idx = 0 - for i in range(len(self.fb_dfb) - 1): + for i in range(self.number_of_bands): d = self.durations[i] Ks, Ke = self.Ks_Ke[i] self.banded_frequency_points = np.append(self.banded_frequency_points, np.arange(Ks, Ke + 1) / d) end_idx = start_idx + Ke - Ks - self.start_end_idxs.append((start_idx, end_idx)) + self.start_end_idxs.append([start_idx, end_idx]) start_idx = end_idx + 1 + self.start_end_idxs = np.array(self.start_end_idxs) unique_frequencies, idxs = np.unique(self.banded_frequency_points, return_inverse=True) self.waveform_generator.waveform_arguments['frequencies'] = unique_frequencies self.unique_to_original_frequencies = idxs @@ -435,35 +461,54 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): (self.maximum_frequency - self.minimum_frequency) * self.interferometers.duration / len(unique_frequencies))) - def _window(self, f, b): - """Compute window function in the b-th band + def _get_window_sequence(self, delta_f, start_idx, length, b): + """Compute window function on frequencies with a fixed frequency interval Parameters ---------- - f: float or ndarray - frequency at which the window function is computed + delta_f: float + frequency interval + start_idx: int + starting frequency per delta_f + length: int + number of frequencies b: int + band number Returns ------- - window: float - window function at f + window_sequence: array + """ fnow, dfnow = self.fb_dfb[b] fnext, dfnext = self.fb_dfb[b + 1] - @np.vectorize - def _vectorized_window(f): - if fnow - dfnow < f < fnow: - return (1. + np.cos(np.pi * (f - fnow) / dfnow)) / 2. - elif fnow <= f <= fnext - dfnext: - return 1. - elif fnext - dfnext < f < fnext: - return (1. - np.cos(np.pi * (f - fnext) / dfnext)) / 2. - else: - return 0. + window_sequence = np.zeros(length) + increase_start = np.clip( + math.floor((fnow - dfnow) / delta_f) - start_idx + 1, 0, length + ) + unity_start = np.clip(math.ceil(fnow / delta_f) - start_idx, 0, length) + decrease_start = np.clip( + math.floor((fnext - dfnext) / delta_f) - start_idx + 1, 0, length + ) + decrease_stop = np.clip(math.ceil(fnext / delta_f) - start_idx, 0, length) + + window_sequence[unity_start:decrease_start] = 1. + + # this if statement avoids overflow caused by vanishing dfnow + if increase_start < unity_start: + frequencies = (np.arange(increase_start, unity_start) + start_idx) * delta_f + window_sequence[increase_start:unity_start] = ( + 1. + np.cos(np.pi * (frequencies - fnow) / dfnow) + ) / 2. - return _vectorized_window(f) + if decrease_start < decrease_stop: + frequencies = (np.arange(decrease_start, decrease_stop) + start_idx) * delta_f + window_sequence[decrease_start:decrease_stop] = ( + 1. - np.cos(np.pi * (frequencies - fnext) / dfnext) + ) / 2. + + return window_sequence def _setup_linear_coefficients(self): """Set up coefficients by which waveforms are multiplied to compute (d, h)""" @@ -474,9 +519,9 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): fddata = np.zeros(N // 2 + 1, dtype=complex) fddata[:len(ifo.frequency_domain_strain)][ifo.frequency_mask] += \ ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask] - for b in range(len(self.fb_dfb) - 1): - start_idx, end_idx = self.start_end_idxs[b] - windows = self._window(self.banded_frequency_points[start_idx:end_idx + 1], b) + for b in range(self.number_of_bands): + Ks, Ke = self.Ks_Ke[b] + windows = self._get_window_sequence(1. / self.durations[b], Ks, Ke - Ks + 1, b) fddata_in_ith_band = np.copy(fddata[:int(self.Nbs[b] / 2 + 1)]) fddata_in_ith_band[-1] = 0. # zeroing data at the Nyquist frequency tddata = np.fft.irfft(fddata_in_ith_band)[-self.Mbs[b]:] @@ -490,36 +535,60 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): linear-interpolation algorithm""" logger.info("Linear-interpolation algorithm is used for (h, h).") self.quadratic_coeffs = dict((ifo.name, np.array([])) for ifo in self.interferometers) - N = self.Nbs[-1] - for ifo in self.interferometers: - logger.info("Pre-computing quadratic coefficients for {}".format(ifo.name)) - full_frequencies = np.arange(N // 2 + 1) / ifo.duration - full_inv_psds = np.zeros(N // 2 + 1) - full_inv_psds[:len(ifo.power_spectral_density_array)][ifo.frequency_mask] = \ - 1. / ifo.power_spectral_density_array[ifo.frequency_mask] - for i in range(len(self.fb_dfb) - 1): - start_idx, end_idx = self.start_end_idxs[i] - banded_frequencies = self.banded_frequency_points[start_idx:end_idx + 1] + original_duration = self.interferometers.duration + + for b in range(self.number_of_bands): + logger.info(f"Pre-computing quadratic coefficients for the {b}-th band") + _start, _end = self.start_end_idxs[b] + banded_frequencies = self.banded_frequency_points[_start:_end + 1] + prefactor = 4 * self.durations[b] / original_duration + + # precompute window values + _fnow, _dfnow = self.fb_dfb[b] + _fnext, _ = self.fb_dfb[b + 1] + start_idx_in_band = math.ceil((_fnow - _dfnow) * original_duration) + window_sequence = self._get_window_sequence( + 1 / original_duration, + start_idx_in_band, + math.floor(_fnext * original_duration) - start_idx_in_band + 1, + b + ) + + for ifo in self.interferometers: + end_idx_in_band = min( + start_idx_in_band + len(window_sequence) - 1, + len(ifo.power_spectral_density_array) - 1 + ) + _frequency_mask = ifo.frequency_mask[start_idx_in_band:end_idx_in_band + 1] + window_over_psd = np.zeros(end_idx_in_band + 1 - start_idx_in_band) + window_over_psd[_frequency_mask] = \ + 1. / ifo.power_spectral_density_array[start_idx_in_band:end_idx_in_band + 1][_frequency_mask] + window_over_psd *= window_sequence[:len(window_over_psd)] + coeffs = np.zeros(len(banded_frequencies)) for k in range(len(coeffs) - 1): if k == 0: - start_idx_in_sum = 0 + start_idx_in_sum = start_idx_in_band else: - start_idx_in_sum = math.ceil(ifo.duration * banded_frequencies[k]) + start_idx_in_sum = max( + start_idx_in_band, + math.ceil(original_duration * banded_frequencies[k]) + ) if k == len(coeffs) - 2: - end_idx_in_sum = len(full_frequencies) - 1 + end_idx_in_sum = end_idx_in_band else: - end_idx_in_sum = math.ceil(ifo.duration * banded_frequencies[k + 1]) - 1 - window_over_psd = ( - full_inv_psds[start_idx_in_sum:end_idx_in_sum + 1] - * self._window(full_frequencies[start_idx_in_sum:end_idx_in_sum + 1], i) - ) - frequencies_in_sum = full_frequencies[start_idx_in_sum:end_idx_in_sum + 1] - coeffs[k] += 4 * self.durations[i] / ifo.duration * np.sum( - (banded_frequencies[k + 1] - frequencies_in_sum) * window_over_psd + end_idx_in_sum = min( + end_idx_in_band, + math.ceil(original_duration * banded_frequencies[k + 1]) - 1 + ) + frequencies_in_sum = np.arange(start_idx_in_sum, end_idx_in_sum + 1) / original_duration + coeffs[k] += prefactor * np.sum( + (banded_frequencies[k + 1] - frequencies_in_sum) * + window_over_psd[start_idx_in_sum - start_idx_in_band:end_idx_in_sum - start_idx_in_band + 1] ) - coeffs[k + 1] += 4 * self.durations[i] / ifo.duration * np.sum( - (frequencies_in_sum - banded_frequencies[k]) * window_over_psd + coeffs[k + 1] += prefactor * np.sum( + (frequencies_in_sum - banded_frequencies[k]) * + window_over_psd[start_idx_in_sum - start_idx_in_band:end_idx_in_sum - start_idx_in_band + 1] ) self.quadratic_coeffs[ifo.name] = np.append(self.quadratic_coeffs[ifo.name], coeffs) @@ -540,7 +609,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): full_inv_psds[:len(ifo.power_spectral_density_array)][ifo.frequency_mask] = ( 1 / ifo.power_spectral_density_array[ifo.frequency_mask] ) - for b in range(len(self.fb_dfb) - 1): + for b in range(self.number_of_bands): Imb = np.fft.irfft(full_inv_psds[:self.Nbs[b] // 2 + 1]) half_length = Nhatbs[b] // 2 Imbc = np.append(Imb[:half_length + 1], Imb[-(Nhatbs[b] - half_length - 1):]) @@ -551,12 +620,79 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): # precompute windows and their squares self.windows = np.array([]) self.square_root_windows = np.array([]) - for b in range(len(self.fb_dfb) - 1): - start, end = self.start_end_idxs[b] - ws = self._window(self.banded_frequency_points[start:end + 1], b) + for b in range(self.number_of_bands): + Ks, Ke = self.Ks_Ke[b] + ws = self._get_window_sequence(1. / self.durations[b], Ks, Ke - Ks + 1, b) self.windows = np.append(self.windows, ws) self.square_root_windows = np.append(self.square_root_windows, np.sqrt(ws)) + @property + def weights(self): + _weights = {} + for key in [ + "reference_chirp_mass", "highest_mode", "linear_interpolation", + "accuracy_factor", "time_offset", "delta_f_end", + "maximum_banding_frequency", "minimum_banding_duration", + "durations", "fb_dfb", "Nbs", "Mbs", "Ks_Ke", + "banded_frequency_points", "start_end_idxs", + "unique_to_original_frequencies", "linear_coeffs" + ]: + _weights[key] = getattr(self, key) + _weights["waveform_frequencies"] = \ + self.waveform_generator.waveform_arguments['frequencies'] + if self.linear_interpolation: + _weights["quadratic_coeffs"] = self.quadratic_coeffs + else: + for key in ["Tbhats", "windows", "square_root_windows"]: + _weights[key] = getattr(self, key) + for key in ["wths", "hbcs", "Ibcs"]: + _weights[key] = {} + value = getattr(self, key) + for ifo_name, data in value.items(): + _weights[key][ifo_name] = dict((str(b), v) for b, v in enumerate(data)) + return _weights + + def save_weights(self, filename): + """ + Save multiband weights into a .hdf5 file. + + Parameters + ========== + filename : str + + """ + import h5py + if not filename.endswith(".hdf5"): + filename += ".hdf5" + logger.info(f"Saving multiband weights to {filename}") + with h5py.File(filename, 'w') as f: + recursively_save_dict_contents_to_group(f, '/', self.weights) + + def setup_multibanding_from_weights(self, weights): + """ + Set multiband weights from dictionary-like weights + + Parameters + ========== + weights : dict + + """ + keys = list(weights.keys()) + # reference_chirp_mass needs to be set first as it is required for the setter of maximum_banding_frequency + self.reference_chirp_mass = weights["reference_chirp_mass"] + keys.remove("reference_chirp_mass") + for key in keys: + value = weights[key] + if key in ["wths", "hbcs", "Ibcs"]: + to_set = {} + for ifo_name, data in value.items(): + to_set[ifo_name] = [data[str(b)] for b in range(len(data.keys()))] + setattr(self, key, to_set) + elif key == "waveform_frequencies": + self.waveform_generator.waveform_arguments['frequencies'] = weights["waveform_frequencies"] + else: + setattr(self, key, value) + def calculate_snrs(self, waveform_polarizations, interferometer): """ Compute the snrs for multi-banding @@ -601,7 +737,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient): ) else: optimal_snr_squared = 0. - for b in range(len(self.fb_dfb) - 1): + for b in range(self.number_of_bands): Ks, Ke = self.Ks_Ke[b] start_idx, end_idx = self.start_end_idxs[b] Mb = self.Mbs[b] diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index e7e707c287aeb473cf3bb0528d9c78e876a39dd4..47a4a6aca3b92be3bc841d60f571f4f8029c7cab 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -2,6 +2,7 @@ import itertools import os import pytest import unittest +import tempfile from copy import deepcopy from itertools import product from parameterized import parameterized @@ -1794,6 +1795,113 @@ class TestMBLikelihood(unittest.TestCase): interferometers=self.ifos, waveform_generator=wfg_mb, priors=self.priors ) + @parameterized.expand([(True, ), (False, )]) + def test_inout_weights(self, linear_interpolation): + """ + Check if multiband weights can be saved as a file, and a likelihood object constructed from the weights file + produces the same likelihood value. + """ + approximant = "IMRPhenomD" + wfg = bilby.gw.WaveformGenerator( + duration=self.duration, sampling_frequency=self.sampling_frequency, + frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, + waveform_arguments=dict( + reference_frequency=self.fmin, approximant=approximant + ) + ) + self.ifos.inject_signal( + parameters=self.test_parameters, waveform_generator=wfg + ) + + wfg_mb = bilby.gw.WaveformGenerator( + duration=self.duration, sampling_frequency=self.sampling_frequency, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, + waveform_arguments=dict( + reference_frequency=self.fmin, approximant=approximant + ) + ) + likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient( + interferometers=self.ifos, waveform_generator=wfg_mb, + reference_chirp_mass=self.test_parameters['chirp_mass'], + linear_interpolation=linear_interpolation, + ) + likelihood_mb.parameters.update(self.test_parameters) + llr = likelihood_mb.log_likelihood_ratio() + + with tempfile.TemporaryDirectory() as tmpdirname: + # check if weights can be saved as a file + filepath = os.path.join(tmpdirname, "weights.hdf5") + likelihood_mb.save_weights(filepath) + self.assertTrue(os.path.exists(filepath)) + + # reset waveform generator to check if likelihood recovered from the weights file properly adds banded + # frequency points to waveform arguments + wfg_mb = bilby.gw.WaveformGenerator( + duration=self.duration, sampling_frequency=self.sampling_frequency, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, + waveform_arguments=dict( + reference_frequency=self.fmin, approximant=approximant + ) + ) + likelihood_mb_from_weights = bilby.gw.likelihood.MBGravitationalWaveTransient( + interferometers=self.ifos, waveform_generator=wfg_mb, weights=filepath + ) + + likelihood_mb_from_weights.parameters.update(self.test_parameters) + llr_from_weights = likelihood_mb_from_weights.log_likelihood_ratio() + + self.assertAlmostEqual(llr, llr_from_weights) + + @parameterized.expand([(True, ), (False, )]) + def test_from_dict_weights(self, linear_interpolation): + """ + Check if a likelihood object constructed from dictionary-like weights produce the same likelihood value + """ + approximant = "IMRPhenomD" + wfg = bilby.gw.WaveformGenerator( + duration=self.duration, sampling_frequency=self.sampling_frequency, + frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole, + waveform_arguments=dict( + reference_frequency=self.fmin, approximant=approximant + ) + ) + self.ifos.inject_signal( + parameters=self.test_parameters, waveform_generator=wfg + ) + + wfg_mb = bilby.gw.WaveformGenerator( + duration=self.duration, sampling_frequency=self.sampling_frequency, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, + waveform_arguments=dict( + reference_frequency=self.fmin, approximant=approximant + ) + ) + likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient( + interferometers=self.ifos, waveform_generator=wfg_mb, + reference_chirp_mass=self.test_parameters['chirp_mass'], + linear_interpolation=linear_interpolation, + ) + likelihood_mb.parameters.update(self.test_parameters) + llr = likelihood_mb.log_likelihood_ratio() + + # reset waveform generator to check if likelihood recovered from the weights properly adds banded + # frequency points to waveform arguments + wfg_mb = bilby.gw.WaveformGenerator( + duration=self.duration, sampling_frequency=self.sampling_frequency, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence, + waveform_arguments=dict( + reference_frequency=self.fmin, approximant=approximant + ) + ) + weights = likelihood_mb.weights + likelihood_mb_from_weights = bilby.gw.likelihood.MBGravitationalWaveTransient( + interferometers=self.ifos, waveform_generator=wfg_mb, weights=weights + ) + likelihood_mb_from_weights.parameters.update(self.test_parameters) + llr_from_weights = likelihood_mb_from_weights.log_likelihood_ratio() + + self.assertAlmostEqual(llr, llr_from_weights) + if __name__ == "__main__": unittest.main()