Skip to content
Snippets Groups Projects

Jitter time marginalisation

Merged Colm Talbot requested to merge jitter-time-marginalisation into master
+ 30
10
@@ -56,8 +56,9 @@ class GravitationalWaveTransient(likelihood.Likelihood):
grid.
In order to cover the whole space the prior is set to be uniform over
the spacing of the array of times.
If using time marginalisation a "jitter" parameter is added to the
prior which modifies the position of the grid of times.
If using time marginalisation and jitter_time is True a "jitter"
parameter is added to the prior which modifies the position of the
grid of times.
phase_marginalization: bool, optional
If true, marginalize over phase in the likelihood.
This is done analytically using a Bessel function.
@@ -72,6 +73,11 @@ class GravitationalWaveTransient(likelihood.Likelihood):
The lookup table is stored after construction in either the
provided string or a default location:
'.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz'
jitter_time: bool, optional
Whether to introduce a `time_jitter` parameter. This avoids either
missing the likelihood peak, or introducing biases in the
reconstructed time posterior due to an insufficient sampling frequency.
Default is False, however using this parameter is strongly encouraged.
Returns
-------
@@ -90,7 +96,8 @@ class GravitationalWaveTransient(likelihood.Likelihood):
def __init__(self, interferometers, waveform_generator,
time_marginalization=False, distance_marginalization=False,
phase_marginalization=False, priors=None,
distance_marginalization_lookup_table=None):
distance_marginalization_lookup_table=None,
jitter_time=False):
self.waveform_generator = waveform_generator
likelihood.Likelihood.__init__(self, dict())
@@ -100,13 +107,20 @@ class GravitationalWaveTransient(likelihood.Likelihood):
self.phase_marginalization = phase_marginalization
self.priors = priors
self._check_set_duration_and_sampling_frequency_of_waveform_generator()
self.jitter_time = jitter_time
if self.time_marginalization:
self._check_prior_is_set(key='geocent_time')
self._setup_time_marginalization()
priors['geocent_time'] = float(self.interferometers.start_time)
priors['time_jitter'] = Uniform(
minimum=- self._delta_tc / 2, maximum=self._delta_tc / 2)
if self.jitter_time:
priors['time_jitter'] = Uniform(
minimum=- self._delta_tc / 2, maximum=self._delta_tc / 2)
elif self.jitter_time:
logger.info(
"Time jittering requested with non-time-marginalised "
"likelihood, ignoring.")
self.jitter_time = False
if self.phase_marginalization:
self._check_prior_is_set(key='phase')
@@ -234,7 +248,8 @@ class GravitationalWaveTransient(likelihood.Likelihood):
optimal_snr_squared = 0.
complex_matched_filter_snr = 0.
if self.time_marginalization:
self.parameters['geocent_time'] += self.parameters['time_jitter']
if self.jitter_time:
self.parameters['geocent_time'] += self.parameters['time_jitter']
d_inner_h_tc_array = np.zeros(
self.interferometers.frequency_array[0:-1].shape,
dtype=np.complex128)
@@ -252,8 +267,11 @@ class GravitationalWaveTransient(likelihood.Likelihood):
d_inner_h_tc_array += per_detector_snr.d_inner_h_squared_tc_array
if self.time_marginalization:
times = self._times + self.parameters['time_jitter']
self.parameters['geocent_time'] -= self.parameters['time_jitter']
if self.jitter_time:
times = self._times + self.parameters['time_jitter']
self.parameters['geocent_time'] -= self.parameters['time_jitter']
else:
times = self._times
self.time_prior_array = self.priors['geocent_time'].prob(times) * self._delta_tc
log_l = self.time_marginalized_likelihood(
d_inner_h_tc_array=d_inner_h_tc_array,
@@ -330,7 +348,8 @@ class GravitationalWaveTransient(likelihood.Likelihood):
new_time: float
Sample from the time posterior.
"""
self.parameters['geocent_time'] += self.parameters['time_jitter']
if self.jitter_time:
self.parameters['geocent_time'] += self.parameters['time_jitter']
if signal_polarizations is None:
signal_polarizations = \
self.waveform_generator.frequency_domain_strain(self.parameters)
@@ -362,7 +381,8 @@ class GravitationalWaveTransient(likelihood.Likelihood):
else:
time_log_like = (d_inner_h_tc_array.real - h_inner_h.real / 2)
times = self._times + self.parameters['time_jitter']
if self.jitter_time:
times = self._times + self.parameters['time_jitter']
time_prior_array = self.priors['geocent_time'].prob(times)
time_post = (
Loading