Commit 35071349 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch 'fast_roq_weight' into 'master'

bilby/gw/likelihood.py: speed up ROQ weight calculation with IFFT

See merge request lscsoft/bilby!903
parents 70cec967 a8dee319
import gc
import os
import json
import copy
......@@ -19,8 +18,7 @@ from .detector import InterferometerList, get_empty_interferometer, calibration
from .prior import BBHPriorDict, CBCPriorDict, Cosmological
from .source import lal_binary_black_hole
from .utils import (
noise_weighted_inner_product, build_roq_weights, blockwise_dot_product,
zenith_azimuth_to_ra_dec)
noise_weighted_inner_product, build_roq_weights, zenith_azimuth_to_ra_dec)
from .waveform_generator import WaveformGenerator
from collections import namedtuple
......@@ -1459,7 +1457,9 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
roq_minimum_component_mass))
def _set_weights(self, linear_matrix, quadratic_matrix):
""" Setup the time-dependent ROQ weights.
"""
Setup the time-dependent ROQ weights.
See https://dcc.ligo.org/LIGO-T2100125 for the detail of how to compute them.
Parameters
==========
......@@ -1469,15 +1469,26 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
"""
time_space = self._get_time_resolution()
number_of_time_samples = int(self.interferometers.duration / time_space)
try:
import pyfftw
ifft_input = pyfftw.empty_aligned(number_of_time_samples, dtype=complex)
ifft_output = pyfftw.empty_aligned(number_of_time_samples, dtype=complex)
ifft = pyfftw.FFTW(ifft_input, ifft_output, direction='FFTW_BACKWARD')
except ImportError:
pyfftw = None
logger.warning("You do not have pyfftw installed, falling back to numpy.fft.")
ifft_input = np.zeros(number_of_time_samples, dtype=complex)
ifft = np.fft.ifft
# Maximum delay time to geocentre + 5 steps
earth_light_crossing_time = radius_of_earth / speed_of_light + 5 * time_space
delta_times = np.arange(
self.priors['{}_time'.format(self.time_reference)].minimum - earth_light_crossing_time,
self.priors['{}_time'.format(self.time_reference)].maximum + earth_light_crossing_time,
time_space)
time_samples = delta_times - self.interferometers.start_time
self.weights['time_samples'] = time_samples
logger.info("Using {} ROQ time samples".format(len(time_samples)))
start_idx = max(0, np.int(np.floor((self.priors['{}_time'.format(self.time_reference)].minimum -
earth_light_crossing_time - self.interferometers.start_time) / time_space)))
end_idx = min(number_of_time_samples - 1, np.int(np.ceil((
self.priors['{}_time'.format(self.time_reference)].maximum + earth_light_crossing_time -
self.interferometers.start_time) / time_space)))
self.weights['time_samples'] = np.arange(start_idx, end_idx + 1) * time_space
logger.info("Using {} ROQ time samples".format(len(self.weights['time_samples'])))
for ifo in self.interferometers:
if self.roq_params is not None:
......@@ -1498,8 +1509,8 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
else:
overlap_frequencies = ifo.frequency_array[ifo.frequency_mask]
roq_idxs = np.arange(linear_matrix.shape[0], dtype=int)
ifo_idxs = [True] * sum(ifo.frequency_mask)
if sum(ifo_idxs) != len(roq_idxs):
ifo_idxs = np.arange(sum(ifo.frequency_mask))
if len(ifo_idxs) != len(roq_idxs):
raise ValueError(
"Mismatch between ROQ basis and frequency array for "
"{}".format(ifo.name))
......@@ -1509,34 +1520,16 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
ifo.name, len(overlap_frequencies),
min(overlap_frequencies), max(overlap_frequencies)))
logger.debug("Preallocate array")
tc_shifted_data = np.zeros((
len(self.weights['time_samples']), len(overlap_frequencies)),
dtype=complex)
logger.debug("Calculate shifted data")
data = ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs]
prefactor = (
data /
ifft_input[:] *= 0.
self.weights[ifo.name + '_linear'] = \
np.zeros((len(self.weights['time_samples']), linear_matrix.shape[1]), dtype=complex)
data_over_psd = ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs] / \
ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs]
)
for j in range(len(self.weights['time_samples'])):
tc_shifted_data[j] = prefactor * np.exp(
2j * np.pi * overlap_frequencies * time_samples[j])
# to not kill all computers this minimises the memory usage of the
# required inner products
max_block_gigabytes = 4
max_elements = int((max_block_gigabytes * 2 ** 30) / 8)
logger.debug("Apply dot product")
self.weights[ifo.name + '_linear'] = blockwise_dot_product(
tc_shifted_data,
linear_matrix[roq_idxs],
max_elements) * 4 / ifo.strain_data.duration
del tc_shifted_data, overlap_frequencies
gc.collect()
nonzero_idxs = ifo_idxs + int(ifo.frequency_array[ifo.frequency_mask][0] * self.interferometers.duration)
for i, basis_element in enumerate(linear_matrix[roq_idxs].T):
ifft_input[nonzero_idxs] = data_over_psd * np.conj(basis_element)
self.weights[ifo.name + '_linear'][:, i] = ifft(ifft_input)[start_idx:end_idx + 1]
self.weights[ifo.name + '_linear'] *= 4. * number_of_time_samples / self.interferometers.duration
self.weights[ifo.name + '_quadratic'] = build_roq_weights(
1 /
......@@ -1546,6 +1539,9 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
logger.info("Finished building weights for {}".format(ifo.name))
if pyfftw is not None:
pyfftw.forget_wisdom()
def save_weights(self, filename, format='npz'):
if format not in filename:
filename += "." + format
......@@ -1630,6 +1626,12 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
# Apply a safety factor to ensure the time step is short enough
delta_t = delta_t / 5
# duration / delta_t needs to be a power of 2 for IFFT
number_of_time_samples = max(
self.interferometers.duration / delta_t,
self.interferometers.frequency_array[-1] * self.interferometers.duration + 1)
number_of_time_samples = np.int(2**np.ceil(np.log2(number_of_time_samples)))
delta_t = self.interferometers.duration / number_of_time_samples
logger.info("ROQ time-step = {}".format(delta_t))
return delta_t
......
......@@ -704,74 +704,6 @@ def build_roq_weights(data, basis, deltaF):
return weights
def blockwise_dot_product(matrix_a, matrix_b, max_elements=int(2 ** 27),
out=None):
"""
Memory efficient
Computes the dot product of two matrices in a block-wise fashion.
Only blocks of `matrix_a` with a maximum size of `max_elements` will be
processed simultaneously.
Parameters
==========
matrix_a, matrix_b: array-like
Matrices to be dot producted, matrix_b is complex conjugated.
max_elements: int
Maximum number of elements to consider simultaneously, should be memory
limited.
out: array-like
Output array
Returns
=======
out: array-like
Dot producted array
"""
def block_slices(dim_size, block_size):
"""Generator that yields slice objects for indexing into
sequential blocks of an array along a particular axis
Useful for blockwise dot
"""
count = 0
while True:
yield slice(count, count + block_size, 1)
count += block_size
if count > dim_size:
return
matrix_b = np.conjugate(matrix_b)
m, n = matrix_a.shape
n1, o = matrix_b.shape
if n1 != n:
raise ValueError(
'Matrices are not aligned, matrix a has shape ' +
'{}, matrix b has shape {}.'.format(matrix_a.shape, matrix_b.shape))
if matrix_a.flags.f_contiguous:
# prioritize processing as many columns of matrix_a as possible
max_cols = max(1, max_elements // m)
max_rows = max_elements // max_cols
else:
# prioritize processing as many rows of matrix_a as possible
max_rows = max(1, max_elements // n)
max_cols = max_elements // max_rows
if out is None:
out = np.empty((m, o), dtype=np.result_type(matrix_a, matrix_b))
elif out.shape != (m, o):
raise ValueError('Output array has incorrect dimensions.')
for mm in block_slices(m, max_rows):
out[mm, :] = 0
for nn in block_slices(n, max_cols):
a_block = matrix_a[mm, nn].copy() # copy to force a read
out[mm, :] += np.dot(a_block, matrix_b[nn, :])
del a_block
return out
def convert_args_list_to_float(*args_list):
""" Converts inputs to floats, returns a list in the same order as the input"""
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment