Pytorch实现的简单的Vision Transformer(ViT)分类任务。
- ViT_base_patch16_384模型pth格式的预训练权重如下。
- 链接:https://pan.baidu.com/s/1y1kOvlR9-OUUrZpRSkww3w
- 提取码:ampe
--JpegImages
--category_1
--001.jpg
--002.jpg
...
--category_2
--001.jpg
--002.jpg
...
...
新建main_txt文件夹,用于存放划分后的训练集与测试集信息。修改generate_txt.py文件中的图片集指向及txt存放位置指向,运行:
python generate_txt.py
运行完毕后,生成train.txt与test.txt, 存放至main_txt文件夹下。
修改train.py中的batch_size参数及类别数量,调整学习率及相关参数,运行:
python train.py
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py