diff --git a/bilby/gw/likelihood.py b/bilby/gw/likelihood.py index cc9e37bb9429c6553c938736c05883d7ff5e7c08..7fc78e82d330d11cf2162397029259c18666ac25 100644 --- a/bilby/gw/likelihood.py +++ b/bilby/gw/likelihood.py @@ -741,12 +741,14 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): waveform_generator: `bilby.waveform_generator.WaveformGenerator` An object which computes the frequency-domain strain of the signal, given some set of parameters - linear_matrix: str, array + linear_matrix: str, array_like Either a string point to the file from which to load the linear_matrix array, or the array itself. - quadratic_matrix: str, array - Either a string point to the file from which to load the quadratic_matrix - array, or the array itself. + quadratic_matrix: str, array_like + Either a string point to the file from which to load the + quadratic_matrix array, or the array itself. + roq_params: str, array_like + Parameters describing the domain of validity of the ROQ basis. priors: dict, bilby.prior.PriorDict A dictionary of priors containing at least the geocent_time prior distance_marginalization_lookup_table: (dict, str), optional @@ -761,6 +763,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): """ def __init__(self, interferometers, waveform_generator, priors, weights=None, linear_matrix=None, quadratic_matrix=None, + roq_params=None, distance_marginalization=False, phase_marginalization=False, distance_marginalization_lookup_table=None): GravitationalWaveTransient.__init__( @@ -770,6 +773,12 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): phase_marginalization=phase_marginalization, distance_marginalization_lookup_table=distance_marginalization_lookup_table) + if isinstance(roq_params, np.ndarray) or roq_params is None: + self.roq_params = roq_params + elif isinstance(roq_params, str): + self.roq_params = np.genfromtxt(roq_params, names=True) + else: + raise TypeError("roq_params should be array or str") if isinstance(weights, dict): self.weights = weights elif isinstance(weights, str): @@ -886,9 +895,18 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): self._get_time_resolution()) - self.interferometers.start_time for ifo in self.interferometers: - # only get frequency components up to maximum_frequency - linear_matrix = linear_matrix[:, :sum(ifo.frequency_mask)] - quadratic_matrix = quadratic_matrix[:, :sum(ifo.frequency_mask)] + if self.roq_params is not None: + frequencies = np.arange( + self.roq_params['flow'], + self.roq_params['fhigh'] + 1 / self.roq_params['seglen'], + 1 / self.roq_params['seglen']) + overlap_frequencies, ifo_idxs, roq_idxs = np.intersect1d( + ifo.frequency_array[ifo.frequency_mask], frequencies, + return_indices=True) + else: + overlap_frequencies = ifo.frequency_array[ifo.frequency_mask] + roq_idxs = np.arange(linear_matrix.shape[0], dtype=int) + ifo_idxs = ifo.frequency_mask # array of relative time shifts to be applied to the data # 0.045s comes from time for GW to traverse the Earth @@ -897,17 +915,16 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): # array to be filled with data, shifted by discrete time_samples tc_shifted_data = np.zeros([ - len(self.weights['time_samples']), sum(ifo.frequency_mask)], + len(self.weights['time_samples']), len(overlap_frequencies)], dtype=complex) # shift data to beginning of the prior increment by the time step shifted_data =\ - ifo.frequency_domain_strain[ifo.frequency_mask] * \ - np.exp(2j * np.pi * ifo.frequency_array[ifo.frequency_mask] * + ifo.frequency_domain_strain[ifo_idxs] * \ + np.exp(2j * np.pi * overlap_frequencies * self.weights['time_samples'][0]) single_time_shift = np.exp( - 2j * np.pi * ifo.frequency_array[ifo.frequency_mask] * - time_space) + 2j * np.pi * overlap_frequencies * time_space) for j in range(len(self.weights['time_samples'])): tc_shifted_data[j] = shifted_data shifted_data *= single_time_shift @@ -918,15 +935,16 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): max_elements = int((max_block_gigabytes * 2 ** 30) / 8) self.weights[ifo.name + '_linear'] = blockwise_dot_product( - tc_shifted_data / - ifo.power_spectral_density_array[ifo.frequency_mask], - linear_matrix, max_elements) * 4 / ifo.strain_data.duration + tc_shifted_data / ifo.power_spectral_density_array[ifo_idxs], + linear_matrix[roq_idxs], + max_elements) * 4 / ifo.strain_data.duration del tc_shifted_data self.weights[ifo.name + '_quadratic'] = build_roq_weights( - 1 / ifo.power_spectral_density_array[ifo.frequency_mask], - quadratic_matrix.real, 1 / ifo.strain_data.duration) + 1 / ifo.power_spectral_density_array[ifo_idxs], + quadratic_matrix[roq_idxs].real, + 1 / ifo.strain_data.duration) def save_weights(self, filename): with open(filename, 'w') as file: