Skip to content
Snippets Groups Projects
Commit 0c2f014c authored by Samson Leong's avatar Samson Leong :stuck_out_tongue:
Browse files

Fix psi4 generator.

parent d117922b
No related branches found
Tags 1.1.4
No related merge requests found
......@@ -56,8 +56,8 @@ class Psi4_WaveformGenerator(WaveformGenerator):
waveform_arguments=None)
if waveform_arguments is not None:
self.waveform_arguments = waveform_arguments
self.is_psi4_model = waveform_arguments.pop('is_psi4_model', False)
self.waveform_arguments = waveform_arguments
else:
self.waveform_arguments = dict()
self.is_psi4_model = False
......@@ -70,10 +70,16 @@ class Psi4_WaveformGenerator(WaveformGenerator):
## This correction_factor depends whether the data is strain or psi4,
## the first is for psi4, (A5) in the article; and
## the second is for strain, see (A4).
self.correction_factor = \
(1 - np.cos(phase)) / (0.5 * phase * phase) \
if self.is_psi4_model else \
2 * self.sampling_frequency * self.sampling_frequency * (np.cos(phase) - 1)
if self.is_psi4_model:
mask = phase != 0
self.correction_factor = np.zeros(len(phase))
non_zero_phase = phase[mask]
self.correction_factor[mask] = \
(1 - np.cos(non_zero_phase)) / (0.5 * non_zero_phase * non_zero_phase)
else:
self.correction_factor = \
2 * self.sampling_frequency * self.sampling_frequency * (np.cos(phase) - 1)
self._cache = dict(parameters=None, FD_waveform=None, TD_waveform=None)
......@@ -124,8 +130,6 @@ class Psi4_WaveformGenerator(WaveformGenerator):
FD_psi4 = self._FD_psi4_from_FD_waveform(model_waveform) \
if is_FD_model else \
self._FD_psi4_from_TD_waveform(model_waveform)
elif model_waveform is None:
FD_psi4 = None
else:
FD_psi4 = dict()
for key in model_waveform:
......@@ -150,8 +154,6 @@ class Psi4_WaveformGenerator(WaveformGenerator):
if isinstance(FD_psi4, np.ndarray):
TD_psi4 = utils.infft(FD_psi4, self.sampling_frequency)
elif model_waveform is None:
TD_psi4 = None
else:
TD_psi4 = dict()
for key in FD_psi4:
......@@ -169,5 +171,4 @@ class Psi4_WaveformGenerator(WaveformGenerator):
def _FD_psi4_from_TD_waveform(self, td_waveform):
fd_waveform, _ = utils.nfft(td_waveform, self.sampling_frequency)
fd_psi4 = self.correction_factor * fd_waveform
return utils.infft(fd_psi4, self.sampling_frequency)
return self.correction_factor * fd_waveform
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment