Skip to content
Snippets Groups Projects
Commit 05b1bc2c authored by Colm Talbot's avatar Colm Talbot Committed by Gregory Ashton
Browse files

make roq weight generation use intersection of frequency arrays

parent 9abefc5a
No related branches found
No related tags found
No related merge requests found
......@@ -747,12 +747,14 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
waveform_generator: `bilby.waveform_generator.WaveformGenerator`
An object which computes the frequency-domain strain of the signal,
given some set of parameters
linear_matrix: str, array
linear_matrix: str, array_like
Either a string point to the file from which to load the linear_matrix
array, or the array itself.
quadratic_matrix: str, array
Either a string point to the file from which to load the quadratic_matrix
array, or the array itself.
quadratic_matrix: str, array_like
Either a string point to the file from which to load the
quadratic_matrix array, or the array itself.
roq_params: str, array_like
Parameters describing the domain of validity of the ROQ basis.
priors: dict, bilby.prior.PriorDict
A dictionary of priors containing at least the geocent_time prior
distance_marginalization_lookup_table: (dict, str), optional
......@@ -767,6 +769,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
"""
def __init__(self, interferometers, waveform_generator, priors,
weights=None, linear_matrix=None, quadratic_matrix=None,
roq_params=None,
distance_marginalization=False, phase_marginalization=False,
distance_marginalization_lookup_table=None):
GravitationalWaveTransient.__init__(
......@@ -776,6 +779,12 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
phase_marginalization=phase_marginalization,
distance_marginalization_lookup_table=distance_marginalization_lookup_table)
if isinstance(roq_params, np.ndarray) or roq_params is None:
self.roq_params = roq_params
elif isinstance(roq_params, str):
self.roq_params = np.genfromtxt(roq_params, names=True)
else:
raise TypeError("roq_params should be array or str")
if isinstance(weights, dict):
self.weights = weights
elif isinstance(weights, str):
......@@ -902,9 +911,18 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
self._get_time_resolution()) - self.interferometers.start_time
for ifo in self.interferometers:
# only get frequency components up to maximum_frequency
linear_matrix = linear_matrix[:, :sum(ifo.frequency_mask)]
quadratic_matrix = quadratic_matrix[:, :sum(ifo.frequency_mask)]
if self.roq_params is not None:
frequencies = np.arange(
self.roq_params['flow'],
self.roq_params['fhigh'] + 1 / self.roq_params['seglen'],
1 / self.roq_params['seglen'])
overlap_frequencies, ifo_idxs, roq_idxs = np.intersect1d(
ifo.frequency_array[ifo.frequency_mask], frequencies,
return_indices=True)
else:
overlap_frequencies = ifo.frequency_array[ifo.frequency_mask]
roq_idxs = np.arange(linear_matrix.shape[0], dtype=int)
ifo_idxs = ifo.frequency_mask
# array of relative time shifts to be applied to the data
# 0.045s comes from time for GW to traverse the Earth
......@@ -913,17 +931,16 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
# array to be filled with data, shifted by discrete time_samples
tc_shifted_data = np.zeros([
len(self.weights['time_samples']), sum(ifo.frequency_mask)],
len(self.weights['time_samples']), len(overlap_frequencies)],
dtype=complex)
# shift data to beginning of the prior increment by the time step
shifted_data =\
ifo.frequency_domain_strain[ifo.frequency_mask] * \
np.exp(2j * np.pi * ifo.frequency_array[ifo.frequency_mask] *
ifo.frequency_domain_strain[ifo_idxs] * \
np.exp(2j * np.pi * overlap_frequencies *
self.weights['time_samples'][0])
single_time_shift = np.exp(
2j * np.pi * ifo.frequency_array[ifo.frequency_mask] *
time_space)
2j * np.pi * overlap_frequencies * time_space)
for j in range(len(self.weights['time_samples'])):
tc_shifted_data[j] = shifted_data
shifted_data *= single_time_shift
......@@ -934,15 +951,16 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
max_elements = int((max_block_gigabytes * 2 ** 30) / 8)
self.weights[ifo.name + '_linear'] = blockwise_dot_product(
tc_shifted_data /
ifo.power_spectral_density_array[ifo.frequency_mask],
linear_matrix, max_elements) * 4 / ifo.strain_data.duration
tc_shifted_data / ifo.power_spectral_density_array[ifo_idxs],
linear_matrix[roq_idxs],
max_elements) * 4 / ifo.strain_data.duration
del tc_shifted_data
self.weights[ifo.name + '_quadratic'] = build_roq_weights(
1 / ifo.power_spectral_density_array[ifo.frequency_mask],
quadratic_matrix.real, 1 / ifo.strain_data.duration)
1 / ifo.power_spectral_density_array[ifo_idxs],
quadratic_matrix[roq_idxs].real,
1 / ifo.strain_data.duration)
def save_weights(self, filename):
with open(filename, 'w') as file:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment