/vgg16_classifiter_pytorch

手动一步步搭建vgg16,对图片进行分类,并使用tensorboard进行记录

Primary LanguagePython

Pytorch图片分类器

实验设置:

  • 训练数据为Cifar-10,训练数据50000帧2242243的图片,测试数据10000帧2242243图片
  • 损失函数为交叉熵
  • Optimmizer为SGD,其中lr=0.001,未使用学习率衰减
  • 搭建了VGG16模型
  • 数据增强:未使用数据增强

评价标准:

Precision:TP/(TP+TN)

评估结果:

Train_loss:

Train_precision:

Val_loss:

Val_precision:

实验结论:

  • 本次实验在训练集上precision约99%,在验证集上约69%,即模型的泛化性不好,可以通过训练时使用数据增强来增强泛化性。
  • Cifar-10数据集中图像的尺寸比较小,经过vgg16多次卷积后图像尺寸都变为1,会在全连接层之前就丢失了位置信息,所以可以将模型换为Resnet50等具有跳层链接的网络可以会好些。
  • 若还要对模型进行改进,应该增加评价指标,多分类的recall,f1,map等指标来进行评估,再对误检、漏检数据进行分析才好进一步改进。