In this work, we focus on vision transformers for classification tasks. Transformers have been widely used is NLP and time series. We report the results of the attention mechanism to capture and extract features. The model architecture is based on vit-pytorch. We also add the possibility to visualize the attention map.
Following the original implementation here, we use the following parameters for each ViT model.
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
image_size
: int
Image size. If you have rectangular images, make sure your image size is the maximum of the width and height.patch_size
: int
Number of patches.image_size
must be divisible bypatch_size
.
(The number of patches is: n = (image_size // patch_size) ** 2 and n must be greater than 16.num_classes
: int
Number of classes to classify. (outputs)dim
: int
Last dimension of output tensor after lineartransformation nn.Linear(..., dim)
.depth
: int
Number of Transformer blocks.head
: int
Number of heads in Multi-head Attention layer.mlp_dim
: int
Dimension of the MLP (FeedForward) layer.channels
: int, default3
Number of image's channels.dropout
: float between[0, 1]
, default0
.
Dropout rate.emb_dropout
: float between[0, 1]
, default0
.
Number of image's channels.pool
: string
Eithercls
token pooling ormean
pooling.
Install dependencies using this command:
pip install -r requirements.txt
To train the ViT on image classification you can use this command:
python main.py --name trial