Commit c579bfce authored by Gregory Ashton's avatar Gregory Ashton

Merge branch 'jitter-time-marginalisation' into 'master'

Jitter time marginalisation

Closes #373

See merge request lscsoft/bilby!535
parents c04a7e97 4c7a36b9
Pipeline #68524 passed with stage
in 5 minutes and 17 seconds
......@@ -18,8 +18,7 @@ from scipy.special import i0e
from ..core import likelihood
from ..core.utils import (
logger, UnsortedInterp2d, BilbyJsonEncoder, decode_bilby_json,
create_frequency_series, create_time_series, speed_of_light,
radius_of_earth)
create_frequency_series, speed_of_light, radius_of_earth)
from ..core.prior import Interped, Prior, Uniform
from .detector import InterferometerList
from .prior import BBHPriorDict
......@@ -49,12 +48,21 @@ class GravitationalWaveTransient(likelihood.Likelihood):
distance_marginalization: bool, optional
If true, marginalize over distance in the likelihood.
This uses a look up table calculated at run time.
The distance prior is set to be a delta function at the minimum
distance allowed in the prior being marginalised over.
time_marginalization: bool, optional
If true, marginalize over time in the likelihood.
This uses a FFT.
This uses a FFT to calculate the likelihood over a regularly spaced
grid.
In order to cover the whole space the prior is set to be uniform over
the spacing of the array of times.
If using time marginalisation and jitter_time is True a "jitter"
parameter is added to the prior which modifies the position of the
grid of times.
phase_marginalization: bool, optional
If true, marginalize over phase in the likelihood.
This is done analytically using a Bessel function.
The phase prior is set to be a delta function at phase=0.
priors: dict, optional
If given, used in the distance and phase marginalization.
distance_marginalization_lookup_table: (dict, str), optional
......@@ -65,6 +73,11 @@ class GravitationalWaveTransient(likelihood.Likelihood):
The lookup table is stored after construction in either the
provided string or a default location:
'.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz'
jitter_time: bool, optional
Whether to introduce a `time_jitter` parameter. This avoids either
missing the likelihood peak, or introducing biases in the
reconstructed time posterior due to an insufficient sampling frequency.
Default is False, however using this parameter is strongly encouraged.
Returns
-------
......@@ -83,7 +96,8 @@ class GravitationalWaveTransient(likelihood.Likelihood):
def __init__(self, interferometers, waveform_generator,
time_marginalization=False, distance_marginalization=False,
phase_marginalization=False, priors=None,
distance_marginalization_lookup_table=None):
distance_marginalization_lookup_table=None,
jitter_time=True):
self.waveform_generator = waveform_generator
likelihood.Likelihood.__init__(self, dict())
......@@ -93,11 +107,20 @@ class GravitationalWaveTransient(likelihood.Likelihood):
self.phase_marginalization = phase_marginalization
self.priors = priors
self._check_set_duration_and_sampling_frequency_of_waveform_generator()
self.jitter_time = jitter_time
if self.time_marginalization:
self._check_prior_is_set(key='geocent_time')
self._setup_time_marginalization()
priors['geocent_time'] = float(self.interferometers.start_time)
if self.jitter_time:
priors['time_jitter'] = Uniform(
minimum=- self._delta_tc / 2, maximum=self._delta_tc / 2)
elif self.jitter_time:
logger.info(
"Time jittering requested with non-time-marginalised "
"likelihood, ignoring.")
self.jitter_time = False
if self.phase_marginalization:
self._check_prior_is_set(key='phase')
......@@ -225,6 +248,8 @@ class GravitationalWaveTransient(likelihood.Likelihood):
optimal_snr_squared = 0.
complex_matched_filter_snr = 0.
if self.time_marginalization:
if self.jitter_time:
self.parameters['geocent_time'] += self.parameters['time_jitter']
d_inner_h_tc_array = np.zeros(
self.interferometers.frequency_array[0:-1].shape,
dtype=np.complex128)
......@@ -242,6 +267,12 @@ class GravitationalWaveTransient(likelihood.Likelihood):
d_inner_h_tc_array += per_detector_snr.d_inner_h_squared_tc_array
if self.time_marginalization:
if self.jitter_time:
times = self._times + self.parameters['time_jitter']
self.parameters['geocent_time'] -= self.parameters['time_jitter']
else:
times = self._times
self.time_prior_array = self.priors['geocent_time'].prob(times) * self._delta_tc
log_l = self.time_marginalized_likelihood(
d_inner_h_tc_array=d_inner_h_tc_array,
h_inner_h=optimal_snr_squared)
......@@ -317,46 +348,46 @@ class GravitationalWaveTransient(likelihood.Likelihood):
new_time: float
Sample from the time posterior.
"""
if self.jitter_time:
self.parameters['geocent_time'] += self.parameters['time_jitter']
if signal_polarizations is None:
signal_polarizations = \
self.waveform_generator.frequency_domain_strain(self.parameters)
n_time_steps = int(self.waveform_generator.duration * 16384)
d_inner_h = np.zeros(n_time_steps, dtype=np.complex)
psd = np.ones(n_time_steps)
signal_long = np.zeros(n_time_steps, dtype=np.complex)
data = np.zeros(n_time_steps, dtype=np.complex)
h_inner_h = np.zeros(1)
for ifo in self.interferometers:
ifo_length = len(ifo.frequency_domain_strain)
signal = ifo.get_detector_response(
signal_polarizations, self.parameters)
signal_long[:ifo_length] = signal
data[:ifo_length] = np.conj(ifo.frequency_domain_strain)
psd[:ifo_length] = ifo.power_spectral_density_array
d_inner_h += np.fft.fft(signal_long * data / psd)
h_inner_h += ifo.optimal_snr_squared(signal=signal).real
d_inner_h = 0.
h_inner_h = 0.
complex_matched_filter_snr = 0.
d_inner_h_tc_array = np.zeros(
self.interferometers.frequency_array[0:-1].shape,
dtype=np.complex128)
for interferometer in self.interferometers:
per_detector_snr = self.calculate_snrs(
signal_polarizations, interferometer)
d_inner_h += per_detector_snr.d_inner_h
h_inner_h += per_detector_snr.optimal_snr_squared
complex_matched_filter_snr += per_detector_snr.complex_matched_filter_snr
if self.time_marginalization:
d_inner_h_tc_array += per_detector_snr.d_inner_h_squared_tc_array
if self.distance_marginalization:
time_log_like = self.distance_marginalized_likelihood(
d_inner_h, h_inner_h)
elif self.phase_marginalization:
time_log_like = (self._bessel_function_interped(abs(d_inner_h)) -
h_inner_h.real / 2)
time_log_like = (
self._bessel_function_interped(abs(d_inner_h_tc_array)) -
h_inner_h.real / 2)
else:
time_log_like = (d_inner_h.real - h_inner_h.real / 2)
time_log_like = (d_inner_h_tc_array.real - h_inner_h.real / 2)
times = create_time_series(
sampling_frequency=16384,
starting_time=self.waveform_generator.start_time,
duration=self.waveform_generator.duration)
if self.jitter_time:
times = self._times + self.parameters['time_jitter']
time_prior_array = self.priors['geocent_time'].prob(times)
time_post = (
np.exp(time_log_like - max(time_log_like)) * time_prior_array)
keep = (time_post > max(time_post) / 1000)
time_post = time_post[keep]
times = times[keep]
new_time = Interped(times, time_post).sample()
return new_time
......@@ -624,14 +655,14 @@ class GravitationalWaveTransient(likelihood.Likelihood):
bounds_error=False, fill_value=(0, np.nan))
def _setup_time_marginalization(self):
delta_tc = 2 / self.waveform_generator.sampling_frequency
self._delta_tc = 2 / self.waveform_generator.sampling_frequency
self._times =\
self.interferometers.start_time + np.linspace(
0, self.interferometers.duration,
int(self.interferometers.duration / 2 *
self.waveform_generator.sampling_frequency + 1))[1:]
self.time_prior_array =\
self.priors['geocent_time'].prob(self._times) * delta_tc
self.time_prior_array = \
self.priors['geocent_time'].prob(self._times) * self._delta_tc
@property
def interferometers(self):
......@@ -795,7 +826,9 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
waveform_generator=waveform_generator, priors=priors,
distance_marginalization=distance_marginalization,
phase_marginalization=phase_marginalization,
distance_marginalization_lookup_table=distance_marginalization_lookup_table)
time_marginalization=False,
distance_marginalization_lookup_table=distance_marginalization_lookup_table,
jitter_time=False)
if isinstance(roq_params, np.ndarray) or roq_params is None:
self.roq_params = roq_params
......
......@@ -184,11 +184,11 @@ class TestTimeMarginalization(unittest.TestCase):
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
start_time=1126259640)
self.prior = bilby.gw.prior.BBHPriorDict()
self.priors = bilby.gw.prior.BBHPriorDict()
self.likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
interferometers=self.interferometers,
waveform_generator=self.waveform_generator, priors=self.prior.copy()
waveform_generator=self.waveform_generator, priors=self.priors.copy()
)
self.likelihood.parameters = self.parameters.copy()
......@@ -199,7 +199,7 @@ class TestTimeMarginalization(unittest.TestCase):
del self.parameters
del self.interferometers
del self.waveform_generator
del self.prior
del self.priors
del self.likelihood
def test_time_marginalisation_full_segment(self):
......@@ -209,13 +209,13 @@ class TestTimeMarginalization(unittest.TestCase):
"""
likes = []
lls = []
self.prior['geocent_time'] = bilby.prior.Uniform(
self.priors['geocent_time'] = bilby.prior.Uniform(
minimum=self.waveform_generator.start_time,
maximum=self.waveform_generator.start_time + self.duration)
self.time = bilby.gw.likelihood.GravitationalWaveTransient(
interferometers=self.interferometers,
waveform_generator=self.waveform_generator,
time_marginalization=True, priors=self.prior.copy()
time_marginalization=True, priors=self.priors
)
times = self.waveform_generator.start_time + np.linspace(
0, self.duration, 4097)[:-1]
......@@ -225,8 +225,9 @@ class TestTimeMarginalization(unittest.TestCase):
likes.append(np.exp(lls[-1]))
marg_like = np.log(np.trapz(
likes * self.prior['geocent_time'].prob(times), times))
likes * self.time.priors['geocent_time'].prob(times), times))
self.time.parameters = self.parameters.copy()
self.time.parameters['time_jitter'] = 0.0
self.time.parameters['geocent_time'] = self.waveform_generator.start_time
self.assertAlmostEqual(marg_like, self.time.log_likelihood_ratio(),
delta=0.5)
......@@ -238,13 +239,13 @@ class TestTimeMarginalization(unittest.TestCase):
"""
likes = []
lls = []
self.prior['geocent_time'] = bilby.prior.Uniform(
self.priors['geocent_time'] = bilby.prior.Uniform(
minimum=self.parameters['geocent_time'] + 1 - 0.1,
maximum=self.parameters['geocent_time'] + 1 + 0.1)
self.time = bilby.gw.likelihood.GravitationalWaveTransient(
interferometers=self.interferometers,
waveform_generator=self.waveform_generator,
time_marginalization=True, priors=self.prior.copy()
time_marginalization=True, priors=self.priors
)
times = self.waveform_generator.start_time + np.linspace(
0, self.duration, 4097)[:-1]
......@@ -254,8 +255,9 @@ class TestTimeMarginalization(unittest.TestCase):
likes.append(np.exp(lls[-1]))
marg_like = np.log(np.trapz(
likes * self.prior['geocent_time'].prob(times), times))
likes * self.time.priors['geocent_time'].prob(times), times))
self.time.parameters = self.parameters.copy()
self.time.parameters['time_jitter'] = 0.0
self.time.parameters['geocent_time'] = self.waveform_generator.start_time
self.assertAlmostEqual(marg_like, self.time.log_likelihood_ratio(),
delta=0.5)
......@@ -410,33 +412,33 @@ class TestTimePhaseMarginalization(unittest.TestCase):
frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
start_time=1126259640)
self.prior = bilby.gw.prior.BBHPriorDict()
self.prior['geocent_time'] = bilby.prior.Uniform(
self.priors = bilby.gw.prior.BBHPriorDict()
self.priors['geocent_time'] = bilby.prior.Uniform(
minimum=self.parameters['geocent_time'] - self.duration / 2,
maximum=self.parameters['geocent_time'] + self.duration / 2)
self.likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
interferometers=self.interferometers,
waveform_generator=self.waveform_generator, priors=self.prior.copy()
waveform_generator=self.waveform_generator, priors=self.priors.copy()
)
self.time = bilby.gw.likelihood.GravitationalWaveTransient(
interferometers=self.interferometers,
waveform_generator=self.waveform_generator,
time_marginalization=True, priors=self.prior.copy()
time_marginalization=True, priors=self.priors.copy()
)
self.phase = bilby.gw.likelihood.GravitationalWaveTransient(
interferometers=self.interferometers,
waveform_generator=self.waveform_generator,
phase_marginalization=True, priors=self.prior.copy()
phase_marginalization=True, priors=self.priors.copy()
)
self.time_phase = bilby.gw.likelihood.GravitationalWaveTransient(
interferometers=self.interferometers,
waveform_generator=self.waveform_generator,
time_marginalization=True, phase_marginalization=True,
priors=self.prior.copy()
priors=self.priors
)
for like in [self.likelihood, self.time, self.phase, self.time_phase]:
like.parameters = self.parameters.copy()
......@@ -447,17 +449,21 @@ class TestTimePhaseMarginalization(unittest.TestCase):
del self.parameters
del self.interferometers
del self.waveform_generator
del self.prior
del self.priors
del self.likelihood
del self.time
del self.phase
del self.time_phase
def test_time_phase_marginalisation(self):
"""Test time and marginalised likelihood matches brute force version"""
def test_time_marginalisation(self):
"""
Test time marginalised likelihood matches brute force version when
also marginalising over phase.
"""
like = []
times = np.linspace(self.prior['geocent_time'].minimum,
self.prior['geocent_time'].maximum, 4097)[:-1]
times = np.linspace(
self.time_phase.priors['geocent_time'].minimum,
self.time_phase.priors['geocent_time'].maximum, 4097)[:-1]
for time in times:
self.phase.parameters['geocent_time'] = time
like.append(np.exp(self.phase.log_likelihood_ratio()))
......@@ -465,18 +471,26 @@ class TestTimePhaseMarginalization(unittest.TestCase):
marg_like = np.log(np.trapz(like, times)
/ self.waveform_generator.duration)
self.time_phase.parameters = self.parameters.copy()
self.time_phase.parameters['time_jitter'] = 0.0
self.assertAlmostEqual(marg_like,
self.time_phase.log_likelihood_ratio(),
delta=0.5)
def test_phase_marginalisation(self):
"""
Test phase marginalised likelihood matches brute force version when
also marginalising over time.
"""
like = []
phases = np.linspace(0, 2 * np.pi, 1000)
for phase in phases:
self.time.parameters['phase'] = phase
self.time.parameters['time_jitter'] = 0.0
like.append(np.exp(self.time.log_likelihood_ratio()))
marg_like = np.log(np.trapz(like, phases) / (2 * np.pi))
self.time_phase.parameters = self.parameters.copy()
self.time_phase.parameters['time_jitter'] = 0.0
self.assertAlmostEqual(marg_like,
self.time_phase.log_likelihood_ratio(),
delta=0.5)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment