Forked from
lscsoft / bilby
This fork has diverged from the upstream repository.
-
Samson Leong authoredSamson Leong authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
psi4_waveform_generator.py 7.20 KiB
import numpy as np
from ..core import utils
from .waveform_generator import WaveformGenerator
class Psi4_WaveformGenerator(WaveformGenerator):
def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequency_domain_source_model=None,
time_domain_source_model=None, parameters=None,
parameter_conversion=None,
waveform_arguments=None):
"""
The Psi4 waveform generator class.
Parameters
==========
sampling_frequency: float, optional
The sampling frequency
duration: float, optional
Time duration of data
start_time: float, optional
Starting time of the time array
frequency_domain_source_model: func, optional
A python function taking some arguments and returning the frequency
domain strain. Note the first argument must be the frequencies at
which to compute the strain
time_domain_source_model: func, optional
A python function taking some arguments and returning the time
domain strain. Note the first argument must be the times at
which to compute the strain
parameters: dict, optional
Initial values for the parameters
parameter_conversion: func, optional
Function to convert from sampled parameters to parameters of the
waveform generator. Default value is the identity, i.e. it leaves
the parameters unaffected.
waveform_arguments: dict, optional
A dictionary of fixed keyword arguments to pass to either
`frequency_domain_source_model` or `time_domain_source_model`.
*** In order to distinguish between strain model and psi4 model,
the keyword `is_psi4_model` should be within the waveform_arguments,
default is False if not specified. ***
Such implementation is to perserve the number of arguments this
waveform_generator takes, such that `bilby_pipe` can work properly.
Note: the arguments of frequency_domain_source_model (except the first,
which is the frequencies at which to compute the strain) will be added to
the WaveformGenerator object and initialised to `None`.
"""
super().__init__(duration=duration, sampling_frequency=sampling_frequency, start_time=start_time,
frequency_domain_source_model=frequency_domain_source_model,
time_domain_source_model=time_domain_source_model, parameters=parameters,
parameter_conversion=parameter_conversion,
waveform_arguments=None)
if waveform_arguments is not None:
self.is_psi4_model = waveform_arguments.pop('is_psi4_model', False)
self.waveform_arguments = waveform_arguments
else:
self.waveform_arguments = dict()
self.is_psi4_model = False
# Correction factor between derivative and differencing
# Details please refers to the article: arxiv:2205.15029
# The argument in the correction factor: 2π kΔf Δt
phase = 2 * np.pi * self.frequency_array / self.sampling_frequency
## This correction_factor depends whether the data is strain or psi4,
## the first is for psi4, (A5) in the article; and
## the second is for strain, see (A4).
if self.is_psi4_model:
mask = phase != 0
self.correction_factor = np.zeros(len(phase))
non_zero_phase = phase[mask]
self.correction_factor[mask] = \
(1 - np.cos(non_zero_phase)) / (0.5 * non_zero_phase * non_zero_phase)
else:
self.correction_factor = \
2 * self.sampling_frequency * self.sampling_frequency * (np.cos(phase) - 1)
self._cache = dict(parameters=None, FD_waveform=None, TD_waveform=None)
def time_domain_psi4(self, parameters=None):
return self._calculate_time_domain_psi4(parameters=parameters)
def frequency_domain_psi4(self, parameters=None):
return self._calculate_frequency_domain_psi4(parameters=parameters)
def time_domain_strain(self, parameters=None):
"""
Retain this function as most internal process in Bilby relies on this attribute name.
But returning time domain Psi4.
"""
return self._calculate_time_domain_psi4(parameters=parameters)
def frequency_domain_strain(self, parameters=None):
"""
Retain this function as most internal process in Bilby relies on this attribute name.
But returning frequency domain Psi4.
"""
return self._calculate_frequency_domain_psi4(parameters=parameters)
def _calculate_frequency_domain_psi4(self, parameters):
if parameters is not None:
self.parameters = parameters
if self.parameters == self._cache['parameters'] and self._cache['FD_waveform'] is not None:
return self._cache['FD_waveform']
is_FD_model = True
if self.frequency_domain_source_model is not None:
model_waveform = self.frequency_domain_source_model(self.frequency_array, **self.parameters)
elif self.time_domain_source_model is not None:
model_waveform = self.time_domain_source_model(self.time_array, **self.parameters)
is_FD_model = False
else:
raise RuntimeError("No source model given")
if isinstance(model_waveform, np.ndarray):
FD_psi4 = self._FD_psi4_from_FD_waveform(model_waveform) \
if is_FD_model else \
self._FD_psi4_from_TD_waveform(model_waveform)
elif model_waveform is None:
FD_psi4 = None
else:
FD_psi4 = dict()
for key in model_waveform:
FD_psi4[key] = self._FD_psi4_from_FD_waveform(model_waveform[key]) \
if is_FD_model else \
self._FD_psi4_from_TD_waveform(model_waveform[key])
self._cache['FD_waveform'] = FD_psi4
self._cache['parameters'] = self.parameters.copy()
return FD_psi4
def _calculate_time_domain_psi4(self, parameters):
if parameters is not None:
self.parameters = parameters
if self.parameters == self._cache['parameters'] and self._cache['TD_waveform'] is not None:
return self._cache['TD_waveform']
FD_psi4 = self._calculate_frequency_domain_psi4(parameters)
if isinstance(FD_psi4, np.ndarray):
TD_psi4 = utils.infft(FD_psi4, self.sampling_frequency)
elif FD_psi4 is None:
TD_psi4 = None
else:
TD_psi4 = dict()
for key in FD_psi4:
TD_psi4[key] = utils.infft(FD_psi4[key], self.sampling_frequency)
self._cache['TD_waveform'] = TD_psi4
self._cache['parameters'] = self.parameters.copy()
return TD_psi4
def _FD_psi4_from_FD_waveform(self, fd_waveform):
return self.correction_factor * fd_waveform
def _FD_psi4_from_TD_waveform(self, td_waveform):
fd_waveform, _ = utils.nfft(td_waveform, self.sampling_frequency)
return self.correction_factor * fd_waveform