How did you train the large-sized models without out-of-memory?
jiang719 opened this issue · 3 comments
I would like to fine-tune the 2B model, but I got the out-of-memory issue even with the batch size setting to 1 (on a single GPU with 24G memory).
I wonder what devices you used to pre-train the 2B and 16B models? How did you address the memory issue? Did you parallel the model by layers on different GPUs? Thank you.
Nan
The models were pre-trained in JAX and TPU-v4 hardware and then later converted to PyTorch for sampling.
The training code in JAX will be released soon.
You may try to fine-tune the models in PyTorch using DeepSpeed:
Training code in JAX has been released: #16 (comment)
@jiang719 Here is DeepSpeed fine-tuning code with CPU parameter offloading, so that you should be able to avoid OOM:
https://github.com/salesforce/jaxformer/blob/main/jaxformer/hf/train.py