CVMI-Lab/SlotCon

ViT training ?

Opened this issue ยท 24 comments

Hello,

Thank you for your very interesting work ! I'm currently trying to replicate your results with your provided codebase and I was wondering whether you also tested a Vision Transformer architecture as encoder ? You compared in the paper with DINO, but I wanted to know if you where able to get some properties close to what they obtained (a kind of saliency map with the attention map around the object of interest).

Thank you again for your response !

Hi @alexcbb, thanks for your attention to our work!

We actually didn't thoroughly experiment with ViTs due to computation constraints. Regarding the object-centric attention maps of DINO, we believe that is a merit of Transformers, and for CNNs we need to find another path. Our method explores doing it via explicit clustering on top of CNN features, which indeed worked.
Besides that, we also tried to find similar visualizations within CNNs themselves, and we found PCA on dense feature maps produced plausible results. Due to the hierarchical structure of CNN, the resulting visualization's resolution is relatively low. We tried tricks like modifying the stride, which did not help much.
Hope that is helpful for you!

Thank you for your quick answer !

I'm would be very interested to explore whether such training would be beneficial for Vision Transformer (even for a small version like ViT-S 16) : I'm first trying to check whether I can reproduce your results with ResNet and then wants to apply it to ViT. I think this could be beneficial to extract object knowledge to some extent and bring some prior for the training, and more again on scene-centric datasets.

Can I ask for your help in this process ? (mainly on the replication of the results)

Thanks again !

Feel free to leave a message if there is trouble working on that.

For the pre-training on COCO it is indicated that it was performed on 8 GPU NVIDIA 2080 Ti for 800 epochs. Do you have maybe an average time required for such training, and eventually some memory consumption information ?

It should took up almost all memories of 8x2080 Ti, roughly 80GB in total. I do not remember well the precise time it took for training, maybe roughly 2~3 days?

I made some small changes to launch the training (I created a Pytorch Lightning module to ease the deployment on clusters) and began to launch a 800 epoch training on COCO. Here is an overview of the current evolution of the loss (it is now at around 230 epochs after ~1 day of training), does it seems to be a right convergence curve ?
Screenshot from 2023-12-19 17-30-11

It seems that I'm not able to replicate your figure 3 after pre-training and I don't understand why, the prototypes seems a bit weird and there's no mask on my final image

Fig 3 is simply produced using viz_slots.py with the default configs. The model is the default model on coco, with 800 epochs of training. Please check if there are any errors in your reimplementation.

