Commit bc4c7077 authored by Gregory Ashton's avatar Gregory Ashton

Minor adjust to the ROQ weight generation

1) Convert the light travel time to calculated value
2) Use a single fixed time_space (rather than recalculating)
3) Add info statements
4) Generate frequencies using the seglen directly
5) Adds a check if the seglen is shorter than the duration
6) Apply a safety factor of 5 to ensure the time step is short enough
parent 7fc2c6d1
from __future__ import division
import gc
import os
import json
import copy
......@@ -17,7 +18,8 @@ from scipy.special import i0e
from ..core import likelihood
from ..core.utils import (
logger, UnsortedInterp2d, BilbyJsonEncoder, decode_bilby_json,
create_frequency_series, create_time_series)
create_frequency_series, create_time_series, speed_of_light,
radius_of_earth)
from ..core.prior import Interped, Prior, Uniform
from .detector import InterferometerList
from .prior import BBHPriorDict
......@@ -842,7 +844,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
dt = interferometer.time_delay_from_geocenter(
self.parameters['ra'], self.parameters['dec'],
interferometer.strain_data.start_time)
self.parameters['geocent_time'])
ifo_time = self.parameters['geocent_time'] + dt - \
interferometer.strain_data.start_time
......@@ -863,6 +865,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
indices, in_bounds = self._closest_time_indices(
ifo_time, self.weights['time_samples'])
if not in_bounds:
logger.debug("SNR calculation error: requested time at edge of ROQ time samples")
return self._CalculatedSNRs(
d_inner_h=np.nan_to_num(-np.inf), optimal_snr_squared=0,
complex_matched_filter_snr=np.nan_to_num(-np.inf),
......@@ -874,7 +877,7 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
d_inner_h = interp1d(
self.weights['time_samples'][indices],
d_inner_h_tc_array, kind='cubic')(ifo_time)
d_inner_h_tc_array, kind='cubic', assume_sorted=True)(ifo_time)
optimal_snr_squared = \
np.vdot(np.abs(h_plus_quadratic + h_cross_quadratic)**2,
......@@ -913,28 +916,38 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
return indices, in_bounds
def _set_weights(self, linear_matrix, quadratic_matrix):
"""
Setup the time-dependent ROQ weights.
This follows FIXME: Smith et al.
""" Setup the time-dependent ROQ weights.
Parameters
----------
linear_matrix, quadratic_matrix: array_like
Arrays of the linear and quadratic basis
The times are chosen to allow all the merger times allows in the time
prior.
"""
self.weights['time_samples'] = np.arange(
self.priors['geocent_time'].minimum - 0.045,
self.priors['geocent_time'].maximum + 0.045,
self._get_time_resolution()) - self.interferometers.start_time
# Maximum delay time to geocentre plus 10%
earth_light_crossing_time = 1.1 * radius_of_earth / speed_of_light
time_space = self._get_time_resolution()
delta_times = np.arange(
self.priors['geocent_time'].minimum - earth_light_crossing_time,
self.priors['geocent_time'].maximum + earth_light_crossing_time,
time_space)
time_samples = delta_times - self.interferometers.start_time
self.weights['time_samples'] = time_samples
logger.info("Using {} ROQ time samples".format(len(time_samples)))
for ifo in self.interferometers:
if self.roq_params is not None:
frequencies = create_frequency_series(
if ifo.maximum_frequency > self.roq_params['fhigh']:
raise ValueError("Requested maximum frequency larger than ROQ basis fhigh")
# Generate frequencies for the ROQ
roq_frequencies = create_frequency_series(
sampling_frequency=self.roq_params['fhigh'] * 2,
duration=self.roq_params['seglen'])
roq_mask = [frequencies >= self.roq_params['flow']]
frequencies = frequencies[roq_mask]
roq_mask = roq_frequencies >= self.roq_params['flow']
roq_frequencies = roq_frequencies[roq_mask]
overlap_frequencies, ifo_idxs, roq_idxs = np.intersect1d(
ifo.frequency_array[ifo.frequency_mask], frequencies,
ifo.frequency_array[ifo.frequency_mask], roq_frequencies,
return_indices=True)
else:
overlap_frequencies = ifo.frequency_array[ifo.frequency_mask]
......@@ -950,26 +963,9 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
ifo.name, len(overlap_frequencies),
min(overlap_frequencies), max(overlap_frequencies)))
# array of relative time shifts to be applied to the data
# 0.045s comes from time for GW to traverse the Earth
time_space = (self.weights['time_samples'][1] -
self.weights['time_samples'][0])
# array to be filled with data, shifted by discrete time_samples
tc_shifted_data = np.zeros([
len(self.weights['time_samples']), len(overlap_frequencies)],
dtype=complex)
# shift data to beginning of the prior increment by the time step
shifted_data =\
ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs] * \
np.exp(2j * np.pi * overlap_frequencies *
self.weights['time_samples'][0])
single_time_shift = np.exp(
2j * np.pi * overlap_frequencies * time_space)
for j in range(len(self.weights['time_samples'])):
tc_shifted_data[j] = shifted_data
shifted_data *= single_time_shift
data = ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs]
tc_shifted_data = data * np.exp(
2j * np.pi * overlap_frequencies * time_samples[:, np.newaxis])
# to not kill all computers this minimises the memory usage of the
# required inner products
......@@ -982,7 +978,8 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
linear_matrix[roq_idxs],
max_elements) * 4 / ifo.strain_data.duration
del tc_shifted_data
del tc_shifted_data, overlap_frequencies
gc.collect()
self.weights[ifo.name + '_quadratic'] = build_roq_weights(
1 /
......@@ -990,6 +987,8 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
quadratic_matrix[roq_idxs].real,
1 / ifo.strain_data.duration)
logger.info("Finished building weights for {}".format(ifo.name))
def save_weights(self, filename):
with open(filename, 'w') as file:
json.dump(self.weights, file, indent=2, cls=BilbyJsonEncoder)
......@@ -1055,6 +1054,10 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
delta_t = fhigh**-1
# Apply a safety factor to ensure the time step is short enough
delta_t = delta_t / 5
logger.info("ROQ time-step = {}".format(delta_t))
return delta_t
def _rescale_signal(self, signal, new_distance):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment