Pints-AI/1.5-Pints

Is there any parameter to speedup pretrain ?

Closed this issue · 1 comments

I am reprducing the pretrain/main.py stage on H800 80G * 8, here is what I have tried:

  • --micro_batch_size 12
  • torch.set_float32_matmul_precision('medium')
  • make sure get_default_supported_precision return bf16-mixed

Now its memory cost is 70GB and needs 3.5 days.

Would you please give more advices for speedup the pretrain trainning ?

@tpoisonooo Yes that would be it. Almost all the optimisations are already baked in. If you could afford, don't use torch.set_float32_matmul_precision('medium'). I believe that defaults to high.