Hello, the problem was in my visualization file, it seems that I'm now able to obtain well aligned concepts ! I've seen in the paper that you say one would need to scale the loss according to the batch size (if we would augment its size). Can you maybe tell me more about this ? (I've trained my model using your default parameters and batch size of 1536 without any huge problem on the results)

Hi, we scale the learning rate linearly with the batch size, as done by many previous works. This part is already implemented in the code, and basically no more modification is needed for you:

lr=args.batch_size * args.world_size / 256 * args.base_lr,

Hello, I would have a question concerning the slot loss part. You specify in equation (5) a masking over the slots that do not occupy dominating groups, that you then use for the computing of an InfoNCE loss. I was wondering in the code with the ctr_loss_filtered function of SlotCon why you would use the mask_intersection over mask_q to select the slots of q ? Is it in order to avoid slots that do not have positive pair in k, or is it another explanation ? Thank you in advance. Did you make any ablations on this masking on whether it was helping or no for the training ?

Your anticipation is correct, this is to make sure they form a positive pair, such that both the query and key slots exist across views. From my memory, we didn't ablate much on that.

Your anticipation is correct, this is to make sure they form a positive pair, such that both the query and key slots exist across views. From my memory, we didn't ablate much on that.

Ok, thank you again for your answer. Did you ever encounter the case where there's no positive pair in the views ? While I was trying to train the model with a ViT backbone, I obtained a NaN loss and the issue comes from the slot loss. At a certain point, it is not able to get any intersection mask. If you have encountered this during your experiments it would help me a lot !

Well actually I can't recall well about the details..., you may consider dropping that pair in this case

Hi, @alexcbb @xwen99

I have just conducted a small-scale experiment with ViT-S on COCO for 100 epochs.
The rest of the settings can be found below.

Prototype visualization makes sense but is weird. I am checking the code now.
I would appreciate any hints or suggestions.

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --master_port 12348 --nproc_per_node=8 \
    main_pretrain.py \
    --dataset COCO \
    --data-dir ${data_dir} \
    --output-dir ${output_dir} \
    \
    --arch vit_small \
    --dim-hidden 4096 \
    --dim-out 256 \
    --num-prototypes 256 \
    --teacher-momentum 0.99 \
    --teacher-temp 0.07 \
    --group-loss-weight 0.5 \
    \
    --batch-size 256 \
    --optimizer adamw \
    --base-lr 5e-4 \
    --weight-decay 0.04 \
    --warmup-epoch 5 \
    --epochs 100 \
    --fp16 \
    \
    --print-freq 10 \
    --save-freq 50 \
    --auto-resume \
    --num-workers 12

slotcon_vits_coco_100eps

Hi, @alexcbb @xwen99

I have just conducted a small-scale experiment with ViT-S on COCO for 100 epochs. The rest of the settings can be found below.

Prototype visualization makes sense but is weird. I am checking the code now. I would appreciate any hints or suggestions.

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --master_port 12348 --nproc_per_node=8 \
    main_pretrain.py \
    --dataset COCO \
    --data-dir ${data_dir} \
    --output-dir ${output_dir} \
    \
    --arch vit_small \
    --dim-hidden 4096 \
    --dim-out 256 \
    --num-prototypes 256 \
    --teacher-momentum 0.99 \
    --teacher-temp 0.07 \
    --group-loss-weight 0.5 \
    \
    --batch-size 256 \
    --optimizer adamw \
    --base-lr 5e-4 \
    --weight-decay 0.04 \
    --warmup-epoch 5 \
    --epochs 100 \
    --fp16 \
    \
    --print-freq 10 \
    --save-freq 50 \
    --auto-resume \
    --num-workers 12

slotcon_vits_coco_100eps

Hello, on my side I was not able to make the training converge properly. The slot loss is returning time to time a NaN on masking and I don't know why this is happenning. Can I maybe know what changes you made to replace the backbone ? Concerning the hyperparameters I've got the same as yours (I've took the same hyperparameter as DINO training). Your prototypes looks coherent for me, what does seems weird for you ? I would gladly have a discussion with you about your re-implementation if you agree !

Hi, @alexcbb @xwen99
I have just conducted a small-scale experiment with ViT-S on COCO for 100 epochs. The rest of the settings can be found below.
Prototype visualization makes sense but is weird. I am checking the code now. I would appreciate any hints or suggestions.

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --master_port 12348 --nproc_per_node=8 \
    main_pretrain.py \
    --dataset COCO \
    --data-dir ${data_dir} \
    --output-dir ${output_dir} \
    \
    --arch vit_small \
    --dim-hidden 4096 \
    --dim-out 256 \
    --num-prototypes 256 \
    --teacher-momentum 0.99 \
    --teacher-temp 0.07 \
    --group-loss-weight 0.5 \
    \
    --batch-size 256 \
    --optimizer adamw \
    --base-lr 5e-4 \
    --weight-decay 0.04 \
    --warmup-epoch 5 \
    --epochs 100 \
    --fp16 \
    \
    --print-freq 10 \
    --save-freq 50 \
    --auto-resume \
    --num-workers 12

slotcon_vits_coco_100eps

Hello, on my side I was not able to make the training converge properly. The slot loss is returning time to time a NaN on masking and I don't know why this is happenning. Can I maybe know what changes you made to replace the backbone ? Concerning the hyperparameters I've got the same as yours (I've took the same hyperparameter as DINO training). Your prototypes looks coherent for me, what does seems weird for you ? I would gladly have a discussion with you about your re-implementation if you agree !

Sure. You can send me an email. ([Update] - my email: kjliu@vision.is.tohoku.ac.jp)
For a quick answer, I tried to make minimum changes. Specifically,

  • Borrowed the ViT-S implimentation from DINOv1, and made its output a 4D torch.Tensor: [B, C, H, W].
    return x[:, 1:].transpose(-2, -1).reshape(-1, self.embed_dim, h, w)

  • Change the num_channel of SlotCon(nn.Module) to 384.
    self.num_channels = 384

  • Use an AdamW optimizer as described above

Regarding prototype visualization, I found there are some empty prototypes, and in the 4th column from the right-hand side, it shows "cat", "cow" and "bear" while the ResNet50 based one can output a pure cat prototype. I would say the semantic consistency is lower in ViT-S based one. Again, I am not sure if this behavior is correct, or not.

@KJ-rc
Ok thank you, can you provide me your mail, I'm not able to find it.
Concerning the implementation from DINO you trained it from scratch, right ? I will try as you said, maybe there are some issues in my code I didn't saw and let you know if I have the same artifact as yours. But I'm pretty sure I've done the same as you did (very few changes)

But yes you are right I didn't saw the empty prototypes ! This is quite weird indeed. Maybe some changes in the hyperparameters can make it a bit better (like the temperature from the student/teacher ?). What about also let it train for longer (like 300~400 epochs) ? I will let you know as soon as I'm able to train the ViT from scratch on my side

@KJ-rc Just to be sure, as DINO do not use BatchNorm in its projection Head, did you also removed the BatchNorm (and subsequently the SyncBatchNorm calls in SlotCon) ?

Hi,
I did only the modifications listed above.
I consider projectors to have a higher dependency on pre-trained methods rather than backbone architecture,
so I keep the batch norm layer.

It seems that by reducing the batch size to 256 and adding a gradient clipping the training is now working. I'll see how it evolves and let you know about my final results !

Hello, it seems that I've got the same issue as @KJ-rc when training SlotCon with ViT but I think this issue is because there are "dead slot" appearing during the training as said in the annex D. I've tested to print out 100 slots to check for the semantics and over those 100, around 20 were "dead slot" without any meaning. It seems to be quite related to the discussion in Annex D, what do you think ?
res3

See if this paper helps you understand the dead slots: https://openreview.net/forum?id=Z2dVrgLpsF

See if this paper helps you understand the dead slots: https://openreview.net/forum?id=Z2dVrgLpsF

Thank you for the sharing. It would be interesting to evaluate whether the problem would disappear with such regularization !