KVishnuVardhanR/M3T

Token selection of the claasification head

Closed this issue · 2 comments

Hello,

Thank you so much for your excellent work on implementing the M3T paper. I've found your implementation extremely helpful in deepening my understanding of the paper.

While reviewing your code, I noticed in the classification head, you've utilized the "mean" operator to calculate the average across all token sets, subsequently using this "mean token" for the classification task. However, in the original ViT paper, the authors only employed "cls_token" for classification purposes. Could you kindly shed some light on the rationale behind choosing a different approach?

I greatly appreciate your time and look forward to your response. Thank you!

Hi,

Thanks for pointing it out, I have updated the code according the original ViT paper and the readme as well. But I think according to https://github.com/FrancescoSaverioZuppichini/ViT implementation, The purpose of using this 'mean token' for classification task, instead of relying solely on the cls_token, this approach takes the mean of all the encoded tokens' embeddings. By averaging across all tokens, the model potentially uses a more holistic representation of the image, as each token contains information about different parts of the image.

But according pytorch torchvision's documentation: https://github.com/pytorch/vision/blob/main/torchvision/models/vision_transformer.py#L303, they have followed according the ViT paper and used cls_token for classification purposes.

Thank you!

Thank you for your prompt response and detailed explanation!
In my view, if the model employs a "mean token" approach for classification tasks, incorporating the "cls_token" into the network might be unnecessary, as the "mean token" seems to sufficiently encapsulate the essential information to represent image features. However, should the model include the "cls_token," I believe utilizing this token would be preferable. The "cls_token" has the ability to attend to all "image tokens," thereby obtaining a comprehensive representation that is well-suited for executing classification tasks.