Commit d3603f17 authored by Daniel Williams's avatar Daniel Williams 🤖
Browse files

Changes to complex number handling.

parent d3d425e0
Pipeline #200128 failed with stages
in 46 seconds
......@@ -229,8 +229,12 @@ class CUDALikelihood(Likelihood):
waveform = self.model.frequency_domain_waveform(p, window=self.window, times=self.times)
for pol, wf in waveform.items():
waveform[pol].data *= self.model.strain_input_factor
waveform[pol].variance *= self.model.strain_input_factor**2
# I've just changed these to division; need to check.
waveform[pol].data /= self.model.strain_input_factor
waveform[pol].variance /= self.model.strain_input_factor**2
waveform[pol].covariance /= self.model.strain_input_factor**2
return waveform
def snr(self, signal):
......@@ -261,7 +265,7 @@ class CUDALikelihood(Likelihood):
else:
waveform_mean = polarisations['plus'].data
waveform_variance = polarisations['plus'].variance
waveform_variance = polarisations['plus'].covariance
if model_var:
inner_product = InnerProduct(self.psd.clone(),
......
......@@ -326,8 +326,7 @@ class HeronCUDA(CUDAModel, BBHSurrogate, HofTSurrogate):
# :todo:`This should always be real.`
# It is the expectation of x_i·x_i^*
# This might not be true if need to use off-diagonal components
uncert_f = torch.diag(cov_f)#torch.view_as_complex(torch.stack([torch.diag(cov_f[:, :, 0]),
# torch.diag(cov_f[:, :, 1])]).T)
uncert_f = torch.diag(cov_f.real)
if np.any(times):
srate = 1/np.diff(times).mean()
......@@ -336,6 +335,7 @@ class HeronCUDA(CUDAModel, BBHSurrogate, HofTSurrogate):
return FrequencySeries(data=strain_f*self.model.strain_input_factor,
variance=uncert_f*self.model.strain_input_factor**2,
covariance=cov_f*self.model.strain_input_factor**2,
frequencies=frequencies)
def time_domain_waveform(self, p, times=np.linspace(-2, 2, 1000)):
......@@ -590,6 +590,7 @@ class HeronCUDAIMR(CUDAModel, BBHNonSpinSurrogate, HofTSurrogate):
data[polarisation] = FrequencySeries(data=strain_f,
frequencies=frequencies,
covariance=cov_f,
variance=uncert_f)
return data
......
......@@ -10,6 +10,6 @@ george
nestle
elk-waveform
gpytorch==1.0.1
torch==1.4.0
torch==1.7.1
torchvision==0.5.0
lalsuite
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment