lucidrains/vit-pytorch

ViVit pos encoding

eyalmazuz opened this issue · 2 comments

in ViVit adding positional information is defined as follows:

def forward(self, video):
    x = self.to_patch_embedding(video)
    b, f, n, _ = x.shape

    x = x + self.pos_embedding

but in ViT it defined as:

def forward(self, img):
    x = self.to_patch_embedding(img)
    b, n, _ = x.shape

    cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
    x = torch.cat((cls_tokens, x), dim=1)
    x += self.pos_embedding[:, :(n + 1)]

ViT definition allows at inference time to use less patches
but in ViViT we don't slice the position embedding thus forcing it to only accept inptus with shape (batch, num_frames, channels, width, height)

isn't it a problem?

shouldn't it be better to change ViViT to have the following?

def forward(self, video):
    x = self.to_patch_embedding(video)
    b, f, n, _ = x.shape

    x = x + self.pos_embedding[:, :f, :n]

@eyalmazuz that makes sense! added it in 1.2.2

@lucidrains Thanks, closing