jzhang38/EasyContext

how to acquire the real whole batch sequenece training loss(reduction_mode=mean) ?

Opened this issue · 2 comments

in the train.py, the loss return from main process is the loss of one sequence block, not the whole sequence loss.

gathered_loss = accelerator.reduce(loss.clone().detach(), "mean")

It is the whole sequence loss?

3Q very much