PythonOT/POT

Importing POT consumes 22G GPU memory. [ONLY FOR 0.9.1]

HowardZJU opened this issue · 11 comments

Describe the bug

It is kind of strange for me. When I try to import POT, the memory occupation gets 23172MiB / 24220MiB for GPU:0 and 814MiB / 24220MiB for other GPUs.

To Reproduce

Steps to reproduce the behavior:

import torch
import ot

then run nvidia-smi

2023-08-28 00:06:28.565778: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
/homeold/home/wangh/miniconda3/envs/mvi/lib/python3.10/site-packages/ot/backend.py:2998: UserWarning: To use TensorflowBackend, you need to activate the tensorflow numpy API. You can activate it by running:
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
register_backend(TensorflowBackend())

Screenshots

截图_20230828001502 截图_20230828001510 截图_20230828001524

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Linux
  • Python version: Python 3.10 or 3.9
  • How was POT installed (source, pip, conda): pip
  • POT: 0.9.1
  • Build command you used (if compiling from source): pip install pythonot
  • Only for GPU related bugs:
    • CUDA version: 11.4
    • GPU: RTX TITAN
    • Any other relevant information:

When I reinstall pot=0.9.0, the issue is resolved.

A little more details on the issue. Here's memory profile from my machine before and after importing ot.

Before

Sat Sep  2 07:45:06 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   31C    P8    22W / 300W |      1MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

After

Sat Sep  2 07:45:25 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   34C    P2    69W / 300W |  47141MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

I guess the reason this happens is the way TF allocation works. By default, TensorFlow maps nearly all of the GPU memory of all GPUs visible to the process.

The problem that now I can't use PyTorch, even though TF was not my "target". It's just happened to be installed. PyTorch now doesn't have access to the GPU memory:

RuntimeError: CUDA out of memory. Tried to allocate 12.21 GiB (GPU 0; 47.54 GiB total capacity; 782.38 MiB already allocated; 753.12 MiB free; 784.00 MiB reserved in total by PyTorch)

One potential solution to this problem is to restrict memory allocation with tf.config.experimental.set_memory_growth though I'm not sure type of performance impact such allocation strategy would have. I'm not sure if there's a good way of restricting list of backends prior to loading library. Seems like that would require change in API with some sort of "activate backend" or "enable backend" type of call, which might not be obvious for users (definitely will break compatibility with existing setups). WDYT @rflamary?

This is awful behavior by tensorflow! Can you check that if you import tensforflow and set this the memory is available for pytorch?

Another way would be to force POT not to use tensorflow (backend (and import tensorflow) with an environment variable ``POT_DISABLE_TENSORFLOW` (we could use then also for pytoirch/cupy and evetything and it would speedup pot loadin).

Also there is a mistery why this happens in 0.9.1 and not 0.9.1 this is something we need to understand

@rflamary Sure, working on it

So, for TensorFlow we have the following:

Before

Mon Sep  4 09:31:38 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   56C    P8    21W / 300W |      1MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

Executing the following code:

import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
    try:
        tf.config.experimental.set_memory_growth(device, True)
    except Exception:
        pass

Check allocation:

Mon Sep  4 09:31:57 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   55C    P8    21W / 300W |      3MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

Now, importing ot

import ot

Check allocation once again:

Mon Sep  4 09:32:09 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    Off  | 00000000:00:05.0 Off |                  Off |
| 30%   57C    P2    81W / 300W |    931MiB / 49140MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

So, it does allocate almost a Gb but still much better than before.

Because JAX uses XLA engine under the hood, I also checked JAX. It has the same default behavior. Preallocation strategy could be switched off by setting env variable XLA_PYTHON_CLIENT_PREALLOCATE to false. More information on GPU memory allocation here.

@rflamary It's a rather tricky situation, honestly. Memory allocation settings are process-wide and I don't think it's desirable to set them from the scope of the library. In many (in most?) applications such behavior would be desirable. It seems like the right course of action here is to find why ot forces preallocation on the import instead of doing this when I'm actually using one of those backends. TF allocates (or at least it should pre-allocate) only when the first tensor being moved on the target device, if we make sure that importing library never creates tensors on GPU, we should safely assume that user already took care about settings that are required by the application by the time it calls function from ot.

Going through the history of recent changes... I guess the problem appeared first in this commit. Prior to the change, the backend object was created only from get_backend call. After: the library creates backend objects up-front (on import). And we do, in fact, perform allocations in backend constructors, for example this one for JAX. Which kicks of memory pre-allocation.

damnit! now I understand. I really did not anticipate this.

Maybe we should try to find a way i the middle that creates the backend on first call and then use the one alerady created?

and again thanks very much @kachayev this is some nice code sleuthing. is still believe that it would be interesting to disable importnat tensorflow/torh with env variables to speedup the import process that can be quite slow when they are all installed

Absolutely, @rflamary! I will create PR with all necessary changes.

Thank you very much. We also encountered the same problem. Changing the pot version to 0.90 will work fine, and we will not investigate the cause for now. Thank you