salesforce/CodeGen

How did you train the large-sized models without out-of-memory?

jiang719 opened this issue · 3 comments

I would like to fine-tune the 2B model, but I got the out-of-memory issue even with the batch size setting to 1 (on a single GPU with 24G memory).

I wonder what devices you used to pre-train the 2B and 16B models? How did you address the memory issue? Did you parallel the model by layers on different GPUs? Thank you.

Nan

The models were pre-trained in JAX and TPU-v4 hardware and then later converted to PyTorch for sampling.

The training code in JAX will be released soon.

You may try to fine-tune the models in PyTorch using DeepSpeed:

https://news.ycombinator.com/item?id=32331764

Training code in JAX has been released: #16 (comment)

@jiang719 Here is DeepSpeed fine-tuning code with CPU parameter offloading, so that you should be able to avoid OOM:

https://github.com/salesforce/jaxformer/blob/main/jaxformer/hf/train.py