Skip to content
Snippets Groups Projects

make roq weight generation use intersection of frequency arrays

Merged Colm Talbot requested to merge roq-frequency-fixing into master
1 file
+ 35
17
Compare changes
  • Side-by-side
  • Inline
+ 35
17
@@ -741,12 +741,14 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
waveform_generator: `bilby.waveform_generator.WaveformGenerator`
An object which computes the frequency-domain strain of the signal,
given some set of parameters
linear_matrix: str, array
linear_matrix: str, array_like
Either a string point to the file from which to load the linear_matrix
array, or the array itself.
quadratic_matrix: str, array
Either a string point to the file from which to load the quadratic_matrix
array, or the array itself.
quadratic_matrix: str, array_like
Either a string point to the file from which to load the
quadratic_matrix array, or the array itself.
roq_params: str, array_like
Parameters describing the domain of validity of the ROQ basis.
priors: dict, bilby.prior.PriorDict
A dictionary of priors containing at least the geocent_time prior
distance_marginalization_lookup_table: (dict, str), optional
@@ -761,6 +763,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
"""
def __init__(self, interferometers, waveform_generator, priors,
weights=None, linear_matrix=None, quadratic_matrix=None,
roq_params=None,
distance_marginalization=False, phase_marginalization=False,
distance_marginalization_lookup_table=None):
GravitationalWaveTransient.__init__(
@@ -770,6 +773,12 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
phase_marginalization=phase_marginalization,
distance_marginalization_lookup_table=distance_marginalization_lookup_table)
if isinstance(roq_params, np.ndarray) or roq_params is None:
self.roq_params = roq_params
elif isinstance(roq_params, str):
self.roq_params = np.genfromtxt(roq_params, names=True)
else:
raise TypeError("roq_params should be array or str")
if isinstance(weights, dict):
self.weights = weights
elif isinstance(weights, str):
@@ -886,9 +895,18 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
self._get_time_resolution()) - self.interferometers.start_time
for ifo in self.interferometers:
# only get frequency components up to maximum_frequency
linear_matrix = linear_matrix[:, :sum(ifo.frequency_mask)]
quadratic_matrix = quadratic_matrix[:, :sum(ifo.frequency_mask)]
if self.roq_params is not None:
frequencies = np.arange(
self.roq_params['flow'],
self.roq_params['fhigh'] + 1 / self.roq_params['seglen'],
1 / self.roq_params['seglen'])
overlap_frequencies, ifo_idxs, roq_idxs = np.intersect1d(
ifo.frequency_array[ifo.frequency_mask], frequencies,
return_indices=True)
else:
overlap_frequencies = ifo.frequency_array[ifo.frequency_mask]
roq_idxs = np.arange(linear_matrix.shape[0], dtype=int)
ifo_idxs = ifo.frequency_mask
# array of relative time shifts to be applied to the data
# 0.045s comes from time for GW to traverse the Earth
@@ -897,17 +915,16 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
# array to be filled with data, shifted by discrete time_samples
tc_shifted_data = np.zeros([
len(self.weights['time_samples']), sum(ifo.frequency_mask)],
len(self.weights['time_samples']), len(overlap_frequencies)],
dtype=complex)
# shift data to beginning of the prior increment by the time step
shifted_data =\
ifo.frequency_domain_strain[ifo.frequency_mask] * \
np.exp(2j * np.pi * ifo.frequency_array[ifo.frequency_mask] *
ifo.frequency_domain_strain[ifo_idxs] * \
np.exp(2j * np.pi * overlap_frequencies *
self.weights['time_samples'][0])
single_time_shift = np.exp(
2j * np.pi * ifo.frequency_array[ifo.frequency_mask] *
time_space)
2j * np.pi * overlap_frequencies * time_space)
for j in range(len(self.weights['time_samples'])):
tc_shifted_data[j] = shifted_data
shifted_data *= single_time_shift
@@ -918,15 +935,16 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
max_elements = int((max_block_gigabytes * 2 ** 30) / 8)
self.weights[ifo.name + '_linear'] = blockwise_dot_product(
tc_shifted_data /
ifo.power_spectral_density_array[ifo.frequency_mask],
linear_matrix, max_elements) * 4 / ifo.strain_data.duration
tc_shifted_data / ifo.power_spectral_density_array[ifo_idxs],
linear_matrix[roq_idxs],
max_elements) * 4 / ifo.strain_data.duration
del tc_shifted_data
self.weights[ifo.name + '_quadratic'] = build_roq_weights(
1 / ifo.power_spectral_density_array[ifo.frequency_mask],
quadratic_matrix.real, 1 / ifo.strain_data.duration)
1 / ifo.power_spectral_density_array[ifo_idxs],
quadratic_matrix[roq_idxs].real,
1 / ifo.strain_data.duration)
def save_weights(self, filename):
with open(filename, 'w') as file:
Loading