arcee-ai/DALM

Large values of casual loss

Closed this issue · 3 comments

Hi! Thank you for the great work.
Could you please share what the expected value of marginalize_casual_loss is at the beginning and at the end of training?
I am getting values around 70-100, and I am not sure if I should be alarmed by this.

Seems like something is wrong! What was the batch size?

I found an error, now the loss is around 10 at the start of the training. I made some modifications to your code, so the mistake was mine. When computing the loss, I used the sum over tokens instead of the mean, similar to the original training code in the hf RAG repository. With the sum, it is expected to have a large loss which depends on the number of tokens. So, what is the reasoning behind that? Did I misunderstand it somehow?

By the way, as far as I am aware, your code does not allow for training only the query encoder and not the context encoder, as in the original RAG paper. Are there any plans to add this feature in the future?

@yashkens Yeah sum will give you large values and can mess up the training.

In this repo, we didn't use a duel encoder for a retriever.