Skip to content
Snippets Groups Projects
Forked from lscsoft / bilby
This fork has diverged from the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
psi4_waveform_generator.py 7.20 KiB
import numpy as np

from ..core import utils
from .waveform_generator import WaveformGenerator


class Psi4_WaveformGenerator(WaveformGenerator):
    def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequency_domain_source_model=None,
                 time_domain_source_model=None, parameters=None,
                 parameter_conversion=None,
                 waveform_arguments=None):
        """
        The Psi4 waveform generator class.

        Parameters
        ==========
        sampling_frequency: float, optional
            The sampling frequency
        duration: float, optional
            Time duration of data
        start_time: float, optional
            Starting time of the time array
        frequency_domain_source_model: func, optional
            A python function taking some arguments and returning the frequency
            domain strain. Note the first argument must be the frequencies at
            which to compute the strain
        time_domain_source_model: func, optional
            A python function taking some arguments and returning the time
            domain strain. Note the first argument must be the times at
            which to compute the strain
        parameters: dict, optional
            Initial values for the parameters
        parameter_conversion: func, optional
            Function to convert from sampled parameters to parameters of the
            waveform generator. Default value is the identity, i.e. it leaves
            the parameters unaffected.
        waveform_arguments: dict, optional
            A dictionary of fixed keyword arguments to pass to either
            `frequency_domain_source_model` or `time_domain_source_model`.

            *** In order to distinguish between strain model and psi4 model,
            the keyword `is_psi4_model` should be within the waveform_arguments,
            default is False if not specified. ***
            Such implementation is to perserve the number of arguments this
            waveform_generator takes, such that `bilby_pipe` can work properly.

            Note: the arguments of frequency_domain_source_model (except the first,
            which is the frequencies at which to compute the strain) will be added to
            the WaveformGenerator object and initialised to `None`.

        """
        super().__init__(duration=duration, sampling_frequency=sampling_frequency, start_time=start_time,
                         frequency_domain_source_model=frequency_domain_source_model,
                         time_domain_source_model=time_domain_source_model, parameters=parameters,
                         parameter_conversion=parameter_conversion,
                         waveform_arguments=None)

        if waveform_arguments is not None:
            self.is_psi4_model = waveform_arguments.pop('is_psi4_model', False)
            self.waveform_arguments = waveform_arguments
        else:
            self.waveform_arguments = dict()
            self.is_psi4_model = False

        # Correction factor between derivative and differencing
        # Details please refers to the article: arxiv:2205.15029
        # The argument in the correction factor: 2π kΔf Δt
        phase = 2 * np.pi * self.frequency_array / self.sampling_frequency

        ## This correction_factor depends whether the data is strain or psi4,
        ## the first is for psi4, (A5) in the article; and
        ## the second is for strain, see (A4).
        if self.is_psi4_model:
            mask = phase != 0
            self.correction_factor = np.zeros(len(phase))
            non_zero_phase = phase[mask]
            self.correction_factor[mask] = \
                (1 - np.cos(non_zero_phase)) / (0.5 * non_zero_phase * non_zero_phase)
        else:
            self.correction_factor = \
                2 * self.sampling_frequency * self.sampling_frequency * (np.cos(phase) - 1)



        self._cache = dict(parameters=None, FD_waveform=None, TD_waveform=None)



    def time_domain_psi4(self, parameters=None):
        return self._calculate_time_domain_psi4(parameters=parameters)


    def frequency_domain_psi4(self, parameters=None):
        return self._calculate_frequency_domain_psi4(parameters=parameters)


    def time_domain_strain(self, parameters=None):
        """
        Retain this function as most internal process in Bilby relies on this attribute name.
        But returning time domain Psi4.
        """
        return self._calculate_time_domain_psi4(parameters=parameters)


    def frequency_domain_strain(self, parameters=None):
        """
        Retain this function as most internal process in Bilby relies on this attribute name.
        But returning frequency domain Psi4.
        """
        return self._calculate_frequency_domain_psi4(parameters=parameters)


    def _calculate_frequency_domain_psi4(self, parameters):
        if parameters is not None:
            self.parameters = parameters

        if self.parameters == self._cache['parameters'] and self._cache['FD_waveform'] is not None:
            return self._cache['FD_waveform']

        is_FD_model = True
        if self.frequency_domain_source_model is not None:
            model_waveform = self.frequency_domain_source_model(self.frequency_array, **self.parameters)
        elif self.time_domain_source_model is not None:
            model_waveform = self.time_domain_source_model(self.time_array, **self.parameters)
            is_FD_model = False
        else:
            raise RuntimeError("No source model given")

        if isinstance(model_waveform, np.ndarray):
            FD_psi4 = self._FD_psi4_from_FD_waveform(model_waveform) \
                    if is_FD_model else \
                    self._FD_psi4_from_TD_waveform(model_waveform)
        elif model_waveform is None:
            FD_psi4 = None
        else:
            FD_psi4 = dict()
            for key in model_waveform:
                FD_psi4[key] = self._FD_psi4_from_FD_waveform(model_waveform[key]) \
                    if is_FD_model else \
                    self._FD_psi4_from_TD_waveform(model_waveform[key])

        self._cache['FD_waveform'] = FD_psi4
        self._cache['parameters'] = self.parameters.copy()

        return FD_psi4


    def _calculate_time_domain_psi4(self, parameters):
        if parameters is not None:
            self.parameters = parameters

        if self.parameters == self._cache['parameters'] and self._cache['TD_waveform'] is not None:
            return self._cache['TD_waveform']

        FD_psi4 = self._calculate_frequency_domain_psi4(parameters)

        if isinstance(FD_psi4, np.ndarray):
            TD_psi4 = utils.infft(FD_psi4, self.sampling_frequency)
        elif FD_psi4 is None:
            TD_psi4 = None
        else:
            TD_psi4 = dict()
            for key in FD_psi4:
                TD_psi4[key] = utils.infft(FD_psi4[key], self.sampling_frequency)

        self._cache['TD_waveform'] = TD_psi4
        self._cache['parameters'] = self.parameters.copy()

        return TD_psi4


    def _FD_psi4_from_FD_waveform(self, fd_waveform):
        return self.correction_factor * fd_waveform


    def _FD_psi4_from_TD_waveform(self, td_waveform):
        fd_waveform, _ = utils.nfft(td_waveform, self.sampling_frequency)
        return self.correction_factor * fd_waveform