ViVit pos encoding
eyalmazuz opened this issue · 2 comments
eyalmazuz commented
in ViVit adding positional information is defined as follows:
def forward(self, video):
x = self.to_patch_embedding(video)
b, f, n, _ = x.shape
x = x + self.pos_embedding
but in ViT it defined as:
def forward(self, img):
x = self.to_patch_embedding(img)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
ViT definition allows at inference time to use less patches
but in ViViT we don't slice the position embedding thus forcing it to only accept inptus with shape (batch, num_frames, channels, width, height)
isn't it a problem?
shouldn't it be better to change ViViT to have the following?
def forward(self, video):
x = self.to_patch_embedding(video)
b, f, n, _ = x.shape
x = x + self.pos_embedding[:, :f, :n]
lucidrains commented
@eyalmazuz that makes sense! added it in 1.2.2
eyalmazuz commented
@lucidrains Thanks, closing