From aae26944e51a5cc347777dfbfd44cdb6037ba53f Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Mon, 11 Jun 2018 16:24:37 +1000
Subject: [PATCH] Adds fixed_arguments to the wfg

---
 examples/injection_examples/basic_tutorial.py |  9 +++++--
 tupak/gw/waveform_generator.py                | 24 +++++++++++++++----
 2 files changed, 26 insertions(+), 7 deletions(-)

diff --git a/examples/injection_examples/basic_tutorial.py b/examples/injection_examples/basic_tutorial.py
index 39a3d171b..1dccdf758 100644
--- a/examples/injection_examples/basic_tutorial.py
+++ b/examples/injection_examples/basic_tutorial.py
@@ -29,13 +29,18 @@ np.random.seed(88170235)
 # spins of both black holes (a, tilt, phi), etc.
 injection_parameters = dict(mass_1=36., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0, phi_12=1.7, phi_jl=0.3,
                             luminosity_distance=2000., iota=0.4, psi=2.659, phase=1.3, geocent_time=1126259642.413,
-                            waveform_approximant='IMRPhenomPv2', reference_frequency=50., ra=1.375, dec=-1.2108)
+                            ra=1.375, dec=-1.2108)
+
+# Fixed arguments passed into the source model
+fixed_arguments = dict(waveform_approximant='IMRPhenomPv2',
+                       reference_frequency=50.)
 
 # Create the waveform_generator using a LAL BinaryBlackHole source function
 waveform_generator = tupak.WaveformGenerator(time_duration=time_duration,
                                              sampling_frequency=sampling_frequency,
                                              frequency_domain_source_model=tupak.gw.source.lal_binary_black_hole,
-                                             parameters=injection_parameters)
+                                             parameters=injection_parameters,
+                                             fixed_arguments=fixed_arguments)
 hf_signal = waveform_generator.frequency_domain_strain()
 
 # Set up interferometers.  In this case we'll use three interferometers (LIGO-Hanford (H1), LIGO-Livingston (L1),
diff --git a/tupak/gw/waveform_generator.py b/tupak/gw/waveform_generator.py
index b0aef610c..663ee7a23 100644
--- a/tupak/gw/waveform_generator.py
+++ b/tupak/gw/waveform_generator.py
@@ -28,6 +28,8 @@ class WaveformGenerator(object):
         waveform generator
     non_standard_sampling_parameter_keys: list
         List of parameter name for *non-standard* sampling parameters.
+    fixed_arguments: dict
+        A dictionary of fixed keyword arguments
 
     Note: the arguments of frequency_domain_source_model (except the first,
     which is the frequencies at which to compute the strain) will be added to
@@ -37,7 +39,7 @@ class WaveformGenerator(object):
 
     def __init__(self, time_duration, sampling_frequency, frequency_domain_source_model=None,
                  time_domain_source_model=None, parameters=None, parameter_conversion=None,
-                 non_standard_sampling_parameter_keys=None):
+                 non_standard_sampling_parameter_keys=None, fixed_arguments={}):
         self.time_duration = time_duration
         self.sampling_frequency = sampling_frequency
         self.frequency_domain_source_model = frequency_domain_source_model
@@ -47,8 +49,11 @@ class WaveformGenerator(object):
         self.parameter_conversion = parameter_conversion
         self.non_standard_sampling_parameter_keys = non_standard_sampling_parameter_keys
         self.parameters = parameters
+        self.fixed_arguments = fixed_arguments
         self.__frequency_array_updated = False
         self.__time_array_updated = False
+        self.__full_source_model_keyword_arguments = {}
+        self.__full_source_model_keyword_arguments.update(fixed_arguments)
 
     def frequency_domain_strain(self):
         """ Wrapper to source_model """
@@ -56,10 +61,15 @@ class WaveformGenerator(object):
             added_keys = self.parameter_conversion(self.parameters, self.non_standard_sampling_parameter_keys)
 
         if self.frequency_domain_source_model is not None:
-            model_frequency_strain = self.frequency_domain_source_model(self.frequency_array, **self.parameters)
+            self.__full_source_model_keyword_arguments.update(self.parameters)
+            model_frequency_strain = self.frequency_domain_source_model(
+                self.frequency_array,
+                **self.__full_source_model_keyword_arguments)
         elif self.time_domain_source_model is not None:
             model_frequency_strain = dict()
-            time_domain_strain = self.time_domain_source_model(self.time_array, **self.parameters)
+            self.__full_source_model_keyword_arguments.update(self.parameters)
+            time_domain_strain = self.time_domain_source_model(
+                self.time_array, **self.__full_source_model_keyword_arguments)
             if isinstance(time_domain_strain, np.ndarray):
                 return utils.nfft(time_domain_strain, self.sampling_frequency)
             for key in time_domain_strain:
@@ -76,10 +86,14 @@ class WaveformGenerator(object):
         if self.parameter_conversion is not None:
             added_keys = self.parameter_conversion(self.parameters, self.non_standard_sampling_parameter_keys)
         if self.time_domain_source_model is not None:
-            model_time_series = self.time_domain_source_model(self.time_array, **self.parameters)
+            self.__full_source_model_keyword_arguments.update(self.parameters)
+            model_time_series = self.time_domain_source_model(
+                self.time_array, **self.__full_source_model_keyword_arguments)
         elif self.frequency_domain_source_model is not None:
             model_time_series = dict()
-            frequency_domain_strain = self.frequency_domain_source_model(self.frequency_array, **self.parameters)
+            self.__full_source_model_keyword_arguments.update(self.parameters)
+            frequency_domain_strain = self.frequency_domain_source_model(
+                self.frequency_array, **self.__full_source_model_keyword_arguments)
             if isinstance(frequency_domain_strain, np.ndarray):
                 return utils.infft(frequency_domain_strain, self.sampling_frequency)
             for key in frequency_domain_strain:
-- 
GitLab