kohjingyu/fromage

The cross entropy loss in training stage

Ziyang412 opened this issue · 2 comments

Hi, thank you for the interesting work.

From the teaser image(https://github.com/kohjingyu/fromage/blob/main/teaser.png), I noticed two cross entropy loss, one in image caption and one in retrieval. But according to the total loss in the original paper, I only saw the CE loss in the image caption stage. So, does the second CE loss applied in the training stage or am I missing anything?

Thank you in advance!

Thanks for pointing this out! In this case, the second CE loss is indeed also applied during training. This is to allow the model to learn the proper [RET] embeddings. The retrieval CE loss is applied at lines 453-454 here:

fromage/main.py

Lines 453 to 454 in 92c6d6f

elif model_mode == 'retrieval':
ce_loss = ce_loss * args.ret_loss_scale

Gotcha! Thanks!