rasbt/LLMs-from-scratch

Best practices for memory efficient weight loading tutorial

mikaylagawarecki opened this issue · 1 comments

Bug description

Thanks for putting together this great tutorial and showing the pros and cons of each of the available options for model loading. I want to add a caveat re the section added on mmap.

My recommendation on the best practices for loading the model memory efficiently, where I define memory efficiency as

  • Does not materialize model on GPU twice
  • Should work regardless of limitations on CPU RAM the user has

would be the following

def best_practices():
  with torch.device("meta"):
      model = GPTModel(BASE_CONFIG)
      
  model.load_state_dict(
          torch.load("model.pth", map_location=device, weights_only=True, mmap=True),
          assign=True
      )
  
  print_memory_usage()

peak_memory_used = memory_usage_in_gb(best_practices)
print(f"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB")

Note that this will print

Maximum GPU memory allocated: 6.4 GB
-> Maximum CPU memory allocated: 6.0 GB

At which point, I expect you to say, hey! 6.0GB is more than the CPU memory used by the load_sequentially_with_meta example.

Agreed! mmap is a syscall and hence we do not have fine-grained control over exactly how much CPU RAM will be used. However, the nice(!) thing about mmap=True is that you should be able to load your model regardless of the user's limitations on CPU RAM :)

What actually happens when setting mmap=True + map_location=device is that the checkpoint file will be mmaped, and then slices of this (corresponding to each storage) will be sent to device

While I don't know of a good way to demonstrate this with an ipynb (resource.setrlimit(rss) doesn't actually work), if you launch a docker container with CPU RAM limited, I expect you should be able to see this.

So I would not recommend that a user has to save and load each parameter tensor separately as is the case in section 7

Let me know what you think :)

Thanks a lot for opening this issue, I really appreciate these insights here! I think I now understand better why the mmap approach didn't look so great in practice...it's basically an "on-demand" thing, and the machine has too much memory here so the mmap function (smartly) doesn't do much here.

I just updated the notebook and flagged the "mmap" method as the recommended one in the section header. Thanks again!