Commit 9e98c33e authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'time_domain_strain_implementation' into 'master'

Time domain strain implementation

See merge request Monash/tupak!22
parents 6230a02e d76a715e
Pipeline #18672 failed with stages
in 59 seconds
......@@ -94,7 +94,7 @@ class TestSourceModelSetter(unittest.TestCase):
self.waveform_generator = tupak.waveform_generator.WaveformGenerator(
frequency_domain_source_model=gaussian_frequency_domain_strain)
self.waveform_generator.frequency_domain_source_model = gaussian_frequency_domain_strain_2
self.simulation_parameters = dict(a=1e-21, m=100, s=1,
self.simulation_parameters = dict(amplitude=1e-21, mu=100, sigma=1,
ra=1.375,
dec=-1.2108,
geocent_time=1126259642.413,
......
import inspect
from . import utils
import numpy as np
class WaveformGenerator(object):
""" A waveform generator
......@@ -38,7 +38,12 @@ class WaveformGenerator(object):
if self.frequency_domain_source_model is not None:
return self.frequency_domain_source_model(self.frequency_array, **self.parameters)
elif self.time_domain_source_model is not None:
fft_data, self.frequency_array = utils.nfft(self.time_domain_source_model(self.time_array, **self.parameters), self.sampling_frequency)
fft_data = dict()
time_domain_strain = self.time_domain_source_model(self.time_array, **self.parameters)
if isinstance(time_domain_strain, np.ndarray):
return time_domain_strain
for key in time_domain_strain:
fft_data[key], self.frequency_array = utils.nfft(time_domain_strain[key], self.sampling_frequency)
return fft_data
else:
raise RuntimeError("No source model given")
......@@ -47,7 +52,13 @@ class WaveformGenerator(object):
if self.time_domain_source_model is not None:
return self.time_domain_source_model(self.time_array, **self.parameters)
elif self.frequency_domain_source_model is not None:
return utils.infft(self.frequency_domain_source_model(self.frequency_array, **self.parameters))
ifft_data = dict()
frequency_domain_strain = self.frequency_domain_source_model(self.frequency_array, **self.parameters)
if isinstance(frequency_domain_strain, np.ndarray):
return frequency_domain_strain
for key in frequency_domain_strain:
ifft_data = utils.infft(frequency_domain_strain[key], self.sampling_frequency)
return ifft_data
else:
raise RuntimeError("No source model given")
......@@ -81,32 +92,31 @@ class WaveformGenerator(object):
@parameters.setter
def parameters(self, parameters):
if parameters is None:
parameters = inspect.getargspec(self.frequency_domain_source_model).args
parameters.pop(0)
self.__parameters = dict.fromkeys(parameters)
if self.frequency_domain_source_model is not None:
parameters = inspect.getargspec(self.frequency_domain_source_model).args
parameters.pop(0)
self.__parameters = dict.fromkeys(parameters)
elif self.time_domain_source_model is not None:
parameters = inspect.getargspec(self.time_domain_source_model).args
parameters.pop(0)
self.__parameters = dict.fromkeys(parameters)
elif isinstance(parameters, list):
parameters.pop(0)
self.__parameters = dict.fromkeys(parameters)
elif isinstance(parameters, dict):
if not hasattr(self, '_WaveformGenerator__parameters'):
self.__parameters = parameters
for key in self.__parameters.keys():
if key in parameters.keys():
self.__parameters[key] = parameters[key]
# else:
# raise KeyError('The provided dictionary did not '
# 'contain key {}'.format(key))
else:
raise KeyError('The provided dictionary did not '
'contain key {}'.format(key))
else:
raise TypeError('Parameters must either be set as a list of keys or'
' a dictionary of key-value pairs.')
@property
def frequency_domain_source_model(self):
return self.__source_model
@frequency_domain_source_model.setter
def frequency_domain_source_model(self, source_model):
self.__source_model = source_model
self.parameters = inspect.getargspec(source_model).args
@property
def time_duration(self):
return self.__time_duration
......
......@@ -45,7 +45,7 @@ prior['luminosity_distance'] = tupak.prior.PowerLaw(
# `lal_binary_black_hole model` source model. We also pass other parameters:
# the waveform approximant and reference frequency.
waveform_generator = tupak.waveform_generator.WaveformGenerator(
tupak.source.lal_binary_black_hole,
frequency_domain_source_model=tupak.source.lal_binary_black_hole,
sampling_frequency=interferometers[0].sampling_frequency,
time_duration=interferometers[0].duration,
parameters={'waveform_approximant': 'IMRPhenomPv2', 'reference_frequency': 50})
......
......@@ -16,7 +16,7 @@ def main():
# Create the waveform generator
waveform_generator = tupak.waveform_generator.WaveformGenerator(
tupak.source.lal_binary_black_hole, sampling_frequency=2048, time_duration=4,
frequency_domain_source_model=tupak.source.lal_binary_black_hole, sampling_frequency=2048, time_duration=4,
parameters={'reference_frequency': 50.0, 'waveform_approximant': 'IMRPhenomPv2'})
# Define the prior
......
import tupak
import matplotlib.pyplot as plt
import numpy as np
def frequency_domain_sine_gaussian(f, A, f0, tau, phi0, geocent_time, ra, dec, psi):
arg = -(np.pi * tau * (f-f0))**2 + 1j * phi0
plus = np.sqrt(np.pi) * A * tau * np.exp(arg) / 2.
cross = plus * np.exp(1j*np.pi/2)
return {'plus': plus, 'cross': cross}
def time_domain_sine_gaussian(t, A, t0, f0, tau, phi0, geocent_time, ra, dec, psi):
arg = -(-(t-t0)/tau)**2
plus = A * np.exp(arg) *np.cos(2*np.pi*f0*t + phi0)
cross = plus * np.exp(1j*np.pi/2)
return {'plus': plus, 'cross': cross}
parameters = dict()
parameters['A'] = 10000
parameters['f0'] = 5
parameters['t0'] = 10
parameters['tau'] = 3
parameters['geocent_time'] = 0
parameters['phi0'] = 0
parameters['ra'] = 0
parameters['dec'] = 0
parameters['psi'] = 0
wg = tupak.waveform_generator.WaveformGenerator(time_domain_source_model=time_domain_sine_gaussian, time_duration=2000, sampling_frequency=1000, parameters=parameters)
wg.parameters = parameters
plt.plot(wg.frequency_array, wg.frequency_domain_strain()['plus'])
plt.xlim(4, 6)
plt.show()
plt.plot(wg.frequency_array, wg.frequency_domain_strain()['cross'])
plt.xlim(4, 6)
plt.show()
plt.plot(wg.time_array, wg.time_domain_strain()['plus'])
plt.xlim(0, 20)
plt.show()
plt.plot(wg.time_array, wg.time_domain_strain()['cross'])
plt.xlim(0, 20)
plt.show()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment