Commit 92f71667 authored by Daniel Williams's avatar Daniel Williams 🤖
Browse files

Updated the torch model to have (slightly hacked) cross polarisation support.

parent 1c6b6302
Pipeline #130900 passed with stages
in 6 minutes and 6 seconds
......@@ -23,7 +23,6 @@ from .gw import BBHSurrogate, HofTSurrogate
DATA_PATH = pkg_resources.resource_filename('heron', 'models/data/')
class HeronCUDA(Model, BBHSurrogate, HofTSurrogate):
"""
A GPR BBH waveform model which is capable of using CUDA resources.
......@@ -38,9 +37,10 @@ class HeronCUDA(Model, BBHSurrogate, HofTSurrogate):
"""
# super(HeronCUDA, self).__init__()
self.model, self.likelihood = self.build()
assert torch.cuda.is_available() # This is a bit of a kludge
(self.model_plus, self.model_cross), self.likelihood = self.build()
self.x_dimensions = 8
self.time_factor = 1
self.time_factor = 100
self.strain_input_factor = 1e21
#
self.eval()
......@@ -49,7 +49,8 @@ class HeronCUDA(Model, BBHSurrogate, HofTSurrogate):
"""
Prepare the model to be evaluated.
"""
self.model.eval()
self.model_plus.eval()
self.model_cross.eval()
self.likelihood.eval()
def _process_inputs(self, times, p):
......@@ -99,19 +100,23 @@ class HeronCUDA(Model, BBHSurrogate, HofTSurrogate):
training_x = torch.tensor(data[:, 0:-2]*100).float().cuda()
training_y = torch.tensor(data[:, -2]*1e21).float().cuda()
training_yx = torch.tensor(data[:, -1]*1e21).float().cuda()
likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=LessThan(10))
model = ExactGPModel(training_x, training_y, likelihood)
model2 = ExactGPModel(training_x, training_yx, likelihood)
state_vector = pkg_resources.resource_filename('heron', 'models/data/gt-gpytorch.pth')
model = model.cuda()
model2 = model2.cuda()
likelihood = likelihood.cuda()
model.load_state_dict(torch.load(state_vector))
model2.load_state_dict(torch.load(state_vector))
return model, likelihood
return [model, model2], likelihood
def _predict(self, times, p):
def _predict(self, times, p, polarisation="plus"):
"""
Query the model for the mean and covariance tensors.
Optionally include the covariance.
......@@ -135,11 +140,16 @@ class HeronCUDA(Model, BBHSurrogate, HofTSurrogate):
Only returned if covariance was True.
"""
if polarisation == "plus":
model = self.model_plus
elif polarisation == "cross":
model = self.model_cross
times_b = times.copy()
points = self._generate_eval_matrix(p, times_b)
points = torch.tensor(points).float().cuda()
with torch.no_grad(), gpytorch.settings.fast_pred_var():
f_preds = self.model(points)
f_preds = model(points)
mean = f_preds.mean/self.strain_input_factor
var = f_preds.variance.detach().double()/(self.strain_input_factor**2)
......@@ -176,13 +186,20 @@ class HeronCUDA(Model, BBHSurrogate, HofTSurrogate):
"""
covariance_flag = covariance
mean, var, covariance = self._predict(times, p)
timeseries = []
covariances = []
for polarisation in ["plus", "cross"]:
mean, var, covariance = self._predict(times, p, polarisation=polarisation)
timeseries.append(
Timeseries(data=mean.cpu().numpy().astype(np.float64),
times=times,
variance=var.cpu().numpy().astype(np.float64)))
covariances.append(covariance)
if covariance_flag:
return Timeseries(data=mean.cpu(),
times=times, variance=var.cpu()), covariance
return timeseries, covariances
else:
return Timeseries(data=mean.cpu(), times=times, variance=var.cpu())
return timeseries
def distribution(self, times, p, samples=100):
"""
......@@ -191,13 +208,15 @@ class HeronCUDA(Model, BBHSurrogate, HofTSurrogate):
times_b = times.copy()
points = self._generate_eval_matrix(p, times_b)
points = torch.tensor(points).float().cuda()
with torch.no_grad(), gpytorch.settings.fast_pred_var():
f_preds = self.model(points)
y_preds = self.likelihood(f_preds)
return_samples = [Timeseries(data=sample.cpu()/self.strain_input_factor,
times=times_b)
for sample in y_preds.sample_n(samples)]
return_samples = []
for polarisation in [self.model_plus, self.model_cross]:
with torch.no_grad(), gpytorch.settings.fast_pred_var():
f_preds = polarisation(points)
y_preds = self.likelihood(f_preds)
return_samples.append([Timeseries(data=sample.cpu()/self.strain_input_factor,
times=times_b)
for sample in y_preds.sample_n(samples)])
return return_samples
def frequency_domain_waveform(self, p, times=np.linspace(-2, 2, 1000)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment