Commit 51a3d0bb authored by Daniel Williams's avatar Daniel Williams 🤖
Browse files

Updated the torch-based base class.

parent bc7fef4b
Pipeline #193833 passed with stages
in 6 minutes and 28 seconds
......@@ -21,10 +21,17 @@ from matplotlib import pyplot as plt
from elk.waveform import Timeseries
from . import Model
from .gw import BBHSurrogate, HofTSurrogate
from .gw import BBHSurrogate, HofTSurrogate, BBHNonSpinSurrogate
DATA_PATH = pkg_resources.resource_filename('heron', 'models/data/')
disable_cuda = False
if not disable_cuda and torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
def train(model, iterations=1000):
"""
Train the model.
......@@ -60,8 +67,125 @@ def train(model, iterations=1000):
torch.save(model.model_plus.state_dict(), state_vector)
model.eval()
class CUDAModel(Model):
"""
The factory class for all CUDA-based models.
"""
def __init__(self, device=device):
self.device = device
def eval(self):
"""
Prepare the model to be evaluated.
"""
if hasattr(self, "model_plus"):
self.model_plus.eval()
if hasattr(self, "model_cross"):
self.model_cross.eval()
self.likelihood.eval()
def _process_inputs(self, times, p):
times *= self.time_factor
# p['mass ratio'] *= 100 #= np.log(p['mass ratio']) * 100
p = {k: self.time_factor*v for k, v in p.items()}
return times, p
def _predict(self, times, p, polarisation="plus"):
"""
Query the model for the mean and covariance tensors.
Optionally include the covariance.
Parameters
-------------
times : ndarray
The times at which the model should be evaluated.
p : dict
A dictionary of locations in parameter space where the model
should be evaluated.
Returns
-------
mean : torch.tensor
The mean waveform
var : torch.tensor
The variance of the waveform
cov : torch.tensor, optional
The covariance matrix of the waveform.
Only returned if covariance was True.
"""
if polarisation == "plus":
model = self.model_plus
elif polarisation == "cross":
model = self.model_cross
if not isinstance(times, torch.Tensor):
times = torch.tensor(times)
times_b = times.clone()
points = self._generate_eval_matrix(p, times_b)
points = torch.tensor(points, device=self.device).float()#.cuda()
with torch.no_grad(), gpytorch.settings.fast_pred_var(num_probe_vectors=10), gpytorch.settings.max_root_decomposition_size(5000):
f_preds = model(points)
mean = f_preds.mean/self.strain_input_factor
var = f_preds.variance.detach().double()/(self.strain_input_factor**2)
covariance = f_preds.covariance_matrix.detach().double()
covariance /= (self.strain_input_factor**2)
return mean, var, covariance
def mean(self, times, p, covariance=False):
"""
Provide the mean waveform and its variance.
Optionally include the covariance.
Parameters
----------
times : ndarray
The times at which the model should be evaluated.
p : dict
A dictionary of locations in parameter space where the
model should be evaluated.
covariance : bool, optional
A flag to determine if the whole covariance matrix
should be returned.
Returns
-------
mean : torch.tensor
The mean waveform
var : torch.tensor
The variance of the waveform
cov : torch.tensor, optional
The covariance matrix of the waveform.
Only returned if covariance was True.
"""
covariance_flag = covariance
class HeronCUDA(Model, BBHSurrogate, HofTSurrogate):
timeseries = []
covariances = []
if not isinstance(times, torch.Tensor):
times = torch.tensor(times)
if hasattr(self, "model_cross"):
polarisations = ["plus", "cross"]
else:
polarisations = ["plus"]
for polarisation in polarisations:
mean, var, covariance = self._predict(times, p, polarisation=polarisation)
timeseries.append(
Timeseries(data=mean.cpu().numpy().astype(np.float64),
times=times,
variance=var.cpu().numpy().astype(np.float64)))
covariances.append(covariance)
if covariance_flag:
return timeseries, covariances
else:
return timeseries
class HeronCUDA(CUDAModel, BBHSurrogate, HofTSurrogate):
"""
A GPR BBH waveform model which is capable of using CUDA resources.
"""
......@@ -315,14 +439,16 @@ class HeronCUDAMix(CUDAModel, BBHNonSpinSurrogate, HofTSurrogate):
specification: dict
The Heron model specification
"""
super().__init__()
self.specification = specification
self.data = np.genfromtxt(training_data)
super().__init__()
self.x_dimensions = 2
self.time_factor = 100
self.strain_input_factor = 1e21
self.total_mass = specification['total mass']
self.data = np.genfromtxt(training_data)
self.state_vector = f"{self.specification['name']}.pth"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment