Best practices for memory efficient weight loading tutorial
mikaylagawarecki opened this issue · 1 comments
Bug description
Thanks for putting together this great tutorial and showing the pros and cons of each of the available options for model loading. I want to add a caveat re the section added on mmap.
My recommendation on the best practices for loading the model memory efficiently, where I define memory efficiency as
- Does not materialize model on GPU twice
- Should work regardless of limitations on CPU RAM the user has
would be the following
def best_practices():
with torch.device("meta"):
model = GPTModel(BASE_CONFIG)
model.load_state_dict(
torch.load("model.pth", map_location=device, weights_only=True, mmap=True),
assign=True
)
print_memory_usage()
peak_memory_used = memory_usage_in_gb(best_practices)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")
Note that this will print
Maximum GPU memory allocated: 6.4 GB
-> Maximum CPU memory allocated: 6.0 GB
At which point, I expect you to say, hey! 6.0GB is more than the CPU memory used by the load_sequentially_with_meta
example.
Agreed! mmap
is a syscall and hence we do not have fine-grained control over exactly how much CPU RAM will be used. However, the nice(!) thing about mmap=True
is that you should be able to load your model regardless of the user's limitations on CPU RAM :)
What actually happens when setting mmap=True
+ map_location=device
is that the checkpoint file will be mmaped, and then slices of this (corresponding to each storage) will be sent to device
While I don't know of a good way to demonstrate this with an ipynb (resource.setrlimit(rss)
doesn't actually work), if you launch a docker container with CPU RAM limited, I expect you should be able to see this.
So I would not recommend that a user has to save and load each parameter tensor separately as is the case in section 7
Let me know what you think :)
Thanks a lot for opening this issue, I really appreciate these insights here! I think I now understand better why the mmap
approach didn't look so great in practice...it's basically an "on-demand" thing, and the machine has too much memory here so the mmap
function (smartly) doesn't do much here.
I just updated the notebook and flagged the "mmap" method as the recommended one in the section header. Thanks again!