- Swin Transformer预训练模型处理花分类任务
- 谢胜达
- 曾旭翔
- 沈翁宇
- 王建煜
- 下载好数据集,代码中默认使用的是花分类数据集,下载地址: https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
- 在
train.py
脚本中将--data-path
设置成解压后的flower_photos
文件夹绝对路径 - 下载预训练权重,在
model.py
文件中每个模型都有提供预训练权重的下载地址,根据自己使用的模型下载对应预训练权重 - 在
train.py
脚本中将--weights
参数设成下载好的预训练权重路径 - 设置好数据集的路径
--data-path
以及预训练权重的路径--weights
就能使用train.py
脚本开始训练了(训练过程中会自动生成class_indices.json
文件) - 在
predict.py
脚本中导入和训练脚本中同样的模型,并将model_weight_path
设置成训练好的模型权重路径(默认保存在weights文件夹下) - 在
predict.py
脚本中将img_path
设置成你自己需要预测的图片绝对路径 - 设置好权重路径
model_weight_path
和预测的图片路径img_path
就能使用predict.py
脚本进行预测了