神经网络与机器学习实验代码


任务

  • Swin Transformer预训练模型处理花分类任务

小组成员

  • 谢胜达
  • 曾旭翔
  • 沈翁宇
  • 王建煜

代码使用简介

  1. 下载好数据集,代码中默认使用的是花分类数据集,下载地址: https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
  2. train.py脚本中将--data-path设置成解压后的flower_photos文件夹绝对路径
  3. 下载预训练权重,在model.py文件中每个模型都有提供预训练权重的下载地址,根据自己使用的模型下载对应预训练权重
  4. train.py脚本中将--weights参数设成下载好的预训练权重路径
  5. 设置好数据集的路径--data-path以及预训练权重的路径--weights就能使用train.py脚本开始训练了(训练过程中会自动生成class_indices.json文件)
  6. predict.py脚本中导入和训练脚本中同样的模型,并将model_weight_path设置成训练好的模型权重路径(默认保存在weights文件夹下)
  7. predict.py脚本中将img_path设置成你自己需要预测的图片绝对路径
  8. 设置好权重路径model_weight_path和预测的图片路径img_path就能使用predict.py脚本进行预测了