[Question] Why is fps much lower than CPU if using GPU
fanliaoooo opened this issue · 1 comments
fanliaoooo commented
Question
I encountered a problem that a RL training can run at 5000 fps if I'm using sbx+cpu, but after the jaxlib-cuda11 was installed, it can only run at about 2000 fps.
Platform: Ubuntu 20.04, x86_64
Python version: 3.9.12
GPU: NVIDIA RTX 4090
CPU: i9-13900KS
nvidia-smi:
Thu Mar 21 14:41:29 2024
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.147.05 Driver Version: 525.147.05 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:01:00.0 On | Off |
| 0% 48C P2 68W / 500W | 20952MiB / 24564MiB | 25% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 1281 G /usr/lib/xorg/Xorg 941MiB |
| 0 N/A N/A 2137 G /usr/bin/gnome-shell 183MiB |
| 0 N/A N/A 4972 G ...RendererForSitePerProcess 7MiB |
| 0 N/A N/A 5435 G ...2gtk-4.0/WebKitWebProcess 6MiB |
| 0 N/A N/A 6950 G ...b2020b/bin/glnxa64/MATLAB 6MiB |
| 0 N/A N/A 8317 G ...17D222A6D1FB8847155E9F895 19MiB |
| 0 N/A N/A 89295 G ...on=20240315-130113.878000 249MiB |
| 0 N/A N/A 96608 C ...3/envs/sb3-jax/bin/python 19530MiB |
+-----------------------------------------------------------------------------+
pip list | grep nvidia
nvidia-cublas-cu11 11.11.3.6
nvidia-cublas-cu12 12.4.2.65
nvidia-cuda-cupti-cu11 11.8.87
nvidia-cuda-cupti-cu12 12.4.99
nvidia-cuda-nvcc-cu11 11.8.89
nvidia-cuda-nvcc-cu12 12.4.99
nvidia-cuda-nvrtc-cu11 11.8.89
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu11 11.8.89
nvidia-cuda-runtime-cu12 12.4.99
nvidia-cudnn-cu11 8.9.6.50
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu11 10.9.0.58
nvidia-cufft-cu12 11.2.0.44
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu11 11.4.1.48
nvidia-cusolver-cu12 11.6.0.99
nvidia-cusparse-cu11 11.7.5.86
nvidia-cusparse-cu12 12.3.0.142
nvidia-nccl-cu11 2.20.5
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.4.99
nvidia-nvtx-cu12 12.1.105
and training script is:
from stable_baselines3.common.envs import MyCustomEnv
from sbx import PPO
env = MyCustomEnv()
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=int(1e8), progress_bar=True)
then I got:
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
...
------------------------------------
| rollout/ | |
| ep_len_mean | 11.7 |
| ep_rew_mean | -34.5 |
| time/ | |
| fps | 2027 |
| iterations | 927 |
| time_elapsed | 936 |
| total_timesteps | 1898496 |
| train/ | |
| clip_range | 0.2 |
| explained_variance | 0.34 |
| n_updates | 9260 |
| pg_loss | -0.0721 |
| value_loss | 5.89 |
------------------------------------
However, if using cpu,
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
...
------------------------------------
| rollout/ | |
| ep_len_mean | 11.6 |
| ep_rew_mean | -35.2 |
| time/ | |
| fps | 4773 |
| iterations | 139 |
| time_elapsed | 59 |
| total_timesteps | 284672 |
| train/ | |
| clip_range | 0.2 |
| explained_variance | 0.281 |
| n_updates | 1380 |
| pg_loss | -0.0598 |
| value_loss | 4.52 |
jax(gpu) is installed by Installing JAX, the cpu version is installed by pip install jax
.
Checklist
- I have read the documentation (required)
- I have checked that there is no similar issue in the repo (required)
araffin commented
See #31 (comment)
and probably other related issues on the sb3 repo.
In short, this is expected with ppo and mlp, transfering data to gpu becomes the bottleneck.