Commit d3d425e0 authored by Daniel Williams's avatar Daniel Williams 🤖
Browse files

Updated the likelihood calculation

Started work to allow the full covariance to be used in the inner product.
parent 7af633e2
Pipeline #194218 failed with stages
in 26 seconds
"""
Code for matched filtering using CUDA and pytorch.
"""
from copy import copy
import torch
import elk
import elk.waveform
from .utils import Complex
from .utils import diag_cuda, Complex
from lal import antenna
......@@ -51,37 +51,40 @@ class InnerProduct():
>>> ip(a, b)
"""
def __init__(self, psd, duration, signal_psd=None, f_min=None, f_max=None):
def __init__(self, psd, duration, signal_psd=None, signal_cov=None, f_min=None, f_max=None):
self.noise = psd
self.noise2 = signal_psd
self.signal_cov = signal_cov
#if not isinstance(self.noise2, type(None)):
# reweight = torch.sum(self.noise.modulus / self.noise2.modulus)
# self.noise2 = self.noise2 / reweight
self.duration = duration
if not isinstance(self.signal_cov, type(None)):
self.metric = 1./self.signal_cov
self.metric += diag_cuda(1./(self.duration*self.noise))
elif not isinstance(self.noise2, type(None)):
self.metric = diag_cuda(1./(self.duration*self.noise + self.duration*self.noise2.abs()))
else:
self.metric = diag_cuda(1./(self.duration * self.noise))
if f_min or f_max:
warnings.warn("f_min and f_max are not yet implemented. The full frequency series will be used.", RuntimeWarning)
warnings.warn("""f_min and f_max are not yet implemented. The full frequency series will be used.""",
RuntimeWarning)
self.f_min = f_min
self.f_max = f_max
def __call__(self, a, b):
return self.inner(a,b, self.duration)
return self.inner(a,b) #, self.duration)
def inner(self, a, b):
"""
Calculate the noise-weighted inner product of a and b.
"""
integrand = a.conjugate * b
weights = self.noise.clone()
if not isinstance(self.noise2, type(None)):
integrand /= (weights*self.duration+self.noise2.modulus)
else:
integrand /= (weights*self.duration)
# Is this the best place to make this assumption?
# It may not be explicit what's being truncated
# :todo:`Remove the minimum / maximum frequency assumption from here`
return 4 * torch.sum(integrand.real[1:-1])
c = a[1:-1].conj() @ self.metric[1:-1,1:-1] @ b[1:-1]
return 4.0*c.real
class Likelihood():
......@@ -185,57 +188,53 @@ class CUDALikelihood(Likelihood):
self.model = model
self.window = window
self.device = device
self.f_min = f_min
self.f_max = f_max
self.gen_args = generator_args
if isinstance(data, elk.waveform.Timeseries):
self.data = Complex((self.window*torch.tensor(data.data*1e19, device=device)).rfft(1)[1:])
self.data = self.window*torch.view_as_complex(torch.tensor(data.data, device=device).rfft(1)[1:])
self.times = data.times
elif isinstance(data, elk.waveform.FrequencySeries):
self.data = data.data.clone()
self.data.tensor = self.data.tensor.clone() * 1e19
#self.data.tensor = self.data.tensor.clone()
freqs = data.frequencies
self.times = torch.linspace(0, 1/freqs[1], 2*(len(freqs)-1)) + start
nt = 2*(len(freqs)-1)
self.times = torch.linspace(0, nt/freqs[-1], nt) + start
self.frequencies = data.frequencies
self.start = start
self.data *= self.model.strain_input_factor
self.duration = self.times[-1] - self.times[0]
if not isinstance(psd, type(None)):
self.psd = psd * 1e38
self.psd = psd * self.model.strain_input_factor**2
else:
if not isinstance(asd, type(None)):
self.asd = asd
self.asd.tensor = self.asd.tensor.clone() * 1e19
#self.asd.tensor = self.asd.tensor.clone()
else:
self.asd = torch.ones(len(self.data), 2)
self.psd = self.asd * self.asd
self.psd = (self.asd * self.asd) * self.model.strain_input_factor**2
def _call_model(self, p):
p.update(self.gen_args)
args = copy(self.gen_args)
args.update(p)
p = args
if self._cache_location == p:
return self._cache
waveform = self._cache
else:
# :todo:`Remove magic numbers`
waveform = self.model.frequency_domain_waveform(p, window=self.window, times=self.times)
# Do we need all of these clones?
waveform.data.tensor = waveform.data.tensor.clone() * 1e19
waveform.variance.tensor = waveform.variance.tensor.clone() * 1e38
# :todo:`Fix the caching...`
return waveform
for pol, wf in waveform.items():
waveform[pol].data *= self.model.strain_input_factor
waveform[pol].variance *= self.model.strain_input_factor**2
return waveform
def snr(self, signal):
if isinstance(signal, elk.waveform.Timeseries):
signal_f = Complex((self.window*signal*1e19).rfft(1).double())
else:
signal_f = signal
bracket = signal_f.data/self.psd
bracket = bracket * bracket
return torch.sqrt(torch.abs(torch.sum(
(bracket.real * torch.log(torch.tensor(signal.frequencies, device=device)).double())[1:-1]
)))
pass
def _products(self, p, model_var):
"""
......@@ -243,10 +242,12 @@ class CUDALikelihood(Likelihood):
Notes
-----
The way that the waveform variance is currently treated assumes that the two polarisations are statistically independent.
The way that the waveform variance is currently treated
assumes that the two polarisations are statistically
independent.
"""
polarisations = self._call_model(p)
polarisations = self._call_model(p)
if "ra" in p.keys():
response = self._antenna_reponse(detector=p['detector'],
......@@ -256,20 +257,30 @@ class CUDALikelihood(Likelihood):
time=p['gpstime'])
waveform_mean = polarisations['plus'].data * response['plus'] + polarisation['cross'].data * response['cross']
waveform_variance = polarisations['plus'].variance * response['plus']**2 + polarisations['cross'].variance * response['cross']**2
waveform_variance = polarisations['plus'].variance * response['plus']**2 + polarisations['cross'].variance * response['cross']**2
else:
waveform_mean = polarisations['plus'].data
waveform_variance = polarisations['plus'].variance
if model_var:
inner_product = InnerProduct(self.psd.clone(), waveform_variance, duration=self.duration, f_min=self.f_min, f_max=self.f_max)
inner_product = InnerProduct(self.psd.clone(),
signal_psd=waveform_variance,
duration=self.duration,
f_min=self.f_min,
f_max=self.f_max)
factor = torch.logdet(inner_product.metric.abs()[1:-1, 1:-1])
else:
inner_product = InnerProduct(self.psd.clone(), duration=self.duration, f_min=self.f_min, f_max=self.f_max)
products = -0.5 * (inner_product(self.data, self.data))
inner_product = InnerProduct(self.psd.clone(),
duration=self.duration,
f_min=self.f_min, f_max=self.f_max)
factor = torch.sum(torch.log(1./(self.duration*self.psd.abs())[1:-1]))
products = 0
products = -0.5 * (inner_product(self.data, self.data))
products += -0.5 * (inner_product(waveform_mean, waveform_mean))
products += inner_product(self.data.clone(), waveform_mean)
products *= factor
return products
......@@ -278,12 +289,21 @@ class CUDALikelihood(Likelihood):
Calculate the normalisation.
"""
waveform = self._call_model(p)
psd = self.psd.real[1:-1]
psd = self.psd.real[1:-1] / self.model.strain_input_factor**2
if "ra" not in p.keys():
waveform = waveform['plus']
if model_var:
variance = waveform.variance.modulus[1:-1]
normalisation = torch.sum(torch.log(psd)) - torch.log(torch.prod(psd / psd.max()) + torch.prod(variance / variance.max())) + torch.log(psd.max())*len(psd) + torch.log(variance.max())*len(variance)
variance = waveform.variance.abs()[1:-1]
normalisation = (torch.sum(torch.log(psd))
- torch.log(torch.prod(psd / psd.max()))
+ torch.log(psd.max())*len(psd))
normalisation -= torch.sum(torch.log(variance))
#normalisation += torch.logdet(variance)
else:
normalisation = torch.sum(torch.log(psd)) - torch.log(torch.prod(psd / psd.max())) + torch.log(psd.max())*len(psd)
normalisation = (torch.sum(torch.log(psd))
- torch.log(torch.prod(psd / psd.max()))
+ torch.log(psd.max())*len(psd))
return normalisation
......@@ -292,7 +312,7 @@ class CUDALikelihood(Likelihood):
"""
Calculate the overall log-likelihood.
"""
return self._products(p, model_var) + self._normalisation(p, model_var)
return self._products(p, model_var) #- self._normalisation(p, model_var)
def __call__(self, p, model_var=True):
"""Calculate the log likelihood for a given set of model parameters.
......
......@@ -22,11 +22,10 @@ from elk.waveform import Timeseries, FrequencySeries
from elk.catalogue import PPCatalogue
from . import Model
from ..utils import Complex
from ..utils import diag_cuda
from .gw import BBHSurrogate, HofTSurrogate, BBHNonSpinSurrogate
from heron.models import Model
from heron.models.torchbased import CUDAModel
import matplotlib.pyplot as plt
import heron
......@@ -135,9 +134,12 @@ class CUDAModel(Model):
points = self._generate_eval_matrix(p, times_b)
points = torch.tensor(points, device=self.device).float()#.cuda()
with torch.no_grad(), gpytorch.settings.fast_pred_var(num_probe_vectors=10), gpytorch.settings.max_root_decomposition_size(5000):
f_preds = model(points)
if polarisation == "plus":
f_preds = self.model_plus(points)
elif polarisation == "cross":
f_preds = self.model_cross(points)
mean = f_preds.mean/self.strain_input_factor
mean = f_preds.mean.double()/self.strain_input_factor
var = f_preds.variance.detach().double()/(self.strain_input_factor**2)
covariance = f_preds.covariance_matrix.detach().double()
covariance /= (self.strain_input_factor**2)
......@@ -315,23 +317,26 @@ class HeronCUDA(CUDAModel, BBHSurrogate, HofTSurrogate):
"""
mean, _, cov = self._predict(times, p)
strain_f = Complex((window*mean).rfft(1))
strain_f = torch.view_as_complex((window.double()*mean.double()).rfft(1))
# :todo:`Check windowing`
# :todo:`Check the effect of the windowing on the covariance`
# :todo:`Check the evaluation of the cv matrix with windowing on the mean`
cov_f = cov.rfft(2)
cov_f = torch.view_as_complex(cov.rfft(2))
# :todo:`This should always be real.`
# It is the expectation of x_i·x_i^*
# This might not be true if need to use off-diagonal components
uncert_f = Complex(torch.stack([torch.diag(cov_f[:, :, 0]), torch.diag(cov_f[:, :, 1])]).T)
uncert_f = torch.diag(cov_f)#torch.view_as_complex(torch.stack([torch.diag(cov_f[:, :, 0]),
# torch.diag(cov_f[:, :, 1])]).T)
#frequencies = np.linspace(0, (1/times[-1])/2, len(strain_f))
df = 1/(times[-1]-times[0])
# :todo:`83?`
frequencies = torch.arange(0, 83)*df
if np.any(times):
srate = 1/np.diff(times).mean()
nf = int(np.floor(len(times/2)))+1
frequencies = np.linspace(0, srate, nf)
return FrequencySeries(data=strain_f, frequencies=frequencies, variance=uncert_f)
return FrequencySeries(data=strain_f*self.model.strain_input_factor,
variance=uncert_f*self.model.strain_input_factor**2,
frequencies=frequencies)
def time_domain_waveform(self, p, times=np.linspace(-2, 2, 1000)):
"""
......@@ -495,9 +500,10 @@ class HeronCUDAIMR(CUDAModel, BBHNonSpinSurrogate, HofTSurrogate):
f_preds = model(points)
mean = f_preds.mean/(distance * self.strain_input_factor)
var = f_preds.variance.detach().double()/(self.strain_input_factor**2)
var = f_preds.variance.detach().double()/((distance*self.strain_input_factor**2))
covariance = f_preds.covariance_matrix.detach().double()
covariance /= (((distance*self.strain_input_factor)**2))
covariance /= (((self.strain_input_factor)**2))
covariance /= distance**2
return mean, var, covariance
......@@ -570,14 +576,17 @@ class HeronCUDAIMR(CUDAModel, BBHNonSpinSurrogate, HofTSurrogate):
for polarisation in ["plus", "cross"]:
mean, _, cov = self._predict(times, p, polarisation=polarisation)
strain_f = Complex((window*mean).rfft(1))
cov_f = cov.rfft(2)
strain_f = torch.view_as_complex((window*mean.double()).rfft(1))
cov_f = torch.view_as_complex(cov.rfft(2))
dt = times[-1] - times[-2]
uncert_f = Complex(torch.stack([torch.diag(cov_f[:, :, 0]), torch.diag(cov_f[:, :, 1])]).T)
df = 1/(times[-1]-times[0])
frequencies = torch.arange(0, len(strain_f.real)*df)
uncert_f = diag_cuda(cov_f)
#Complex(torch.stack([torch.diag(cov_f[:, :, 0]), torch.diag(cov_f[:, :, 1])]).T)
if not isinstance(times, type(None)):
srate = 1/np.diff(times).mean()
nf = int(np.floor(len(times)/2))+1
frequencies = np.linspace(0, srate, nf)
data[polarisation] = FrequencySeries(data=strain_f,
frequencies=frequencies,
......
......@@ -5,6 +5,16 @@ Various utilities and non-GPR based stuff.
import torch
import numpy as np
def diag_cuda(a):
"""Make a vector into a diagonal matrix."""
a = torch.view_as_real(a)
if a.dim() == 2:
b = torch.stack([torch.diag(a[:, 0]), torch.diag(a[:, 1])], dim=-1)
elif a.dim() == 3:
b = torch.stack([torch.diag(a[:, :, 0]), torch.diag(a[:, :, 1])], dim=-1)
return torch.view_as_complex(b)
class Complex():
"""
Complex numbers in torch and CUDA.
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment