diff --git a/gwinc/noise/quantum.py b/gwinc/noise/quantum.py
index 0af7defb93b5501d7ac9f36bcfe681e08dcc69bb..466cc8d8d9fcda8e9374be3fa0557bd7ef1e7a3a 100644
--- a/gwinc/noise/quantum.py
+++ b/gwinc/noise/quantum.py
@@ -421,26 +421,23 @@ def getProdTF(lhs, rhs):
         raise Exception('Matrix size mismatch size(lhs, 2) = %d != %d = size(rhs, 1)' % (lhs.shape[1], rhs.shape[0]))
     N = lhs.shape[0]
     M = rhs.shape[1]
+    if len(lhs.shape) == 3:
+        lhs = np.transpose(lhs, axes=(2, 0, 1))
+    if len(rhs.shape) == 3:
+        rhs = np.transpose(rhs, axes=(2, 0, 1))
 
     # compute product
-    if len(lhs.shape) < 3 or lhs.shape[2] == 1:
-        Nfreq = rhs.shape[2]
-        rslt = zeros((N, M, Nfreq), dtype=complex)
-        for n in range(Nfreq):
-            rslt[:, :, n] = np.dot(np.squeeze(lhs), rhs[:, :, n])
-    elif len(rhs.shape) < 3 or rhs.shape[2] == 1:
-        Nfreq = lhs.shape[2]
-        rslt = zeros((N, M, Nfreq), dtype=complex)
-        for n in range(Nfreq):
-            rslt[:, :, n] = np.dot(lhs[:, :, n], np.squeeze(rhs))
-    elif lhs.shape[2] == rhs.shape[2]:
-        Nfreq = lhs.shape[2]
-        rslt = zeros((N, M, Nfreq), dtype=complex)
-        for n in range(Nfreq):
-            rslt[:, :, n] = np.dot(lhs[:, :, n], rhs[:, :, n])
+    if len(lhs.shape) < 3 or lhs.shape[0] == 1:
+        rslt = np.matmul(lhs, rhs)
+    elif len(rhs.shape) < 3 or rhs.shape[0] == 1:
+        rslt = np.matmul(lhs, rhs)
+    elif lhs.shape[0] == rhs.shape[0]:
+        rslt = np.matmul(lhs, rhs)
     else:
         raise Exception('Matrix size mismatch lhs.shape[2] = %d != %d = rhs.shape[2]' % (lhs.shape[2], rhs.shape[2]))
 
+    if len(rslt.shape) == 3:
+        rslt = np.transpose(rslt, axes=(1, 2, 0))
     return rslt