Importing POT consumes 22G GPU memory. [ONLY FOR 0.9.1]
HowardZJU opened this issue · 11 comments
Describe the bug
It is kind of strange for me. When I try to import POT, the memory occupation gets 23172MiB / 24220MiB for GPU:0 and 814MiB / 24220MiB for other GPUs.
To Reproduce
Steps to reproduce the behavior:
import torch
import ot
then run nvidia-smi
2023-08-28 00:06:28.565778: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/homeold/home/wangh/miniconda3/envs/mvi/lib/python3.10/site-packages/ot/backend.py:2998: UserWarning: To use TensorflowBackend, you need to activate the tensorflow numpy API. You can activate it by running:
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
register_backend(TensorflowBackend())
Screenshots
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Linux
- Python version: Python 3.10 or 3.9
- How was POT installed (source,
pip
,conda
): pip - POT: 0.9.1
- Build command you used (if compiling from source): pip install pythonot
- Only for GPU related bugs:
- CUDA version: 11.4
- GPU: RTX TITAN
- Any other relevant information:
When I reinstall pot=0.9.0, the issue is resolved.
A little more details on the issue. Here's memory profile from my machine before and after importing ot
.
Before
Sat Sep 2 07:45:06 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04 Driver Version: 525.116.04 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA RTX A6000 Off | 00000000:00:05.0 Off | Off |
| 30% 31C P8 22W / 300W | 1MiB / 49140MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
After
Sat Sep 2 07:45:25 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04 Driver Version: 525.116.04 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA RTX A6000 Off | 00000000:00:05.0 Off | Off |
| 30% 34C P2 69W / 300W | 47141MiB / 49140MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
+-----------------------------------------------------------------------------+
I guess the reason this happens is the way TF allocation works. By default, TensorFlow maps nearly all of the GPU memory of all GPUs visible to the process.
The problem that now I can't use PyTorch, even though TF was not my "target". It's just happened to be installed. PyTorch now doesn't have access to the GPU memory:
RuntimeError: CUDA out of memory. Tried to allocate 12.21 GiB (GPU 0; 47.54 GiB total capacity; 782.38 MiB already allocated; 753.12 MiB free; 784.00 MiB reserved in total by PyTorch)
One potential solution to this problem is to restrict memory allocation with tf.config.experimental.set_memory_growth
though I'm not sure type of performance impact such allocation strategy would have. I'm not sure if there's a good way of restricting list of backends prior to loading library. Seems like that would require change in API with some sort of "activate backend" or "enable backend" type of call, which might not be obvious for users (definitely will break compatibility with existing setups). WDYT @rflamary?
This is awful behavior by tensorflow! Can you check that if you import tensforflow and set this the memory is available for pytorch?
Another way would be to force POT not to use tensorflow (backend (and import tensorflow) with an environment variable ``POT_DISABLE_TENSORFLOW` (we could use then also for pytoirch/cupy and evetything and it would speedup pot loadin).
Also there is a mistery why this happens in 0.9.1 and not 0.9.1 this is something we need to understand
So, for TensorFlow we have the following:
Before
Mon Sep 4 09:31:38 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04 Driver Version: 525.116.04 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA RTX A6000 Off | 00000000:00:05.0 Off | Off |
| 30% 56C P8 21W / 300W | 1MiB / 49140MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
Executing the following code:
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
try:
tf.config.experimental.set_memory_growth(device, True)
except Exception:
pass
Check allocation:
Mon Sep 4 09:31:57 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04 Driver Version: 525.116.04 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA RTX A6000 Off | 00000000:00:05.0 Off | Off |
| 30% 55C P8 21W / 300W | 3MiB / 49140MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
Now, importing ot
import ot
Check allocation once again:
Mon Sep 4 09:32:09 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04 Driver Version: 525.116.04 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA RTX A6000 Off | 00000000:00:05.0 Off | Off |
| 30% 57C P2 81W / 300W | 931MiB / 49140MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
+-----------------------------------------------------------------------------+
So, it does allocate almost a Gb but still much better than before.
Because JAX uses XLA engine under the hood, I also checked JAX. It has the same default behavior. Preallocation strategy could be switched off by setting env variable XLA_PYTHON_CLIENT_PREALLOCATE
to false
. More information on GPU memory allocation here.
@rflamary It's a rather tricky situation, honestly. Memory allocation settings are process-wide and I don't think it's desirable to set them from the scope of the library. In many (in most?) applications such behavior would be desirable. It seems like the right course of action here is to find why ot
forces preallocation on the import instead of doing this when I'm actually using one of those backends. TF allocates (or at least it should pre-allocate) only when the first tensor being moved on the target device, if we make sure that importing library never creates tensors on GPU, we should safely assume that user already took care about settings that are required by the application by the time it calls function from ot
.
Going through the history of recent changes... I guess the problem appeared first in this commit. Prior to the change, the backend object was created only from get_backend
call. After: the library creates backend objects up-front (on import). And we do, in fact, perform allocations in backend constructors, for example this one for JAX. Which kicks of memory pre-allocation.
damnit! now I understand. I really did not anticipate this.
Maybe we should try to find a way i the middle that creates the backend on first call and then use the one alerady created?
and again thanks very much @kachayev this is some nice code sleuthing. is still believe that it would be interesting to disable importnat tensorflow/torh with env variables to speedup the import process that can be quite slow when they are all installed
Thank you very much. We also encountered the same problem. Changing the pot version to 0.90 will work fine, and we will not investigate the cause for now. Thank you