diff --git a/tupak/gw/waveform_generator.py b/tupak/gw/waveform_generator.py index d97127d07f7bcd398410c9d8ef8a0b8d4f14da87..7b462ad4c90ff1cf622eb9627630539a8356afdc 100644 --- a/tupak/gw/waveform_generator.py +++ b/tupak/gw/waveform_generator.py @@ -78,29 +78,30 @@ class WaveformGenerator(object): RuntimeError: If no source model is given """ - model_strain = None - added_keys = self._setup_conversion() - preferred_model = self.frequency_domain_source_model - preferred_model_data_points = self.frequency_array - alternative_model = self.time_domain_source_model - alternative_model_data_points = self.time_array + model = self.frequency_domain_source_model + model_data_points = self.frequency_array + transformed_model = self.time_domain_source_model + transformed_model_data_points = self.time_array + transformation_function = utils.nfft + added_keys = self._setup_conversion() - if preferred_model is not None: + model_strain = None + if model is not None: self.__full_source_model_keyword_arguments.update(self.parameters) - model_strain = preferred_model( - preferred_model_data_points, + model_strain = model( + model_data_points, **self.__full_source_model_keyword_arguments) - elif alternative_model is not None: + elif transformed_model is not None: model_strain = dict() self.__full_source_model_keyword_arguments.update(self.parameters) - time_domain_strain = alternative_model( - alternative_model_data_points, **self.__full_source_model_keyword_arguments) - if isinstance(time_domain_strain, np.ndarray): - return utils.nfft(time_domain_strain, self.sampling_frequency) - for key in time_domain_strain: - model_strain[key], self.frequency_array = utils.nfft(time_domain_strain[key], - self.sampling_frequency) + transformed_model_strain = transformed_model( + transformed_model_data_points, **self.__full_source_model_keyword_arguments) + if isinstance(transformed_model_strain, np.ndarray): + return transformation_function(transformed_model_strain, self.sampling_frequency) + for key in transformed_model_strain: + model_strain[key], self.frequency_array = transformation_function(transformed_model_strain[key], + self.sampling_frequency) else: raise RuntimeError("No source model given") @@ -131,27 +132,28 @@ class WaveformGenerator(object): RuntimeError: If no source model is given """ - model_strain = None - added_keys = self._setup_conversion() - preferred_model = self.time_domain_source_model - preferred_model_data_points = self.time_array - alternative_model = self.frequency_domain_source_model - alternative_model_data_points = self.frequency_array + model = self.time_domain_source_model + model_data_points = self.time_array + transformed_model = self.frequency_domain_source_model + transformed_model_data_points = self.frequency_array + transformation_function = utils.infft + added_keys = self._setup_conversion() - if preferred_model is not None: + model_strain = None + if model is not None: self.__full_source_model_keyword_arguments.update(self.parameters) - model_strain = preferred_model( - preferred_model_data_points, **self.__full_source_model_keyword_arguments) - elif alternative_model is not None: + model_strain = model( + model_data_points, **self.__full_source_model_keyword_arguments) + elif transformed_model is not None: model_strain = dict() self.__full_source_model_keyword_arguments.update(self.parameters) - frequency_domain_strain = alternative_model( - alternative_model_data_points, **self.__full_source_model_keyword_arguments) - if isinstance(frequency_domain_strain, np.ndarray): - return utils.infft(frequency_domain_strain, self.sampling_frequency) - for key in frequency_domain_strain: - model_strain[key] = utils.infft(frequency_domain_strain[key], self.sampling_frequency) + transformed_model_strain = transformed_model( + transformed_model_data_points, **self.__full_source_model_keyword_arguments) + if isinstance(transformed_model_strain, np.ndarray): + return transformation_function(transformed_model_strain, self.sampling_frequency) + for key in transformed_model_strain: + model_strain[key] = transformation_function(transformed_model_strain[key], self.sampling_frequency) else: raise RuntimeError("No source model given") @@ -169,8 +171,8 @@ class WaveformGenerator(object): """ if self.__frequency_array_updated is False: self.frequency_array = utils.create_frequency_series( - self.sampling_frequency, - self.duration) + self.sampling_frequency, + self.duration) return self.__frequency_array @frequency_array.setter @@ -189,9 +191,9 @@ class WaveformGenerator(object): if self.__time_array_updated is False: self.__time_array = utils.create_time_series( - self.sampling_frequency, - self.duration, - self.start_time) + self.sampling_frequency, + self.duration, + self.start_time) self.__time_array_updated = True return self.__time_array