UKPLab/sentence-transformers

Inefficient loss calculation in cached losses

Marcel256 opened this issue · 2 comments

Hello, during my tests I discovered a problem with my PR to enable the combination of the Matryoshka loss and the cached losses (#3068 ). I moved the loss.backward() call out of the minibatch loop. Therefore, the function produces one big computation graph containing the loss calculation of all mini batches, which defeats the purpose of doing this computation in mini-batches. This drastically increases the memory consumption and can easily lead to out of memory errors. I am already working on a fix for this issue.

I am sorry for any inconvenience this will cause.

Well spotted, thank you @Marcel256. I didn't see this when I was reviewing, I thought the removed section in calculate_loss_and_cache_gradients was identical to the section in calculate_loss. Perhaps we can add a parameter to calculate_loss whether the backward should be called in minibatch or not, although I'll gladly await your fix.

Also, don't stress it - the faulty commit from #3068 hasn't been included in any release yet, so nothing bad happened yet. In my time with this project, I've probably introduced 20 bugs like this 😋

  • Tom Aarsen

I created a first draft PR #3114
I will also run some more tests today with bigger batch sizes