naver-ai/pit

Question about ’cls_token‘ in the code

wscc123 opened this issue · 2 comments

Given the code in the pit.py, 'cls_token' how to calculate, thanks a lot!

def forward_features(self, x):
    x = self.patch_embed(x)

    pos_embed = self.pos_embed
    x = self.pos_drop(x + pos_embed)
    cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)

    for stage in range(len(self.pools)):
        x, cls_tokens = self.transformers[stage](x, cls_tokens)
        x, cls_tokens = self.pools[stage](x, cls_tokens)
    x, cls_tokens = self.transformers[-1](x, cls_tokens)

    cls_tokens = self.norm(cls_tokens)

    return cls_tokens
bhheo commented

Hi

cls_token is the same as the class token in the original vision transformer.
You can find it from timm repository
https://github.com/rwightman/pytorch-image-models/blob/01a0e2/timm/models/vision_transformer.py#L333

In the original ViT, cls_token is concatenated with other spatial tokens x.
So, there is no variable for cls_token, it is included in x.
But, in PiT, we handled cls_token separately in pooling layers.
So, we make a variable for it.