kohjingyu/fromage

How are the inputs arranged for in-context retrieval evaluation?

ys-zong opened this issue · 5 comments

Hi, thanks for the nice work and code! I wonder how the input tokens are arranged for in-context retrieval evaluation? E.g. in Table 1 "5 captions, 4 images", are the input to the model like ([IMG][Caption])x4+[Caption] or [IMG]X4+[Caption]x5 or something else? I guess it's the former one?
Also, how are the in-context prompts selected - are they randomly selected or the pairs that have similar semantic features with the query caption are selected? It would be great if you could provide a code snippet for it. Many thanks!

Hi, thanks for your interest! For Table 1, the input is: ([Caption][IMG])x4+[Caption]+[RET]. We basically want to format the in-context examples similar to the retrieval, so we do [Caption][IMG] throughout.

I've uploaded the code for reproducing Table 1 here: https://github.com/kohjingyu/fromage/blob/main/evals/VIST_Contextual_Image_Retrieval.ipynb (there's also a .py script in the same folder, which does the same thing).

Note that there was a bug in a previous version of the code, so the results that this script produces (R@1 of 18.2) are actually better than the ones we have in the current version of our paper (R@1 of 15.6), and this version should reflect the correct way to call the model for retrieval. We'll update the paper shortly to reflect this.

Hope that helps!

Got it. Thank you very much!

Hi, just a quick follow-up question. The weights of ret_input_embeddings in this line doesn't seem to be in model.state_dict() when saving the model. I wonder if I should manually extract the ret_input_embeddings from the embedding layer with the ret token index?

Ah yes, sorry about that. We pruned the checkpoint to remove the pretrained weights/embeddings so that it was small enough to upload to GitHub :)

If you are using a model saved with L379 in main.py, you should be able to load the checkpoint by removing these two lines:

fromage/fromage/models.py

Lines 678 to 679 in 7d735cd

with torch.no_grad():
model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())

Simply put, if you did not prune the checkpoint, torch.load

checkpoint = torch.load(model_ckpt_path)

should be sufficient to load it.

I wonder if I should manually extract the ret_input_embeddings from the embedding layer with the ret token index?

You can also do this if you would like to prune the checkpoint. I've uploaded the script we used for this here (although it's untested). I think it would probably be easier if you just commented out the above lines for a custom checkpoint, though.

Hope that helps!

Thank you for the quick reply!