I found pyTorch to be much slower than numpy when doing complex-valued matrix-vector multiplication on CPU:
A few notes:
Perhaps I have misconfigured something?
Code to produce above plots:
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
maxn = 3000
nrep = 100
def conv(M,latype):
if latype=='numpy':
return np.array(M)
if latype.startswith('torch,'):
return torch.tensor(M,device=latype[7:])
def multtest(A,b):
t0 = time.time()
for i in range(nrep):
b = A@b
t1 = time.time()
return (t1-t0)/nrep
ns = np.array(np.linspace(100,maxn,100),dtype=int)
numpyts = np.zeros(len(ns))
torchts = np.zeros(len(ns))
fig,axes = plt.subplots(1,2)
for ax,dtype in zip(axes,['real','complex']):
Aorig = np.random.rand(maxn,maxn)
borig = np.random.rand(maxn)
if dtype == 'complex':
Aorig = Aorig + 1.j*np.random.rand(maxn,maxn)
borig = borig + 1.j*np.random.rand(maxn)
for latype in ['numpy','torch, cpu']:
A = conv(Aorig,latype)
b = conv(borig,latype)
ts = np.zeros(len(ns))
for i,n in enumerate(ns):
ts[i] = multtest(A[:n,:n],b[:n])
ax.plot(ns,ts,label=latype)
ax.legend()
ax.set_title(dtype)
ax.set_xlabel('vector/matrix size')
ax.set_ylabel('mean matrix-vector mult time (sec)')
fig.tight_layout()
plt.show()
I found pyTorch to be much slower than numpy when doing complex-valued matrix-vector multiplication on CPU:
A few notes:
Perhaps I have misconfigured something?
Code to produce above plots:
import torch
import numpy as np
import matplotlib.pyplot as plt
import time
maxn = 3000
nrep = 100
def conv(M,latype):
if latype=='numpy':
return np.array(M)
if latype.startswith('torch,'):
return torch.tensor(M,device=latype[7:])
def multtest(A,b):
t0 = time.time()
for i in range(nrep):
b = A@b
t1 = time.time()
return (t1-t0)/nrep
ns = np.array(np.linspace(100,maxn,100),dtype=int)
numpyts = np.zeros(len(ns))
torchts = np.zeros(len(ns))
fig,axes = plt.subplots(1,2)
for ax,dtype in zip(axes,['real','complex']):
Aorig = np.random.rand(maxn,maxn)
borig = np.random.rand(maxn)
if dtype == 'complex':
Aorig = Aorig + 1.j*np.random.rand(maxn,maxn)
borig = borig + 1.j*np.random.rand(maxn)
for latype in ['numpy','torch, cpu']:
A = conv(Aorig,latype)
b = conv(borig,latype)
ts = np.zeros(len(ns))
for i,n in enumerate(ns):
ts[i] = multtest(A[:n,:n],b[:n])
ax.plot(ns,ts,label=latype)
ax.legend()
ax.set_title(dtype)
ax.set_xlabel('vector/matrix size')
ax.set_ylabel('mean matrix-vector mult time (sec)')
fig.tight_layout()
plt.show()
I can reproduce your problem on Windows 10, with CPython 3.8.1, Numpy 1.24.3, Torch 49444c3e (the packages are the default ones installed via pip). Torch is setup to use my (i5-9600KF) CPU. Here is the result:
I can also see that torch uses only 1 core while Numpy uses multiple cores (only) for complex numbers. Numpy uses OpenBLAS internally by default. Torch probably uses another implementation which is not optimized for that (no parallelism for an unknown reason). I can see that the real version uses OpenMP internally while the complex one does not. Both does not appear to call any (dynamic) BLAS function internally (which tends to confirm they use their own implementation unless they statically linked a BLAS).
Assuming they also use a BLAS but the default one use is not efficient, then you can certainly compile/package it so to link another faster BLAS implementation (possibly OpenBLAS or another one like BLIS or the Intel MKL).
If they uses their own implementation, then you can open an issue about this so to use OpenMP also in the complex version.
AFAIK, Torch is optimized for real simple-precision computations on GPUs and not really complex double-precision computations on CPUs. Thus, maybe they did not care about this yet.
Note I can see the following warning during the execution by the way:
<ipython-input-85-1e20a6760269>:18: RuntimeWarning: overflow encountered in matmul
b = A@b
<ipython-input-85-1e20a6760269>:18: RuntimeWarning: overflow encountered in matmul
b = A@b
<ipython-input-85-1e20a6760269>:18: RuntimeWarning: invalid value encountered in matmul
b = A@b
When I run the code I get a different plot:
torch: 2.3.1
numpy: 1.26.4
cuda: 12.2
NVIDIA-Driver: 535.183.01 (Ubuntu)