MichaelTMatthews/Craftax

High GPU memory on simple import

Closed this issue · 2 comments

When importing craftax functions such as the one below, my GPU memory immediately fills up for a large part. E.g. when I run the single line below

from craftax.craftax_env import make_craftax_env_from_name

and then check the gpu memory, I get that 60GB is already taken up on a H100 machine:

Screenshot 2024-06-20 at 12 12 31 PM

Is this normal? And if so, what's causing this huge memory allocation?
Thanks a lot for your help!

It is default JAX behaviour to preallocate 75% of GPU memory.
This behaviour can be controlled as detailed here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

I had no idea, thanks for the pointer!