这个包是基于Keras框架的Vision Transformer(ViT)实现。 ViT模型由论文 "An image is worth 16x16 words: transformers for image recognition at scale" 提出。这个包使用在imagenet21K数据集和imagenet21K+imagenet2012数据集上的预训练权重,它们是.npz格式的。
-
Python >= 3.7
-
Keras >= 2.9
-
pip install keras-vit
-
构建标准架构的预训练VisionTransformer(ViT)模型
-
构建自定义参数的ViT模型以适用于不同任务
-
快速构建预训练 ViTB16
from keras_vit.vit import ViT_B16 vit = ViT_B16()
预训练ViT有四种配置 :ViT_B16,ViT_B32,ViT_L16 和 ViT_L32
配置 patch size hiddem dim mlp dim attention heads encoder depth ViT_B16 16×16 768 3072 12 12 ViT_B32 32×32 768 3072 12 12 ViT_L16 16×16 1024 4096 16 24 ViT_L32 32×32 1024 4096 16 24 数据集 "imagenet21k" 和 "imagenet21k+imagenet2012" 的预训练权重对应的模型参数有些许不同,如下表所示:
weights image size classes pre logits known labels imagenet21k 224 21843 True False imagenet21k+imagenet2012 384 1000 False True -
构建不同数据集下的预训练ViTB16
from keras_vit.vit import ViT_B16 vit_1 = ViT_B16(weights = "imagenet21k") vit_2 = ViT_B16(weights="imagenet21k+imagenet2012")
预训练权重(.npz)文件会自动下载到:C:\Users\user_name\.Keras\weights路径下。如果在下载过程意外中断,需要将该路径下的文件删除并重新下载,否则会报错。
若下载速度太慢,可手动下载(百度网盘),然后将文件放到上述路径中。
-
构建未进行预训练的ViT6
from keras_vit.vit import ViT_B16 vit = ViT_B16(pre_trained=False)
-
自定义参数构建预训练的ViT32
from keras_vit.vit import ViT_B32 vit = ViT_B32( image_size = 128, num_classes = 12, pre_logits = False, weights = "imagenet21k", )
当改变了预训练模型的参数,模型中某些层的参数会发生改变,这些层就不再读取预训练权重,而是随机初始化。对于未发生改变的层,预训练权重参数会正常加载到这些层中。可以通过
loading_summary()
方法查看每一层的加载信息。vit.loading_summary() >> Model: "ViT-B-32-128" ----------------------------------------------------------------- layers load weights inf ================================================================= patch_embedding loaded add_cls_token loaded - imagenet position_embedding not loaded - mismatch transformer_block_0 loaded - imagenet transformer_block_1 loaded - imagenet transformer_block_2 loaded - imagenet transformer_block_3 loaded - imagenet transformer_block_4 loaded - imagenet transformer_block_5 loaded - imagenet transformer_block_6 loaded - imagenet transformer_block_7 loaded - imagenet transformer_block_8 loaded - imagenet transformer_block_9 loaded - imagenet transformer_block_10 loaded - imagenet transformer_block_11 loaded - imagenet layer_norm loaded - imagenet mlp_head not loaded - mismatch =================================================================
-
通过实例化 ViT 类来构建自定义ViT模型
from keras_vit.vit import ViT vit = ViT( image_size = 128, patch_size = 36, num_classes = 1, hidden_dim = 128, mlp_dim = 512, atten_heads = 32, encoder_depth = 4, dropout_rate = 0.1, activation = "sigmoid", pre_logits = True, include_mlp_head = True, ) vit.summary() >> Model: "ViT-CUSTOM_SIZE-36-128" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= patch_embedding (PatchEmbed (None, 9, 128) 497792 ding) add_cls_token (AddCLSToken) (None, 10, 128) 128 position_embedding (AddPosi (None, 10, 128) 1280 tionEmbedding) transformer_block_0 (Transf (None, 10, 128) 198272 ormerEncoder) transformer_block_1 (Transf (None, 10, 128) 198272 ormerEncoder) transformer_block_2 (Transf (None, 10, 128) 198272 ormerEncoder) transformer_block_3 (Transf (None, 10, 128) 198272 ormerEncoder) layer_norm (LayerNormalizat (None, 10, 128) 256 ion) extract_token (Lambda) (None, 128) 0 pre_logits (Dense) (None, 128) 16512 mlp_head (Dense) (None, 1) 129 ================================================================= Total params: 1,309,185 Trainable params: 1,309,185 Non-trainable params: 0 _________________________________________________________________==========================
需要注意的是,
hidden_dim
参数需要能被atten_heads
参数整除。image_size
参数最好能被patch_size
参数整除。 -
将预训练权重加载到自定义ViT模型中
from keras_vit import utils, vit vit_custom = vit.ViT( image_size=128, patch_size=8, encoder_depth=4 ) utils.load_imgnet_weights(vit_custom, "ViT-B_16_imagenet21k.npz") vit_custom.loading_summary() >> Model: "ViT-CUSTOM_SIZE-8-128" ----------------------------------------------------------------- layers load weights inf ================================================================= patch_embedding mismatch add_cls_token loaded - imagenet position_embedding not loaded - mismatch transformer_block_0 loaded - imagenet transformer_block_1 loaded - imagenet transformer_block_2 loaded - imagenet transformer_block_3 loaded - imagenet layer_norm loaded - imagenet pre_logits loaded - imagenet mlp_head not loaded - mismatch =================================================================
-
微调
from keras_vit.vit import ViT_L16 # Set parameters IMAGE_SIZE = ... NUM_CLASSES = ... ACTIVATION = ... ... # build ViT vit = ViT_B32( image_size = IMAGE_SIZE, num_classes = NUM_CLASSES, activation = ACTIVATION, ) # Compiling ViT vit.compile( optimizer = ..., loss = ..., metrics = ... ) # Define train, valid and test data train_generator = ... valid_generator = ... test_generator = ... # fine tuning ViT vit.fit( x = train_generator , validation_data = valid_generator , steps_per_epoch = ..., validation_steps = ..., ) # testing vit.evaluate(x = test_generator, steps=...)
-
图像分类
from keras_vit import vit from keras_vit import utils # Get pre-trained vitb16 vit_model = vit.ViT_B16(weights="imagenet21k+imagenet2012") # Load a picture img = utils.read_img("test.jpg", resize=vit_model.image_size) img = img.reshape((1,*vit_model.image_size,3)) # Classifying y = vit_model.predict(img) classes = utils.get_imagenet2012_classes() print(classes[y[0].argmax()])
需要注意的是,由于目前包中没有imagenet21k数据集的标签文件,因此在应用预先训练的ViT进行图像分类时,请设置
“imagenet21-k+imagenet2012”
。若进行微调,则
“imagenet21k”
和“imagenet21k+imagenet2012”
都可用。
项目中fine_tuning_on_CIFAR10_demo.py为在cifar10数据集上微调的脚本,运行前需要将数据集解压后放到datasets文件夹中。