From c7f60418f1d3b08778d30fe94ef24aea81e76416 Mon Sep 17 00:00:00 2001 From: Christopher Wipf <wipf@ligo.mit.edu> Date: Sat, 18 Aug 2018 14:18:59 -0700 Subject: [PATCH] Speed up quantum calculation with monolithic matrix multiply --- gwinc/noise/quantum.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/gwinc/noise/quantum.py b/gwinc/noise/quantum.py index 0af7defb..466cc8d8 100644 --- a/gwinc/noise/quantum.py +++ b/gwinc/noise/quantum.py @@ -421,26 +421,23 @@ def getProdTF(lhs, rhs): raise Exception('Matrix size mismatch size(lhs, 2) = %d != %d = size(rhs, 1)' % (lhs.shape[1], rhs.shape[0])) N = lhs.shape[0] M = rhs.shape[1] + if len(lhs.shape) == 3: + lhs = np.transpose(lhs, axes=(2, 0, 1)) + if len(rhs.shape) == 3: + rhs = np.transpose(rhs, axes=(2, 0, 1)) # compute product - if len(lhs.shape) < 3 or lhs.shape[2] == 1: - Nfreq = rhs.shape[2] - rslt = zeros((N, M, Nfreq), dtype=complex) - for n in range(Nfreq): - rslt[:, :, n] = np.dot(np.squeeze(lhs), rhs[:, :, n]) - elif len(rhs.shape) < 3 or rhs.shape[2] == 1: - Nfreq = lhs.shape[2] - rslt = zeros((N, M, Nfreq), dtype=complex) - for n in range(Nfreq): - rslt[:, :, n] = np.dot(lhs[:, :, n], np.squeeze(rhs)) - elif lhs.shape[2] == rhs.shape[2]: - Nfreq = lhs.shape[2] - rslt = zeros((N, M, Nfreq), dtype=complex) - for n in range(Nfreq): - rslt[:, :, n] = np.dot(lhs[:, :, n], rhs[:, :, n]) + if len(lhs.shape) < 3 or lhs.shape[0] == 1: + rslt = np.matmul(lhs, rhs) + elif len(rhs.shape) < 3 or rhs.shape[0] == 1: + rslt = np.matmul(lhs, rhs) + elif lhs.shape[0] == rhs.shape[0]: + rslt = np.matmul(lhs, rhs) else: raise Exception('Matrix size mismatch lhs.shape[2] = %d != %d = rhs.shape[2]' % (lhs.shape[2], rhs.shape[2])) + if len(rslt.shape) == 3: + rslt = np.transpose(rslt, axes=(1, 2, 0)) return rslt -- GitLab