Skip to content
Snippets Groups Projects

Jitter time marginalisation

Merged Colm Talbot requested to merge jitter-time-marginalisation into master
Files
2
+ 66
33
@@ -18,8 +18,7 @@ from scipy.special import i0e
from ..core import likelihood
from ..core.utils import (
logger, UnsortedInterp2d, BilbyJsonEncoder, decode_bilby_json,
create_frequency_series, create_time_series, speed_of_light,
radius_of_earth)
create_frequency_series, speed_of_light, radius_of_earth)
from ..core.prior import Interped, Prior, Uniform
from .detector import InterferometerList
from .prior import BBHPriorDict
@@ -49,12 +48,21 @@ class GravitationalWaveTransient(likelihood.Likelihood):
distance_marginalization: bool, optional
If true, marginalize over distance in the likelihood.
This uses a look up table calculated at run time.
The distance prior is set to be a delta function at the minimum
distance allowed in the prior being marginalised over.
time_marginalization: bool, optional
If true, marginalize over time in the likelihood.
This uses a FFT.
This uses a FFT to calculate the likelihood over a regularly spaced
grid.
In order to cover the whole space the prior is set to be uniform over
the spacing of the array of times.
If using time marginalisation and jitter_time is True a "jitter"
parameter is added to the prior which modifies the position of the
grid of times.
phase_marginalization: bool, optional
If true, marginalize over phase in the likelihood.
This is done analytically using a Bessel function.
The phase prior is set to be a delta function at phase=0.
priors: dict, optional
If given, used in the distance and phase marginalization.
distance_marginalization_lookup_table: (dict, str), optional
@@ -65,6 +73,11 @@ class GravitationalWaveTransient(likelihood.Likelihood):
The lookup table is stored after construction in either the
provided string or a default location:
'.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz'
jitter_time: bool, optional
Whether to introduce a `time_jitter` parameter. This avoids either
missing the likelihood peak, or introducing biases in the
reconstructed time posterior due to an insufficient sampling frequency.
Default is False, however using this parameter is strongly encouraged.
Returns
-------
@@ -83,7 +96,8 @@ class GravitationalWaveTransient(likelihood.Likelihood):
def __init__(self, interferometers, waveform_generator,
time_marginalization=False, distance_marginalization=False,
phase_marginalization=False, priors=None,
distance_marginalization_lookup_table=None):
distance_marginalization_lookup_table=None,
jitter_time=True):
self.waveform_generator = waveform_generator
likelihood.Likelihood.__init__(self, dict())
@@ -93,11 +107,20 @@ class GravitationalWaveTransient(likelihood.Likelihood):
self.phase_marginalization = phase_marginalization
self.priors = priors
self._check_set_duration_and_sampling_frequency_of_waveform_generator()
self.jitter_time = jitter_time
if self.time_marginalization:
self._check_prior_is_set(key='geocent_time')
self._setup_time_marginalization()
priors['geocent_time'] = float(self.interferometers.start_time)
if self.jitter_time:
priors['time_jitter'] = Uniform(
minimum=- self._delta_tc / 2, maximum=self._delta_tc / 2)
elif self.jitter_time:
logger.info(
"Time jittering requested with non-time-marginalised "
"likelihood, ignoring.")
self.jitter_time = False
if self.phase_marginalization:
self._check_prior_is_set(key='phase')
@@ -225,6 +248,8 @@ class GravitationalWaveTransient(likelihood.Likelihood):
optimal_snr_squared = 0.
complex_matched_filter_snr = 0.
if self.time_marginalization:
if self.jitter_time:
self.parameters['geocent_time'] += self.parameters['time_jitter']
d_inner_h_tc_array = np.zeros(
self.interferometers.frequency_array[0:-1].shape,
dtype=np.complex128)
@@ -242,6 +267,12 @@ class GravitationalWaveTransient(likelihood.Likelihood):
d_inner_h_tc_array += per_detector_snr.d_inner_h_squared_tc_array
if self.time_marginalization:
if self.jitter_time:
times = self._times + self.parameters['time_jitter']
self.parameters['geocent_time'] -= self.parameters['time_jitter']
else:
times = self._times
self.time_prior_array = self.priors['geocent_time'].prob(times) * self._delta_tc
log_l = self.time_marginalized_likelihood(
d_inner_h_tc_array=d_inner_h_tc_array,
h_inner_h=optimal_snr_squared)
@@ -317,46 +348,46 @@ class GravitationalWaveTransient(likelihood.Likelihood):
new_time: float
Sample from the time posterior.
"""
if self.jitter_time:
self.parameters['geocent_time'] += self.parameters['time_jitter']
if signal_polarizations is None:
signal_polarizations = \
self.waveform_generator.frequency_domain_strain(self.parameters)
n_time_steps = int(self.waveform_generator.duration * 16384)
d_inner_h = np.zeros(n_time_steps, dtype=np.complex)
psd = np.ones(n_time_steps)
signal_long = np.zeros(n_time_steps, dtype=np.complex)
data = np.zeros(n_time_steps, dtype=np.complex)
h_inner_h = np.zeros(1)
for ifo in self.interferometers:
ifo_length = len(ifo.frequency_domain_strain)
signal = ifo.get_detector_response(
signal_polarizations, self.parameters)
signal_long[:ifo_length] = signal
data[:ifo_length] = np.conj(ifo.frequency_domain_strain)
psd[:ifo_length] = ifo.power_spectral_density_array
d_inner_h += np.fft.fft(signal_long * data / psd)
h_inner_h += ifo.optimal_snr_squared(signal=signal).real
d_inner_h = 0.
h_inner_h = 0.
complex_matched_filter_snr = 0.
d_inner_h_tc_array = np.zeros(
self.interferometers.frequency_array[0:-1].shape,
dtype=np.complex128)
for interferometer in self.interferometers:
per_detector_snr = self.calculate_snrs(
signal_polarizations, interferometer)
d_inner_h += per_detector_snr.d_inner_h
h_inner_h += per_detector_snr.optimal_snr_squared
complex_matched_filter_snr += per_detector_snr.complex_matched_filter_snr
if self.time_marginalization:
d_inner_h_tc_array += per_detector_snr.d_inner_h_squared_tc_array
if self.distance_marginalization:
time_log_like = self.distance_marginalized_likelihood(
d_inner_h, h_inner_h)
elif self.phase_marginalization:
time_log_like = (self._bessel_function_interped(abs(d_inner_h)) -
h_inner_h.real / 2)
time_log_like = (
self._bessel_function_interped(abs(d_inner_h_tc_array)) -
h_inner_h.real / 2)
else:
time_log_like = (d_inner_h.real - h_inner_h.real / 2)
time_log_like = (d_inner_h_tc_array.real - h_inner_h.real / 2)
times = create_time_series(
sampling_frequency=16384,
starting_time=self.waveform_generator.start_time,
duration=self.waveform_generator.duration)
if self.jitter_time:
times = self._times + self.parameters['time_jitter']
time_prior_array = self.priors['geocent_time'].prob(times)
time_post = (
np.exp(time_log_like - max(time_log_like)) * time_prior_array)
keep = (time_post > max(time_post) / 1000)
time_post = time_post[keep]
times = times[keep]
new_time = Interped(times, time_post).sample()
return new_time
@@ -624,14 +655,14 @@ class GravitationalWaveTransient(likelihood.Likelihood):
bounds_error=False, fill_value=(0, np.nan))
def _setup_time_marginalization(self):
delta_tc = 2 / self.waveform_generator.sampling_frequency
self._delta_tc = 2 / self.waveform_generator.sampling_frequency
self._times =\
self.interferometers.start_time + np.linspace(
0, self.interferometers.duration,
int(self.interferometers.duration / 2 *
self.waveform_generator.sampling_frequency + 1))[1:]
self.time_prior_array =\
self.priors['geocent_time'].prob(self._times) * delta_tc
self.time_prior_array = \
self.priors['geocent_time'].prob(self._times) * self._delta_tc
@property
def interferometers(self):
@@ -795,7 +826,9 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
waveform_generator=waveform_generator, priors=priors,
distance_marginalization=distance_marginalization,
phase_marginalization=phase_marginalization,
distance_marginalization_lookup_table=distance_marginalization_lookup_table)
time_marginalization=False,
distance_marginalization_lookup_table=distance_marginalization_lookup_table,
jitter_time=False)
if isinstance(roq_params, np.ndarray) or roq_params is None:
self.roq_params = roq_params
Loading