[PyTorch] Training is very slow on Linux.
haifengl opened this issue · 9 comments
Training 10 epochs of MNIST (the sample code from your project README) on takes > 500 seconds on Linux (24 cores, ubuntu 22.04). It takes only about 50 seconds on an old mac (4 cores). Both use CPU (no GPU or MPS).
Try to reduce the number of threads used by PyTorch to 6 or 12, see https://stackoverflow.com/questions/76084214/what-is-recommended-number-of-threads-for-pytorch-based-on-available-cpu-cores
It's most probably related to pytorch not finding openblas and/or MKL in your path.
Have you added mkl-platform-redist
to your dependencies ?
You can also try to download and use the official libtorch, add the path containing its libs to your library path, and set -Dorg.bytedeco.javacpp.pathsFirst
: the official binaries are statically built with MKL.
It helps a lot by set OMP_NUM_THREADS=12
on linux. The training speed is on par with mac (4 threads). Without it, torch.get_num_threads() returns 48. So the slowness may be caused by hyper-threading. According to your link, PyTorch will set the number of threads to the half of vCores. If so, we shouldn't have this issue on Linux. However, it is not the case with JavaCPP building. Do we miss some building configuration for Linux? Thanks!
So the default is 24 on that machine, but it doesn't mean it's going to give good results
The default is 48 with JavaCPP build, which is too high. It should be 24 for this case.
Have you tried with the official libtorch ?
libtorch sets it to 24 by default on my box. And it works well. Why does JavaCPP build libtorch from source? Why not package the precompiled libtorch library from pytorch.org?
See discussion here
Here is the result of running the sample MNIST code on a machine with 32 vcores and 16 physical cores:
OpenMP lib | Default num thread | Speed |
---|---|---|
omp | 32 | Very slow |
gomp | 32 | Somewhat slow |
mkl static (official build) | 16 | fast |
When forcing the num thread to 16 using OMP_NUM_THREADS
or torch.set_num_threads
, it's fast in all cases.
I'll try to rationalize that in the PR so that torch is linked with gomp on linux.
Also the fact that the presets preloads every possible openmp lib it finds, leading to possibly multiple different libraries loaded surely doesn't help.