Commit 787ec872 authored by RUdall's avatar RUdall
Browse files

Some updates, maybe closing in on joint likelihood (without distance...

Some updates, maybe closing in on joint likelihood (without distance marginalization; will test next
parent 69c41de1
Pipeline #478141 failed with stages
in 50 minutes and 50 seconds
......@@ -313,7 +313,9 @@ class Interferometer(object):
for mode in waveform_polarizations.keys():
if "glitch" in mode:
# If it's a glitch mode, check if it is in this detector or not
if mode == f"glitch-{self.name}":
# In the joint inference likelihood this should be handled already
# But just in case
if mode.split("_")[1] == self.name:
signal[mode] = waveform_polarizations[mode]
else:
signal[mode] = np.zeros(waveform_polarizations[mode].shape)
......@@ -617,7 +619,7 @@ class Interferometer(object):
Returns
=======
float: The noise weighted inner product of the two templates
float: The noise weighted inner product of the two templates
"""
return gwutils.noise_weighted_inner_product(
aa=signal_1[self.strain_data.frequency_mask],
......
......@@ -7,6 +7,7 @@ from .waveform_generator import WaveformGenerator
logger = logging.getLogger(__name__)
def slow_scattering(f,
geocent_time=0,
f_harm0=30,
......@@ -108,7 +109,8 @@ def slow_scattering(f,
Whether all higher arches are centered at the same time - if True will compute more efficiently
time_array :
Precomputed time array corresponding to the data characteristics, passed for efficiency
glitch_mode_name : str
The name to assign the glitch mode - should take the form {ifo}-{nth glitch}-glitch
Returns
------------
strain_dict : dict
......@@ -212,22 +214,52 @@ def slow_scattering(f,
htilde_shifted = htilde_temp[0] * np.exp(-2 * np.pi * 1j * htilde_temp[1] * (time_to_rotate_base))
htilde = htilde_shifted
return {'glitch': htilde, 'plus': htilde * 0, 'cross': htilde * 0}
return {kwargs["glitch_mode_name"]: htilde, 'plus': htilde * 0, 'cross': htilde * 0}
_model_map = dict(slow_scattering=slow_scattering)
_model_map=dict(slow_scattering=slow_scattering)
def glitch_waveform_generator_factory(model_name, duration, sampling_frequency, start_time, model_kwargs=dict()):
def glitch_waveform_generator_factory(model_name, duration, sampling_frequency, start_time,
glitch_name, interferometer, model_kwargs=dict()):
""""""
model_kwargs["duration"] = duration
model_kwargs["sampling_rate"] = sampling_frequency
model_kwargs["glitch_mode_name"] = f"glitch_{interferometer}_{glitch_name}"
waveform_generator = WaveformGenerator(
duration=duration,
sampling_frequency=sampling_frequency,
frequency_domain_source_model=_model_map[model_name],
waveform_arguments=model_kwargs,
parameter_conversion=None,
start_time=start_time,
)
logger.info(waveform_generator)
duration=duration,
sampling_frequency=sampling_frequency,
frequency_domain_source_model=_model_map[model_name],
waveform_arguments=model_kwargs,
parameter_conversion=None,
start_time=start_time,
)
waveform_generator.waveform_arguments["time_array"] = waveform_generator.time_array
logger.info(waveform_generator)
pass
\ No newline at end of file
return waveform_generator
# def joint_waveform_glitch_factory(
# glitches, cbc_waveform_frequency_domain_source_model, duration, sampling_frequency,
# start_time, cbc_parameter_conversion=None, glitch_waveform_arguments=dict(), cbc_waveform_arguments=dict(),):
# """
# Parameters
# ==========
# sampling_frequency: float, optional
# The sampling frequency
# duration: float, optional
# Time duration of data
# start_time: float, optional
# Starting time of the time array
# cbc_waveform_frequency_domain_source_model: func, optional
# A python function taking some arguments and returning the frequency
# domain strain of the CBC signal. Note the first argument must be the frequencies at
# which to compute the strain
# parameter_conversion: func, optional
# Function to convert from sampled parameters to parameters of the
# waveform generator. Default value is the identity, i.e. it leaves
# the parameters unaffected.
# cbc_waveform_arguments: dict, optional
# A dictionary of fixed keyword arguments to pass to
# `cbc_waveform_frequency_domain_source_model`.
# """
import os
import copy
from time import time
from weakref import ref
import attr
from typing import Dict, List
import numpy as np
import pandas as pd
from scipy.special import logsumexp
......@@ -16,8 +13,6 @@ from ...core.prior import Interped, Prior, Uniform, PriorDict, DeltaFunction
from ..detector import InterferometerList, get_empty_interferometer, calibration
from ..prior import BBHPriorDict, Cosmological
from ..utils import noise_weighted_inner_product, zenith_azimuth_to_ra_dec, ln_i0
from bilby.bilby.core import prior
from bilby.gw.waveform_generator import WaveformGenerator
class GravitationalWaveTransient(Likelihood):
......@@ -238,10 +233,9 @@ class GravitationalWaveTransient(Likelihood):
self.priors)
def _check_set_duration_and_sampling_frequency_of_waveform_generator(self):
""" Check the waveform_generator has the same duration and
sampling_frequency as the interferometers. If they are unset, then
set them, if they differ, raise an error
"""
"""Check the waveform_generator has the same duration and
sampling_frequency as the interferometers. If they are unset, then
set them, if they differ, raise an error"""
attributes = ['duration', 'sampling_frequency', 'start_time']
for attribute in attributes:
......@@ -1179,31 +1173,38 @@ class GravitationalWaveTransient(Likelihood):
lal_version=self.lal_version,
lalsimulation_version=self.lalsimulation_version)
class JointGlitchGravitationalWaveTransient(GravitationalWaveTransient):
"""A likelihood class for joint inference of glitches and gravitational waves.
"""A likelihood class for joint inference of glitches and gravitational waves.
Parameters
==========
See superclass `GravitationalWaveTransient`
glitch_generators : Dict[str:List[`bilby.gw.waveform_generator.WaveformGenerator`]]
A dict associating each IFOs to glitches which occurred within them in the analysis segment.
cbc_waveform_generator : `bilby.gw.waveform_generator.WaveformGenerator`
Equivalent to the 'waveform_generator' in `
glitch_generators : Dict[str, `bilby.gw.waveform_generator.WaveformGenerator`]
A dict mapping names of glitches to the waveform generators which generate them.
Priors passed to these should have parameter names formatted as
"glitch_{glitch_name}_{param_name}".
The mode output should take the format
"glitch_{IFO}_{glitch_name}", where here the {glitch_name} element is simply
for bookkeeping.
"""
def __init__(self, interferometers, waveform_generator, glitch_generators : Dict[str:List[WaveformGenerator]], time_marginalization=False,
distance_marginalization=False, phase_marginalization=False, calibration_marginalization=False, priors=None,
distance_marginalization_lookup_table=None, calibration_lookup_table=None,
number_of_response_curves=1000, starting_index=0, jitter_time=True, reference_frame="sky",
time_reference="geocenter"):
def __init__(self, interferometers, cbc_waveform_generator, glitch_generators, time_marginalization=False,
distance_marginalization=False, phase_marginalization=False,
calibration_marginalization=False, priors=None,
distance_marginalization_lookup_table=None, calibration_lookup_table=None,
number_of_response_curves=1000, starting_index=0, jitter_time=True, reference_frame="sky",
time_reference="geocenter"):
# For now, simplify by only applying distance marginalization
#TODO if possible apply others (should be for e.g. time, shouldn't be for e.g. phase)
# TODO if possible apply others (should be for e.g. time, shouldn't be for e.g. phase)
assert not time_marginalization, "Time marginalization is not presently supported."
assert not phase_marginalization, "Phase marginalization is not presently supported."
assert not calibration_marginalization, "Calibration marginalization is not presently supported."
super(GravitationalWaveTransient, self).__init__(
interferometers,
waveform_and_glitch_generator,
cbc_waveform_generator,
time_marginalization=False,
distance_marginalization=distance_marginalization,
phase_marginalization=False,
......@@ -1217,6 +1218,7 @@ class JointGlitchGravitationalWaveTransient(GravitationalWaveTransient):
reference_frame=reference_frame,
time_reference=time_reference,
)
self.glitch_generators = glitch_generators
@attr.s
class _CalculatedSNRs:
......@@ -1230,30 +1232,42 @@ class JointGlitchGravitationalWaveTransient(GravitationalWaveTransient):
def __repr__(self):
return self.__class__.__name__ + '(interferometers={},\n\twaveform_generator={},\n\t ' \
'glitch_generator={},\n\tdistance_marginalization={}, priors={})' \
.format(self.interferometers, self.waveform_generator, self.glitch_generator,
'glitch_generators={},\n\tdistance_marginalization={}, priors={})' \
.format(self.interferometers, self.waveform_generator, self.glitch_generators,
self.distance_marginalization, self.priors)
def log_likelihood_ratio(self):
waveform_polarizations = \
self.waveform_and_glitch_generator.frequency_domain_strain(self.parameters)
self.parameters.update(self.get_sky_frame_parameters())
if waveform_polarizations is None:
return np.nan_to_num(-np.inf)
waveform_polarizations = dict()
assert self.glitch_generators != dict(), "At least one glitch generator\
should be passed, else why are you using this?"
for glitch_name, glitch_generator in self.glitch_generators.items():
# Input parameters are formatted as "glitch_{glitch_name}_{param_name}"
# We want the parameters associated with the glitch in question
# Then, transform to what frequency domain strain will expect
parameters_of_interest = {
key.split("_")[-1] : parameter for key, parameter in self.parameters if glitch_name in key
}
waveform_polarizations.update(glitch_generator.frequency_domain_strain(parameters_of_interest))
if self.waveform_generator is not None:
# Allow passing None to do the case of PE on a glitch (or glitches) alone
cbc_polarizations = self.waveform_generator.frequency_domain_strain(self.parameters)
if cbc_polarizations is None:
return np.nan_to_num(-np.inf)
else:
waveform_polarizations.update(cbc_polarizations)
self.parameters.update(self.get_sky_frame_parameters())
d_inner_h = 0.
d_inner_g = 0.
g_inner_h = 0.
optimal_gw_snr_squared = 0.
optimal_glitch_snr_squard = 0.
optimal_glitch_snr_squared = 0.
complex_matched_filter_gw_snr = 0.
complex_matched_filter_glitch_snr = 0.
for interferometer in self.interferometers:
per_detector_snr = self.calculate_snrs(
waveform_polarizations=waveform_polarizations[[mode for mode in waveform_polarizations if "glitch" in mode]],
waveform_polarizations=waveform_polarizations,
interferometer=interferometer)
d_inner_h += per_detector_snr.d_inner_h
......@@ -1269,11 +1283,11 @@ class JointGlitchGravitationalWaveTransient(GravitationalWaveTransient):
d_inner_h=d_inner_h, h_inner_h=optimal_gw_snr_squared)
else:
log_l = (np.real(d_inner_h) + np.real(d_inner_g) - np.real(g_inner_h) \
- optimal_gw_snr_squared / 2 - optimal_glitch_snr_squared / 2)
log_l = (np.real(d_inner_h) + np.real(d_inner_g) - np.real(g_inner_h)
- optimal_gw_snr_squared / 2 - optimal_glitch_snr_squared / 2)
return float(log_l.real)
def distance_marginalized_likelihood(self, d_inner_h, g_inner_h, h_inner_h):
d_inner_h_ref, g_inner_h_ref, h_inner_h_ref = self._setup_rho(
d_inner_h, g_inner_h, h_inner_h)
......@@ -1343,34 +1357,34 @@ class JointGlitchGravitationalWaveTransient(GravitationalWaveTransient):
The bilby interferometer object
"""
#TODO glitches are only in 1 detector - make it possible to track which one!!
if "glitch" in waveform_polarizations:
glitch_signal = interferometer.get_detector_response(
dict(glitch = waveform_polarizations.pop("glitch")),
self.parameters
)
glitch_signal = interferometer.get_detector_response(
waveform_polarizations[
[mode_name for mode_name in waveform_polarizations.keys()
if "glitch" in mode_name
and interferometer.name in mode_name]
],
self.parameters
)
gw_signal = interferometer.get_detector_response(
waveform_polarizations, self.parameters)
_mask = interferometer.frequency_mask
waveform_polarizations[["plus", "cross"]], self.parameters)
d_inner_h = interferometer.inner_product(signal=gw_signal)
d_inner_g = interferometer.inner_product(signal=glitch_signal)
g_inner_h = interferometer.template_template_inner_product(signal_1=gw_signal, signal_2=glitch_signal)
optimal_gw_snr_squared = interferometer.optimal_snr_squared(signal=gw_signal)
optimal_glitch_snr_squard = interferometer.optimal_snr_squared(signal=glitch_signal)
optimal_glitch_snr_squared = interferometer.optimal_snr_squared(signal=glitch_signal)
complex_matched_filter_gw_snr = d_inner_h / (optimal_gw_snr_squared**0.5)
complex_matched_filter_glitch_snr = d_inner_g / (optimal_glitch_snr_squard**0.5)
normalization = 4 / self.waveform_generator.duration
complex_matched_filter_glitch_snr = d_inner_g / (optimal_glitch_snr_squared**0.5)
return self._CalculatedSNRs(
d_inner_h=d_inner_h, d_inner_g=d_inner_g, g_inner_h=g_inner_h,
optimal_gw_snr_squared=optimal_gw_snr_squared,
optimal_glitch_snr_squard=optimal_glitch_snr_squard,
optimal_glitch_snr_squared=optimal_glitch_snr_squared,
complex_matched_filter_gw_snr=complex_matched_filter_gw_snr,
complex_matched_filter_glitch_snr=complex_matched_filter_glitch_snr
)
)
def _setup_rho(self, d_inner_h, g_inner_h, optimal_snr_squared):
optimal_snr_squared_ref = (optimal_snr_squared.real *
......@@ -1378,6 +1392,6 @@ class JointGlitchGravitationalWaveTransient(GravitationalWaveTransient):
self._ref_dist ** 2.)
d_inner_h_ref = (d_inner_h * self.parameters['luminosity_distance'] /
self._ref_dist)
g_inner_h_ref = (g_inner_h * self.parameters['luminosity_distance'] /
g_inner_h_ref = (g_inner_h * self.parameters['luminosity_distance'] /
self._ref_dist)
return d_inner_h_ref, g_inner_h_ref, optimal_snr_squared_ref
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment