Question about ’cls_token‘ in the code
wscc123 opened this issue · 2 comments
wscc123 commented
Given the code in the pit.py, 'cls_token' how to calculate, thanks a lot!
def forward_features(self, x):
x = self.patch_embed(x)
pos_embed = self.pos_embed
x = self.pos_drop(x + pos_embed)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
for stage in range(len(self.pools)):
x, cls_tokens = self.transformers[stage](x, cls_tokens)
x, cls_tokens = self.pools[stage](x, cls_tokens)
x, cls_tokens = self.transformers[-1](x, cls_tokens)
cls_tokens = self.norm(cls_tokens)
return cls_tokens
bhheo commented
Hi
cls_token
is the same as the class token in the original vision transformer.
You can find it from timm repository
https://github.com/rwightman/pytorch-image-models/blob/01a0e2/timm/models/vision_transformer.py#L333
In the original ViT, cls_token
is concatenated with other spatial tokens x
.
So, there is no variable for cls_token
, it is included in x
.
But, in PiT, we handled cls_token
separately in pooling layers.
So, we make a variable for it.
wscc123 commented
这是来自QQ邮箱的假期自动回复邮件。
您好,我最近正在休假中,无法亲自回复您的邮件。我将在假期结束后,尽快给您回复。