Forked from
lscsoft / bilby
2754 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
waveform_generator.py 11.32 KiB
from tupak.core import utils
import numpy as np
class WaveformGenerator(object):
def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequency_domain_source_model=None,
time_domain_source_model=None, parameters=None,
parameter_conversion=None,
non_standard_sampling_parameter_keys=None,
waveform_arguments=None):
""" A waveform generator
Parameters
----------
sampling_frequency: float, optional
The sampling frequency
duration: float, optional
Time duration of data
start_time: float, optional
Starting time of the time array
frequency_domain_source_model: func, optional
A python function taking some arguments and returning the frequency
domain strain. Note the first argument must be the frequencies at
which to compute the strain
time_domain_source_model: func, optional
A python function taking some arguments and returning the time
domain strain. Note the first argument must be the times at
which to compute the strain
parameters: dict, optional
Initial values for the parameters
parameter_conversion: func, optional
Function to convert from sampled parameters to parameters of the
waveform generator. Default value is the identity, i.e. it leaves
the parameters unaffected.
non_standard_sampling_parameter_keys: list, optional
List of parameter name for *non-standard* sampling parameters.
waveform_arguments: dict, optional
A dictionary of fixed keyword arguments to pass to either
`frequency_domain_source_model` or `time_domain_source_model`.
Note: the arguments of frequency_domain_source_model (except the first,
which is the frequencies at which to compute the strain) will be added to
the WaveformGenerator object and initialised to `None`.
"""
self.duration = duration
self.sampling_frequency = sampling_frequency
self.start_time = start_time
self.frequency_domain_source_model = frequency_domain_source_model
self.time_domain_source_model = time_domain_source_model
self.__parameters_from_source_model()
self.duration = duration
self.sampling_frequency = sampling_frequency
if parameter_conversion is None:
self.parameter_conversion = lambda params, search_keys: (params, [])
else:
self.parameter_conversion = parameter_conversion
self.non_standard_sampling_parameter_keys = non_standard_sampling_parameter_keys
self.parameters = parameters
if waveform_arguments is not None:
self.waveform_arguments = waveform_arguments
else:
self.waveform_arguments = dict()
self.__frequency_array_updated = False
self.__time_array_updated = False
self.__full_source_model_keyword_arguments = {}
self.__full_source_model_keyword_arguments.update(self.waveform_arguments)
self.__full_source_model_keyword_arguments.update(self.parameters)
self.__added_keys = []
def __repr__(self):
if self.frequency_domain_source_model is not None:
fdsm_name = self.frequency_domain_source_model.__name__
else:
fdsm_name = None
if self.time_domain_source_model is not None:
tdsm_name = self.time_domain_source_model.__name__
else:
tdsm_name = None
if self.parameter_conversion.__name__ == '<lambda>':
param_conv_name = None
else:
param_conv_name = self.parameter_conversion.__name__
return self.__class__.__name__ + '(duration={}, sampling_frequency={}, start_time={}, ' \
'frequency_domain_source_model={}, time_domain_source_model={}, ' \
'parameters={}, parameter_conversion={}, ' \
'non_standard_sampling_parameter_keys={}, waveform_arguments={})'\
.format(self.duration, self.sampling_frequency, self.start_time, fdsm_name, tdsm_name, self.parameters,
param_conv_name, self.non_standard_sampling_parameter_keys, self.waveform_arguments)
def frequency_domain_strain(self):
""" Rapper to source_model.
Converts self.parameters with self.parameter_conversion before handing it off to the source model.
Automatically refers to the time_domain_source model via NFFT if no frequency_domain_source_model is given.
Returns
-------
array_like: The frequency domain strain for the given set of parameters
Raises
-------
RuntimeError: If no source model is given
"""
return self._calculate_strain(model=self.frequency_domain_source_model,
model_data_points=self.frequency_array,
transformation_function=utils.nfft,
transformed_model=self.time_domain_source_model,
transformed_model_data_points=self.time_array)
def time_domain_strain(self):
""" Rapper to source_model.
Converts self.parameters with self.parameter_conversion before handing it off to the source model.
Automatically refers to the frequency_domain_source model via INFFT if no frequency_domain_source_model is
given.
Returns
-------
array_like: The time domain strain for the given set of parameters
Raises
-------
RuntimeError: If no source model is given
"""
return self._calculate_strain(model=self.time_domain_source_model,
model_data_points=self.time_array,
transformation_function=utils.infft,
transformed_model=self.frequency_domain_source_model,
transformed_model_data_points=self.frequency_array)
def _calculate_strain(self, model, model_data_points, transformation_function, transformed_model,
transformed_model_data_points):
self._apply_parameter_conversion()
if model is not None:
model_strain = self._strain_from_model(model_data_points, model)
elif transformed_model is not None:
model_strain = self._strain_from_transformed_model(transformed_model_data_points, transformed_model,
transformation_function)
else:
raise RuntimeError("No source model given")
self._remove_added_keys()
return model_strain
def _apply_parameter_conversion(self):
self.parameters, self.__added_keys = self.parameter_conversion(self.parameters,
self.non_standard_sampling_parameter_keys)
self.__full_source_model_keyword_arguments.update(self.parameters)
def _strain_from_model(self, model_data_points, model):
return model(model_data_points, **self.__full_source_model_keyword_arguments)
def _strain_from_transformed_model(self, transformed_model_data_points, transformed_model, transformation_function):
transformed_model_strain = self._strain_from_model(transformed_model_data_points, transformed_model)
if isinstance(transformed_model_strain, np.ndarray):
return transformation_function(transformed_model_strain, self.sampling_frequency)
model_strain = dict()
for key in transformed_model_strain:
if transformation_function == utils.nfft:
model_strain[key], self.frequency_array = \
transformation_function(transformed_model_strain[key], self.sampling_frequency)
else:
model_strain[key] = transformation_function(transformed_model_strain[key], self.sampling_frequency)
return model_strain
def _remove_added_keys(self):
for key in self.__added_keys:
self.parameters.pop(key)
@property
def frequency_array(self):
""" Frequency array for the waveforms. Automatically updates if sampling_frequency or duration are updated.
Returns
-------
array_like: The frequency array
"""
if self.__frequency_array_updated is False:
self.frequency_array = utils.create_frequency_series(
self.sampling_frequency,
self.duration)
return self.__frequency_array
@frequency_array.setter
def frequency_array(self, frequency_array):
self.__frequency_array = frequency_array
self.__frequency_array_updated = True
@property
def time_array(self):
""" Time array for the waveforms. Automatically updates if sampling_frequency or duration are updated.
Returns
-------
array_like: The time array
"""
if self.__time_array_updated is False:
self.__time_array = utils.create_time_series(
self.sampling_frequency,
self.duration,
self.start_time)
self.__time_array_updated = True
return self.__time_array
@time_array.setter
def time_array(self, time_array):
self.__time_array = time_array
self.__time_array_updated = True
@property
def parameters(self):
""" The dictionary of parameters for source model.
Returns
-------
dict: The dictionary of parameter key-value pairs
"""
return self.__parameters
@parameters.setter
def parameters(self, parameters):
if isinstance(parameters, dict):
for key in parameters.keys():
self.__parameters[key] = parameters[key]
def __parameters_from_source_model(self):
if self.frequency_domain_source_model is not None:
self.__parameters = dict.fromkeys(utils.infer_parameters_from_function(self.frequency_domain_source_model))
elif self.time_domain_source_model is not None:
self.__parameters = dict.fromkeys(utils.infer_parameters_from_function(self.time_domain_source_model))
@property
def duration(self):
""" Allows one to set the time duration and automatically updates the frequency and time array.
Returns
-------
float: The time duration.
"""
return self.__duration
@duration.setter
def duration(self, duration):
self.__duration = duration
self.__frequency_array_updated = False
self.__time_array_updated = False
@property
def sampling_frequency(self):
""" Allows one to set the sampling frequency and automatically updates the frequency and time array.
Returns
-------
float: The sampling frequency.
"""
return self.__sampling_frequency
@sampling_frequency.setter
def sampling_frequency(self, sampling_frequency):
self.__sampling_frequency = sampling_frequency
self.__frequency_array_updated = False
self.__time_array_updated = False
@property
def start_time(self):
return self.__start_time
@start_time.setter
def start_time(self, starting_time):
self.__start_time = starting_time
self.__time_array_updated = False