最后在test.py文件中测试一下识别准确率怎么样?
正确率可达50% 及以上
CIFAR-10 数据集是一个广泛用于图像分类任务的数据集。它包含 10 个类别的 60,000 张 32x32 彩色图像,每个类别有 6,000 张图像。其中 50,000 张图像用于训练,10,000 张图像用于测试。
运行代码时,数据集将自动下载到 ../data 目录中。
模型架构在 TrainModel 类中定义,该类应在 model.py 文件中实现。该模型旨在将 CIFAR-10 数据集中的图像分类到 10 个类别之一。
训练过程由主脚本控制。模型使用随机梯度下降(SGD)优化器和交叉熵损失函数进行训练。
要开始训练,只需运行以下脚本:
python train.py
训练过程将记录损失和准确率等指标到 TensorBoard,可以通过以下命令可视化这些指标:
tensorboard --logdir=../log_train
每个 epoch 结束后,模型将在测试数据集上进行评估。评估指标(包括损失和准确率)也会记录到 TensorBoard 中。
最终训练好的模型将保存为根目录下的 train_model.pth 文件。
您可以加载并使用此模型进行推理或进一步训练:
import torch
model = torch.load("train_model.pth")
model.eval()