araffin/sbx

[Question] Why is fps much lower than CPU if using GPU

fanliaoooo opened this issue · 1 comments

Question

I encountered a problem that a RL training can run at 5000 fps if I'm using sbx+cpu, but after the jaxlib-cuda11 was installed, it can only run at about 2000 fps.

Platform: Ubuntu 20.04, x86_64
Python version: 3.9.12
GPU: NVIDIA RTX 4090
CPU: i9-13900KS

nvidia-smi:

Thu Mar 21 14:41:29 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05   Driver Version: 525.147.05   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  Off |
|  0%   48C    P2    68W / 500W |  20952MiB / 24564MiB |     25%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      1281      G   /usr/lib/xorg/Xorg                941MiB |
|    0   N/A  N/A      2137      G   /usr/bin/gnome-shell              183MiB |
|    0   N/A  N/A      4972      G   ...RendererForSitePerProcess        7MiB |
|    0   N/A  N/A      5435      G   ...2gtk-4.0/WebKitWebProcess        6MiB |
|    0   N/A  N/A      6950      G   ...b2020b/bin/glnxa64/MATLAB        6MiB |
|    0   N/A  N/A      8317      G   ...17D222A6D1FB8847155E9F895       19MiB |
|    0   N/A  N/A     89295      G   ...on=20240315-130113.878000      249MiB |
|    0   N/A  N/A     96608      C   ...3/envs/sb3-jax/bin/python    19530MiB |
+-----------------------------------------------------------------------------+

pip list | grep nvidia

nvidia-cublas-cu11            11.11.3.6
nvidia-cublas-cu12            12.4.2.65
nvidia-cuda-cupti-cu11        11.8.87
nvidia-cuda-cupti-cu12        12.4.99
nvidia-cuda-nvcc-cu11         11.8.89
nvidia-cuda-nvcc-cu12         12.4.99
nvidia-cuda-nvrtc-cu11        11.8.89
nvidia-cuda-nvrtc-cu12        12.1.105
nvidia-cuda-runtime-cu11      11.8.89
nvidia-cuda-runtime-cu12      12.4.99
nvidia-cudnn-cu11             8.9.6.50
nvidia-cudnn-cu12             8.9.2.26
nvidia-cufft-cu11             10.9.0.58
nvidia-cufft-cu12             11.2.0.44
nvidia-curand-cu12            10.3.2.106
nvidia-cusolver-cu11          11.4.1.48
nvidia-cusolver-cu12          11.6.0.99
nvidia-cusparse-cu11          11.7.5.86
nvidia-cusparse-cu12          12.3.0.142
nvidia-nccl-cu11              2.20.5
nvidia-nccl-cu12              2.19.3
nvidia-nvjitlink-cu12         12.4.99
nvidia-nvtx-cu12              12.1.105

and training script is:

from stable_baselines3.common.envs import MyCustomEnv
from sbx import PPO

env = MyCustomEnv()
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=int(1e8), progress_bar=True)

then I got:

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.

...

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 11.7     |
|    ep_rew_mean        | -34.5    |
| time/                 |          |
|    fps                | 2027     |
|    iterations         | 927      |
|    time_elapsed       | 936      |
|    total_timesteps    | 1898496  |
| train/                |          |
|    clip_range         | 0.2      |
|    explained_variance | 0.34     |
|    n_updates          | 9260     |
|    pg_loss            | -0.0721  |
|    value_loss         | 5.89     |
------------------------------------

However, if using cpu,

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

...


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 11.6     |
|    ep_rew_mean        | -35.2    |
| time/                 |          |
|    fps                | 4773     |
|    iterations         | 139      |
|    time_elapsed       | 59       |
|    total_timesteps    | 284672   |
| train/                |          |
|    clip_range         | 0.2      |
|    explained_variance | 0.281    |
|    n_updates          | 1380     |
|    pg_loss            | -0.0598  |
|    value_loss         | 4.52     |

jax(gpu) is installed by Installing JAX, the cpu version is installed by pip install jax.

Checklist

  • I have read the documentation (required)
  • I have checked that there is no similar issue in the repo (required)

See #31 (comment)

and probably other related issues on the sb3 repo.

In short, this is expected with ppo and mlp, transfering data to gpu becomes the bottleneck.