ViViT: A Video Vision Transformer
from vivit import (
SpatTempoAttnViViT , # 'Model 1: Spatio-temporal attention'
FactorEncViViT , # 'Model 2: Factorised encoder'
FactorSelfAttnViViT , # 'Model 3: Factorised self-attention'
)
# e.g.,
num_frames = 16
img_size = 224
video = torch .randn ((4 , 3 , num_frames , img_size , img_size ))
vit = timm .create_model ("vit_base_patch16_224" , pretrained = True )
tempo_patch_size = 4
spat_patch_size = 16
num_classes = 1000
pooling_mode = "cls"
model = FactorEncViViT (
vit = vit ,
num_frames = num_frames ,
img_size = img_size ,
tempo_patch_size = tempo_patch_size ,
spat_patch_size = spat_patch_size ,
num_classes = num_classes ,
pooling_mode = pooling_mode ,
)
device = torch .device ("cuda" )
model = model .to (device )
video = video .to (device )
out = model (video ) # (B, `num_classes`)
@misc{arnab2021vivit,
title={ViViT: A Video Vision Transformer},
author={Anurag Arnab and Mostafa Dehghani and Georg Heigold and Chen Sun and Mario Lučić and Cordelia Schmid},
year={2021},
eprint={2103.15691},
archivePrefix={arXiv},
primaryClass={cs.CV}
}