Skip to content
Snippets Groups Projects
Commit 7b007480 authored by Samson Leong's avatar Samson Leong :stuck_out_tongue:
Browse files

fix some bugs in psi4 wfm gen.

parent c24ba667
No related branches found
No related tags found
No related merge requests found
import numpy as np
from ..core import utils
from ..core.series import CoupledTimeAndFrequencySeries
from .conversion import convert_to_lal_binary_black_hole_parameters
from .waveform_generator import WaveformGenerator
class Psi4_WaveformGenerator(WaveformGenerator):
......@@ -38,11 +37,11 @@ class Psi4_WaveformGenerator(WaveformGenerator):
waveform_arguments: dict, optional
A dictionary of fixed keyword arguments to pass to either
`frequency_domain_source_model` or `time_domain_source_model`.
*** In order to distinguish between strain model and psi4 model,
*** In order to distinguish between strain model and psi4 model,
the keyword `is_psi4_model` should be within the waveform_arguments,
default is False if not specified. ***
Such implementation is to perserve the number of arguments this
Such implementation is to perserve the number of arguments this
waveform_generator takes, such that `bilby_pipe` can work properly.
Note: the arguments of frequency_domain_source_model (except the first,
......@@ -50,55 +49,40 @@ class Psi4_WaveformGenerator(WaveformGenerator):
the WaveformGenerator object and initialised to `None`.
"""
self._times_and_frequencies = CoupledTimeAndFrequencySeries(duration=duration,
sampling_frequency=sampling_frequency,
start_time=start_time)
self.frequency_domain_source_model = frequency_domain_source_model
self.time_domain_source_model = time_domain_source_model
self.source_parameter_keys = self.__parameters_from_source_model()
if parameter_conversion is None:
self.parameter_conversion = convert_to_lal_binary_black_hole_parameters
else:
self.parameter_conversion = parameter_conversion
super().__init__(duration=duration, sampling_frequency=sampling_frequency, start_time=start_time,
frequency_domain_source_model=frequency_domain_source_model,
time_domain_source_model=time_domain_source_model, parameters=parameters,
parameter_conversion=parameter_conversion,
waveform_arguments=None)
if waveform_arguments is not None:
self.waveform_arguments = waveform_arguments
self.is_psi4_model = waveform_arguments.pop('is_psi4_model', False)
else:
self.waveform_arguments = dict()
self.is_psi4_model = False
# Correction factor between derivative and differencing
# The argument in the correction factor: 2π kΔf Δt
phase = 2 * np.pi * self.frequency_array * self.sampling_frequency
phase = 2 * np.pi * self.frequency_array / self.sampling_frequency
self.correction_factor = \
(1 - np.cos(phase)) / (0.5 * phase * phase) \
(1 - np.cos(phase)) / (0.5 * phase * phase) \
if self.is_psi4_model else \
2 * self.sampling_frequency * self.sampling_frequency * (np.cos(phase) - 1)
if isinstance(parameters, dict):
self.parameters = parameters
self._cache = dict(parameters=None, waveform=None, model=None)
utils.logger.info(
"Waveform generator initiated with\n"
" frequency_domain_source_model: {}\n"
" time_domain_source_model: {}\n"
" parameter_conversion: {}"
.format(utils.get_function_path(self.frequency_domain_source_model),
utils.get_function_path(self.time_domain_source_model),
utils.get_function_path(self.parameter_conversion))
)
self._cache = dict(parameters=None, FD_waveform=None, TD_waveform=None)
def time_domain_psi4(self, parameters=None):
return self._calculate_time_domain_psi4(parameters=parameters)
def frequency_domain_psi4(self, parameters=None):
return self._calculate_frequency_domain_psi4(parameters=parameters)
def time_domain_strain(self, parameters=None):
"""
Retain this function as most internal process in Bilby relies on this attribute name.
......@@ -106,31 +90,31 @@ class Psi4_WaveformGenerator(WaveformGenerator):
"""
return self._calculate_time_domain_psi4(parameters=parameters)
def frequency_domain_strain(self, parameters=None):
"""
Retain this function as most internal process in Bilby relies on this attribute name.
But returning frequency domain Psi4.
"""
return self._calculate_frequency_domain_psi4(parameters=parameters)
def _calculate_frequency_domain_psi4(self, parameters):
if parameters is not None:
self.parameters = parameters
if self.parameters == self._cache['parameters']:
if self.parameters == self._cache['parameters'] and self._cache['FD_waveform'] is not None:
return self._cache['FD_waveform']
is_FD_model = True
if self.frequency_domain_source_model is not None:
model_waveform = self.frequency_domain_source_model(self.frequency_array)
model_waveform = self.frequency_domain_source_model(self.frequency_array, **self.parameters)
elif self.time_domain_source_model is not None:
model_waveform = self.time_domain_source_model(self.time_array)
model_waveform = self.time_domain_source_model(self.time_array, **self.parameters)
is_FD_model = False
else:
raise RuntimeError("No source model given")
if isinstance(model_waveform, np.ndarray):
FD_psi4 = self._FD_psi4_from_FD_waveform(model_waveform) \
if is_FD_model else \
......@@ -141,32 +125,39 @@ class Psi4_WaveformGenerator(WaveformGenerator):
FD_psi4[key] = self._FD_psi4_from_FD_waveform(model_waveform[key]) \
if is_FD_model else \
self._FD_psi4_from_TD_waveform(model_waveform[key])
self._cache['FD_waveform'] = FD_psi4
self._cache['parameters'] = self.parameters.copy()
return FD_psi4
def _calculate_time_domain_psi4(self, parameters):
if parameters is not None:
self.parameters = parameters
if self.parameters == self._cache['parameters']:
if self.parameters == self._cache['parameters'] and self._cache['TD_waveform'] is not None:
return self._cache['TD_waveform']
TD_psi4 = utils.infft(self._calculate_frequency_domain_psi4(parameters), self.sampling_frequency)
FD_psi4 = self._calculate_frequency_domain_psi4(parameters)
if isinstance(FD_psi4, np.ndarray):
TD_psi4 = utils.infft(FD_psi4, self.sampling_frequency)
else:
TD_psi4 = dict()
for key in FD_psi4:
TD_psi4[key] = utils.infft(FD_psi4[key], self.sampling_frequency)
self._cache['TD_waveform'] = TD_psi4
self._cache['parameters'] = self.parameters.copy()
return TD_psi4
def _FD_psi4_from_FD_waveform(self, fd_waveform):
return self.correction_factor * fd_waveform
def _FD_psi4_from_TD_waveform(self, td_waveform):
fd_waveform, _ = utils.nfft(td_waveform, self.sampling_frequency)
fd_psi4 = self.correction_factor * fd_waveform
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment