diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py index b0bac463ddf0a5fb0613413664d1aa69394ce926..fddfa92b56e5f95f2b2c3e3c9fa7f4ddc67a5bb9 100644 --- a/bilby/gw/likelihood/roq.py +++ b/bilby/gw/likelihood/roq.py @@ -136,13 +136,28 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): linear_matrix['duration_s'][()])], dtype=[('flow', float), ('fhigh', float), ('seglen', float)] ) - elif is_hdf5_quadratic: - self.roq_params = np.array( - [(quadratic_matrix['minimum_frequency_hz'][()], - quadratic_matrix['maximum_frequency_hz'][()], - quadratic_matrix['duration_s'][()])], - dtype=[('flow', float), ('fhigh', float), ('seglen', float)] - ) + if is_hdf5_quadratic: + if self.roq_params is None: + self.roq_params = np.array( + [(quadratic_matrix['minimum_frequency_hz'][()], + quadratic_matrix['maximum_frequency_hz'][()], + quadratic_matrix['duration_s'][()])], + dtype=[('flow', float), ('fhigh', float), ('seglen', float)] + ) + else: + self.roq_params['flow'] = max( + self.roq_params['flow'], quadratic_matrix['minimum_frequency_hz'][()] + ) + self.roq_params['fhigh'] = min( + self.roq_params['fhigh'], quadratic_matrix['maximum_frequency_hz'][()] + ) + self.roq_params['seglen'] = min( + self.roq_params['seglen'], quadratic_matrix['duration_s'][()] + ) + if self.roq_params is not None: + for ifo in self.interferometers: + self.perform_roq_params_check(ifo) + self.weights = dict() self._set_weights(linear_matrix=linear_matrix, quadratic_matrix=quadratic_matrix) if is_hdf5_linear: @@ -158,9 +173,10 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): for basis_type in ['linear', 'quadratic']: number_of_bases = getattr(self, f'number_of_bases_{basis_type}') if number_of_bases > 1: - self._verify_prior_ranges_and_frequency_nodes(basis_type) + self._verify_numbers_of_prior_ranges_and_frequency_nodes(basis_type) else: self._check_frequency_nodes_exist_for_single_basis(basis_type) + self._verify_prior_ranges(basis_type) self._set_unique_frequency_nodes_and_inverse() # need to fill waveform_arguments here if single basis is used, as they will never be updated. @@ -171,7 +187,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): self._waveform_generator.waveform_arguments['linear_indices'] = linear_indices self._waveform_generator.waveform_arguments['quadratic_indices'] = quadratic_indices - def _verify_prior_ranges_and_frequency_nodes(self, basis_type): + def _verify_numbers_of_prior_ranges_and_frequency_nodes(self, basis_type): """ Check if self.weights contains lists of prior ranges and frequency nodes, and their sizes are equal to the number of bases. @@ -205,6 +221,35 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): raise ValueError( f'The number of arrays of frequency nodes does not match the number of {basis_type} bases') + def _verify_prior_ranges(self, basis_type): + """Check if the union of prior ranges is within the ROQ basis bounds. + + Parameters + ========== + basis_type: str + + """ + key = f'prior_range_{basis_type}' + if key not in self.weights: + return + prior_ranges = self.weights[key] + for param_name, prior_ranges_of_this_param in prior_ranges.items(): + prior_minimum = self.priors[param_name].minimum + basis_minimum = np.min(prior_ranges_of_this_param[:, 0]) + if prior_minimum < basis_minimum: + raise BilbyROQParamsRangeError( + f"Prior minimum of {param_name} {prior_minimum} less " + f"than ROQ basis bound {basis_minimum}" + ) + + prior_maximum = self.priors[param_name].maximum + basis_maximum = np.max(prior_ranges_of_this_param[:, 1]) + if prior_maximum > basis_maximum: + raise BilbyROQParamsRangeError( + f"Prior maximum of {param_name} {prior_maximum} greater " + f"than ROQ basis bound {basis_maximum}" + ) + def _check_frequency_nodes_exist_for_single_basis(self, basis_type): """ For a single-basis case, frequency nodes should be contained in self._waveform_generator.waveform_arguments or @@ -701,6 +746,8 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): roq_scale_factor = 1. prior_ranges[param_name] = matrix[key][param_name][()] * roq_scale_factor selected_idxs, selected_prior_ranges = self._select_prior_ranges(prior_ranges) + if len(selected_idxs) == 0: + raise BilbyROQParamsRangeError(f"There are no {basis_type} ROQ bases within the prior range.") self.weights[key] = selected_prior_ranges idxs_in_prior_range[basis_type] = selected_idxs else: @@ -725,7 +772,6 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient): ifo_idxs = {} for ifo in self.interferometers: if self.roq_params is not None: - self.perform_roq_params_check(ifo) # Get scaled ROQ quantities roq_scaled_minimum_frequency = self.roq_params['flow'] * self.roq_scale_factor roq_scaled_maximum_frequency = self.roq_params['fhigh'] * self.roq_scale_factor diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 0c1575945247360f39ddeaa24acf0d48ba18430e..b12ffd59bdc9101bca7170c5764cd73c7a9eecbb 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -656,6 +656,80 @@ class TestROQLikelihoodHDF5(unittest.TestCase): self.injection_parameters["geocent_time"] + 0.1 ) + @parameterized.expand( + [(_path_to_basis, 20., 2048., 16), + (_path_to_basis, 10., 1024., 16), + (_path_to_basis, 20., 1024., 32), + (_path_to_basis_mb, 20., 2048., 16)] + ) + def test_fails_with_frequency_duration_mismatch( + self, basis, minimum_frequency, maximum_frequency, duration + ): + """Test if likelihood fails as expected, when data frequency range is + not within the basis range or data duration does not match the basis + duration. The basis frequency range and duration are 20--1024Hz and + 16s""" + self.priors["chirp_mass"].minimum = 8 + self.priors["chirp_mass"].maximum = 9 + interferometers = bilby.gw.detector.InterferometerList(["H1"]) + interferometers.set_strain_data_from_power_spectral_densities( + sampling_frequency=2 * maximum_frequency, + duration=duration, + start_time=self.injection_parameters["geocent_time"] - duration + 1 + ) + for ifo in interferometers: + ifo.minimum_frequency = minimum_frequency + ifo.maximum_frequency = maximum_frequency + search_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + duration=duration, + sampling_frequency=2 * maximum_frequency, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, + waveform_arguments=dict( + reference_frequency=self.reference_frequency, + waveform_approximant=self.waveform_approximant + ) + ) + with self.assertRaises(BilbyROQParamsRangeError): + bilby.gw.likelihood.ROQGravitationalWaveTransient( + interferometers=interferometers, + priors=self.priors, + waveform_generator=search_waveform_generator, + linear_matrix=basis, + quadratic_matrix=basis, + ) + + @parameterized.expand([(_path_to_basis, 7, 13), (_path_to_basis, 9, 15), (_path_to_basis, 16, 17)]) + def test_fails_with_prior_mismatch(self, basis, chirp_mass_min, chirp_mass_max): + """Test if likelihood fails as expected, when prior range is not within + the basis bounds. Basis chirp-mass range is 8Msun--14Msun.""" + self.priors["chirp_mass"].minimum = chirp_mass_min + self.priors["chirp_mass"].maximum = chirp_mass_max + interferometers = bilby.gw.detector.InterferometerList(["H1"]) + interferometers.set_strain_data_from_power_spectral_densities( + sampling_frequency=self.sampling_frequency, + duration=self.duration, + start_time=self.injection_parameters["geocent_time"] - self.duration + 1 + ) + for ifo in interferometers: + ifo.minimum_frequency = self.minimum_frequency + search_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( + duration=self.duration, + sampling_frequency=self.sampling_frequency, + frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq, + waveform_arguments=dict( + reference_frequency=self.reference_frequency, + waveform_approximant=self.waveform_approximant + ) + ) + with self.assertRaises(BilbyROQParamsRangeError): + bilby.gw.likelihood.ROQGravitationalWaveTransient( + interferometers=interferometers, + priors=self.priors, + waveform_generator=search_waveform_generator, + linear_matrix=basis, + quadratic_matrix=basis, + ) + @parameterized.expand( product( [_path_to_basis, _path_to_basis_mb],