Skip to content
Snippets Groups Projects
Commit a90aaaa2 authored by Soichiro Morisaki's avatar Soichiro Morisaki Committed by Colm Talbot
Browse files

Add informative error messages in ROQ likelihood

parent 379a6ed6
No related branches found
No related tags found
No related merge requests found
......@@ -136,13 +136,28 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
linear_matrix['duration_s'][()])],
dtype=[('flow', float), ('fhigh', float), ('seglen', float)]
)
elif is_hdf5_quadratic:
self.roq_params = np.array(
[(quadratic_matrix['minimum_frequency_hz'][()],
quadratic_matrix['maximum_frequency_hz'][()],
quadratic_matrix['duration_s'][()])],
dtype=[('flow', float), ('fhigh', float), ('seglen', float)]
)
if is_hdf5_quadratic:
if self.roq_params is None:
self.roq_params = np.array(
[(quadratic_matrix['minimum_frequency_hz'][()],
quadratic_matrix['maximum_frequency_hz'][()],
quadratic_matrix['duration_s'][()])],
dtype=[('flow', float), ('fhigh', float), ('seglen', float)]
)
else:
self.roq_params['flow'] = max(
self.roq_params['flow'], quadratic_matrix['minimum_frequency_hz'][()]
)
self.roq_params['fhigh'] = min(
self.roq_params['fhigh'], quadratic_matrix['maximum_frequency_hz'][()]
)
self.roq_params['seglen'] = min(
self.roq_params['seglen'], quadratic_matrix['duration_s'][()]
)
if self.roq_params is not None:
for ifo in self.interferometers:
self.perform_roq_params_check(ifo)
self.weights = dict()
self._set_weights(linear_matrix=linear_matrix, quadratic_matrix=quadratic_matrix)
if is_hdf5_linear:
......@@ -158,9 +173,10 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
for basis_type in ['linear', 'quadratic']:
number_of_bases = getattr(self, f'number_of_bases_{basis_type}')
if number_of_bases > 1:
self._verify_prior_ranges_and_frequency_nodes(basis_type)
self._verify_numbers_of_prior_ranges_and_frequency_nodes(basis_type)
else:
self._check_frequency_nodes_exist_for_single_basis(basis_type)
self._verify_prior_ranges(basis_type)
self._set_unique_frequency_nodes_and_inverse()
# need to fill waveform_arguments here if single basis is used, as they will never be updated.
......@@ -171,7 +187,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
self._waveform_generator.waveform_arguments['linear_indices'] = linear_indices
self._waveform_generator.waveform_arguments['quadratic_indices'] = quadratic_indices
def _verify_prior_ranges_and_frequency_nodes(self, basis_type):
def _verify_numbers_of_prior_ranges_and_frequency_nodes(self, basis_type):
"""
Check if self.weights contains lists of prior ranges and frequency nodes, and their sizes are equal to the
number of bases.
......@@ -205,6 +221,35 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
raise ValueError(
f'The number of arrays of frequency nodes does not match the number of {basis_type} bases')
def _verify_prior_ranges(self, basis_type):
"""Check if the union of prior ranges is within the ROQ basis bounds.
Parameters
==========
basis_type: str
"""
key = f'prior_range_{basis_type}'
if key not in self.weights:
return
prior_ranges = self.weights[key]
for param_name, prior_ranges_of_this_param in prior_ranges.items():
prior_minimum = self.priors[param_name].minimum
basis_minimum = np.min(prior_ranges_of_this_param[:, 0])
if prior_minimum < basis_minimum:
raise BilbyROQParamsRangeError(
f"Prior minimum of {param_name} {prior_minimum} less "
f"than ROQ basis bound {basis_minimum}"
)
prior_maximum = self.priors[param_name].maximum
basis_maximum = np.max(prior_ranges_of_this_param[:, 1])
if prior_maximum > basis_maximum:
raise BilbyROQParamsRangeError(
f"Prior maximum of {param_name} {prior_maximum} greater "
f"than ROQ basis bound {basis_maximum}"
)
def _check_frequency_nodes_exist_for_single_basis(self, basis_type):
"""
For a single-basis case, frequency nodes should be contained in self._waveform_generator.waveform_arguments or
......@@ -701,6 +746,8 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
roq_scale_factor = 1.
prior_ranges[param_name] = matrix[key][param_name][()] * roq_scale_factor
selected_idxs, selected_prior_ranges = self._select_prior_ranges(prior_ranges)
if len(selected_idxs) == 0:
raise BilbyROQParamsRangeError(f"There are no {basis_type} ROQ bases within the prior range.")
self.weights[key] = selected_prior_ranges
idxs_in_prior_range[basis_type] = selected_idxs
else:
......@@ -725,7 +772,6 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
ifo_idxs = {}
for ifo in self.interferometers:
if self.roq_params is not None:
self.perform_roq_params_check(ifo)
# Get scaled ROQ quantities
roq_scaled_minimum_frequency = self.roq_params['flow'] * self.roq_scale_factor
roq_scaled_maximum_frequency = self.roq_params['fhigh'] * self.roq_scale_factor
......
......@@ -656,6 +656,80 @@ class TestROQLikelihoodHDF5(unittest.TestCase):
self.injection_parameters["geocent_time"] + 0.1
)
@parameterized.expand(
[(_path_to_basis, 20., 2048., 16),
(_path_to_basis, 10., 1024., 16),
(_path_to_basis, 20., 1024., 32),
(_path_to_basis_mb, 20., 2048., 16)]
)
def test_fails_with_frequency_duration_mismatch(
self, basis, minimum_frequency, maximum_frequency, duration
):
"""Test if likelihood fails as expected, when data frequency range is
not within the basis range or data duration does not match the basis
duration. The basis frequency range and duration are 20--1024Hz and
16s"""
self.priors["chirp_mass"].minimum = 8
self.priors["chirp_mass"].maximum = 9
interferometers = bilby.gw.detector.InterferometerList(["H1"])
interferometers.set_strain_data_from_power_spectral_densities(
sampling_frequency=2 * maximum_frequency,
duration=duration,
start_time=self.injection_parameters["geocent_time"] - duration + 1
)
for ifo in interferometers:
ifo.minimum_frequency = minimum_frequency
ifo.maximum_frequency = maximum_frequency
search_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
duration=duration,
sampling_frequency=2 * maximum_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq,
waveform_arguments=dict(
reference_frequency=self.reference_frequency,
waveform_approximant=self.waveform_approximant
)
)
with self.assertRaises(BilbyROQParamsRangeError):
bilby.gw.likelihood.ROQGravitationalWaveTransient(
interferometers=interferometers,
priors=self.priors,
waveform_generator=search_waveform_generator,
linear_matrix=basis,
quadratic_matrix=basis,
)
@parameterized.expand([(_path_to_basis, 7, 13), (_path_to_basis, 9, 15), (_path_to_basis, 16, 17)])
def test_fails_with_prior_mismatch(self, basis, chirp_mass_min, chirp_mass_max):
"""Test if likelihood fails as expected, when prior range is not within
the basis bounds. Basis chirp-mass range is 8Msun--14Msun."""
self.priors["chirp_mass"].minimum = chirp_mass_min
self.priors["chirp_mass"].maximum = chirp_mass_max
interferometers = bilby.gw.detector.InterferometerList(["H1"])
interferometers.set_strain_data_from_power_spectral_densities(
sampling_frequency=self.sampling_frequency,
duration=self.duration,
start_time=self.injection_parameters["geocent_time"] - self.duration + 1
)
for ifo in interferometers:
ifo.minimum_frequency = self.minimum_frequency
search_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
duration=self.duration,
sampling_frequency=self.sampling_frequency,
frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq,
waveform_arguments=dict(
reference_frequency=self.reference_frequency,
waveform_approximant=self.waveform_approximant
)
)
with self.assertRaises(BilbyROQParamsRangeError):
bilby.gw.likelihood.ROQGravitationalWaveTransient(
interferometers=interferometers,
priors=self.priors,
waveform_generator=search_waveform_generator,
linear_matrix=basis,
quadratic_matrix=basis,
)
@parameterized.expand(
product(
[_path_to_basis, _path_to_basis_mb],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment