- 下载好数据集,代码中默认使用的是花分类数据集,下载地址: https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz, 如果下载不了的话可以通过百度云链接下载: https://pan.baidu.com/s/1QLCTA4sXnQAw_yvxPj9szg 提取码:58p0
- 在
train.py
脚本中将--data-path
设置成解压后的flower_photos
文件夹绝对路径 - 下载预训练权重,在
vit_model.py
文件中每个模型都有提供预训练权重的下载地址,根据自己使用的模型下载对应预训练权重 - 在
train.py
脚本中将--weights
参数设成下载好的预训练权重路径 - 设置好数据集的路径
--data-path
以及预训练权重的路径--weights
就能使用train.py
脚本开始训练了(训练过程中会自动生成class_indices.json
文件) - 在
predict.py
脚本中导入和训练脚本中同样的模型,并将model_weight_path
设置成训练好的模型权重路径(默认保存在weights文件夹下) - 在
predict.py
脚本中将img_path
设置成你自己需要预测的图片绝对路径 - 设置好权重路径
model_weight_path
和预测的图片路径img_path
就能使用predict.py
脚本进行预测了 - 如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的
num_classes
设置成你自己数据的类别数