From aae26944e51a5cc347777dfbfd44cdb6037ba53f Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Mon, 11 Jun 2018 16:24:37 +1000 Subject: [PATCH] Adds fixed_arguments to the wfg --- examples/injection_examples/basic_tutorial.py | 9 +++++-- tupak/gw/waveform_generator.py | 24 +++++++++++++++---- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/examples/injection_examples/basic_tutorial.py b/examples/injection_examples/basic_tutorial.py index 39a3d171b..1dccdf758 100644 --- a/examples/injection_examples/basic_tutorial.py +++ b/examples/injection_examples/basic_tutorial.py @@ -29,13 +29,18 @@ np.random.seed(88170235) # spins of both black holes (a, tilt, phi), etc. injection_parameters = dict(mass_1=36., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0, phi_12=1.7, phi_jl=0.3, luminosity_distance=2000., iota=0.4, psi=2.659, phase=1.3, geocent_time=1126259642.413, - waveform_approximant='IMRPhenomPv2', reference_frequency=50., ra=1.375, dec=-1.2108) + ra=1.375, dec=-1.2108) + +# Fixed arguments passed into the source model +fixed_arguments = dict(waveform_approximant='IMRPhenomPv2', + reference_frequency=50.) # Create the waveform_generator using a LAL BinaryBlackHole source function waveform_generator = tupak.WaveformGenerator(time_duration=time_duration, sampling_frequency=sampling_frequency, frequency_domain_source_model=tupak.gw.source.lal_binary_black_hole, - parameters=injection_parameters) + parameters=injection_parameters, + fixed_arguments=fixed_arguments) hf_signal = waveform_generator.frequency_domain_strain() # Set up interferometers. In this case we'll use three interferometers (LIGO-Hanford (H1), LIGO-Livingston (L1), diff --git a/tupak/gw/waveform_generator.py b/tupak/gw/waveform_generator.py index b0aef610c..663ee7a23 100644 --- a/tupak/gw/waveform_generator.py +++ b/tupak/gw/waveform_generator.py @@ -28,6 +28,8 @@ class WaveformGenerator(object): waveform generator non_standard_sampling_parameter_keys: list List of parameter name for *non-standard* sampling parameters. + fixed_arguments: dict + A dictionary of fixed keyword arguments Note: the arguments of frequency_domain_source_model (except the first, which is the frequencies at which to compute the strain) will be added to @@ -37,7 +39,7 @@ class WaveformGenerator(object): def __init__(self, time_duration, sampling_frequency, frequency_domain_source_model=None, time_domain_source_model=None, parameters=None, parameter_conversion=None, - non_standard_sampling_parameter_keys=None): + non_standard_sampling_parameter_keys=None, fixed_arguments={}): self.time_duration = time_duration self.sampling_frequency = sampling_frequency self.frequency_domain_source_model = frequency_domain_source_model @@ -47,8 +49,11 @@ class WaveformGenerator(object): self.parameter_conversion = parameter_conversion self.non_standard_sampling_parameter_keys = non_standard_sampling_parameter_keys self.parameters = parameters + self.fixed_arguments = fixed_arguments self.__frequency_array_updated = False self.__time_array_updated = False + self.__full_source_model_keyword_arguments = {} + self.__full_source_model_keyword_arguments.update(fixed_arguments) def frequency_domain_strain(self): """ Wrapper to source_model """ @@ -56,10 +61,15 @@ class WaveformGenerator(object): added_keys = self.parameter_conversion(self.parameters, self.non_standard_sampling_parameter_keys) if self.frequency_domain_source_model is not None: - model_frequency_strain = self.frequency_domain_source_model(self.frequency_array, **self.parameters) + self.__full_source_model_keyword_arguments.update(self.parameters) + model_frequency_strain = self.frequency_domain_source_model( + self.frequency_array, + **self.__full_source_model_keyword_arguments) elif self.time_domain_source_model is not None: model_frequency_strain = dict() - time_domain_strain = self.time_domain_source_model(self.time_array, **self.parameters) + self.__full_source_model_keyword_arguments.update(self.parameters) + time_domain_strain = self.time_domain_source_model( + self.time_array, **self.__full_source_model_keyword_arguments) if isinstance(time_domain_strain, np.ndarray): return utils.nfft(time_domain_strain, self.sampling_frequency) for key in time_domain_strain: @@ -76,10 +86,14 @@ class WaveformGenerator(object): if self.parameter_conversion is not None: added_keys = self.parameter_conversion(self.parameters, self.non_standard_sampling_parameter_keys) if self.time_domain_source_model is not None: - model_time_series = self.time_domain_source_model(self.time_array, **self.parameters) + self.__full_source_model_keyword_arguments.update(self.parameters) + model_time_series = self.time_domain_source_model( + self.time_array, **self.__full_source_model_keyword_arguments) elif self.frequency_domain_source_model is not None: model_time_series = dict() - frequency_domain_strain = self.frequency_domain_source_model(self.frequency_array, **self.parameters) + self.__full_source_model_keyword_arguments.update(self.parameters) + frequency_domain_strain = self.frequency_domain_source_model( + self.frequency_array, **self.__full_source_model_keyword_arguments) if isinstance(frequency_domain_strain, np.ndarray): return utils.infft(frequency_domain_strain, self.sampling_frequency) for key in frequency_domain_strain: -- GitLab