kohjingyu/fromage

The reproduction of FROMAGe training

Ziyang412 opened this issue · 6 comments

Hi! I am trying to reproduce the training of FROMAGe model using CC3M dataset, and the final output of CC3M val seems normal:

Computing similarity between torch.Size([12856, 256]) and torch.Size([12856, 256]).                                                      
 * Time 9.645 Loss 3.445 Acc@1 46.411 Acc@5 69.867 BLEU@4 0.063       

While I was trying to eval on VisDial dataset (after the ckpt pruning), I get the error below:
error_fromage
I print the dimension of these vector, it seems that the saved "ret_input_embeddings.weight" is [1,4096] dim instead of [4096].

To tackle this, I change the code in

model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())

to the code below (add a squeeze)

model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].squeeze(0).cpu().detach())

I evaluate on IT2T and get reasonable results, however, while I test on T2I settings, the results are even worst than random guessing.
top-k, k=1, acc=0.00000 top-k, k=5, acc=0.00097 top-k, k=10, acc=0.00242

Could you help me with this? Thank you so much!

Hi,

That's strange! It seems to suggest that the rest of the model is loaded correctly, except for the ret_embeddings weight (since this computes the text embeddings used in retrieval). Could you help to check two things?

  1. If you load the pretrained model in this repo for running VisDial evals, does the number look reasonable?
  2. If you use the saved checkpoint without pruning, does it work? You can probably do this by commenting out the following two lines:

fromage/fromage/models.py

Lines 678 to 679 in 92c6d6f

with torch.no_grad():
model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())

I find several difference between the model.args in my runs log and the one in /fromage_model directory, don't know whether this effect the weight loading for [RET] token. If so, is there any way I can fix it?
image

Hi,

That's strange! It seems to suggest that the rest of the model is loaded correctly, except for the ret_embeddings weight (since this computes the text embeddings used in retrieval). Could you help to check two things?

  1. If you load the pretrained model in this repo for running VisDial evals, does the number look reasonable?
  2. If you use the saved checkpoint without pruning, does it work? You can probably do this by commenting out the following two lines:

fromage/fromage/models.py

Lines 678 to 679 in 92c6d6f

with torch.no_grad():
model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())

Thank you for the reply.

For 1, yes, I can reproduce reasonable results using the pretrained model.

For 2, no, I tried but receive the same results with the pruned ckpt.

Hope to get fix soon, thank you in advance!

BTW, this is the training script I used, in case it helps. The only thing I change in the main.py code is the GPU number (https://github.com/kohjingyu/fromage/blob/92c6d6f6ea9cea38f0b0a12bcdb0cf3915d0e774/main.py#L190C1-L190C1), I set the ngpus_per_node as 4.

export NCCL_P2P_DISABLE=1 randport=$(shuf -i8000-9999 -n1) # Generate a random port number python -u main.py \ --dist-url "tcp://127.0.0.1:${randport}" --dist-backend 'nccl' \ --multiprocessing-distributed --world-size 1 --rank 0\ --dataset=cc3m --val-dataset=cc3m \ --opt-version='facebook/opt-6.7b' --visual-model='openai/clip-vit-large-patch14' \ --exp_name='fromage_train_exp_6.29_bs120_lr2_valbs80' --image-dir='/data/cc3m_dl/conceptual_caption/' --log-base-dir='runs/' \ --batch-size=120 --val-batch-size=80 --learning-rate=0.0002 --precision='bf16' --print-freq=100

Thanks for sharing that! I managed to reproduce the problem. This happens because when training with DDP, we have a module. prefix in the state_dict. So we were not restoring the weights at all, since they have different names (and we set strict=False to load pretrained OPT/CLIP weights). I've just pushed a commit to update the fromage/prune_model_ckpt.py script to remove the prefix:

stripped_state_dict = {
k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items() if

If you rerun the script and the evals, I think it should work as expected now, but please let me know if it doesn't!

Yeah, it works! Thank you so much!