trapoom555/Language-Model-STS-CFT

Efficient Loss Calculation with `all_gather` to Achieve Even More Batch Size

Closed this issue · 1 comments

With this approach, the batch size can linearly scale with the number of GPUs

  • Improved training time (1k steps) from 1:30 hrs to 15 mins
  • x4 batch sizes (according to the number of GPUs in the system)