Skip to content
Snippets Groups Projects

Update quantum code from matgwinc, and vectorize

Merged Christopher Wipf requested to merge fast-quantum-matmul into master
1 file
+ 12
15
Compare changes
  • Side-by-side
  • Inline
+ 12
15
@@ -421,26 +421,23 @@ def getProdTF(lhs, rhs):
raise Exception('Matrix size mismatch size(lhs, 2) = %d != %d = size(rhs, 1)' % (lhs.shape[1], rhs.shape[0]))
N = lhs.shape[0]
M = rhs.shape[1]
if len(lhs.shape) == 3:
lhs = np.transpose(lhs, axes=(2, 0, 1))
if len(rhs.shape) == 3:
rhs = np.transpose(rhs, axes=(2, 0, 1))
# compute product
if len(lhs.shape) < 3 or lhs.shape[2] == 1:
Nfreq = rhs.shape[2]
rslt = zeros((N, M, Nfreq), dtype=complex)
for n in range(Nfreq):
rslt[:, :, n] = np.dot(np.squeeze(lhs), rhs[:, :, n])
elif len(rhs.shape) < 3 or rhs.shape[2] == 1:
Nfreq = lhs.shape[2]
rslt = zeros((N, M, Nfreq), dtype=complex)
for n in range(Nfreq):
rslt[:, :, n] = np.dot(lhs[:, :, n], np.squeeze(rhs))
elif lhs.shape[2] == rhs.shape[2]:
Nfreq = lhs.shape[2]
rslt = zeros((N, M, Nfreq), dtype=complex)
for n in range(Nfreq):
rslt[:, :, n] = np.dot(lhs[:, :, n], rhs[:, :, n])
if len(lhs.shape) < 3 or lhs.shape[0] == 1:
rslt = np.matmul(lhs, rhs)
elif len(rhs.shape) < 3 or rhs.shape[0] == 1:
rslt = np.matmul(lhs, rhs)
elif lhs.shape[0] == rhs.shape[0]:
rslt = np.matmul(lhs, rhs)
else:
raise Exception('Matrix size mismatch lhs.shape[2] = %d != %d = rhs.shape[2]' % (lhs.shape[2], rhs.shape[2]))
if len(rslt.shape) == 3:
rslt = np.transpose(rslt, axes=(1, 2, 0))
return rslt
Loading