- 多目标多分类基础模型,模型采用ResNet_101,ImageNet预训练。
- 训练集与验证集采用MS COCO。
- python
- torch
- torchvision
- numpy
- PIL
- matplotlib(如果需要展示图像)
使用MS COCO数据集。
图像处理方式:
- 随机裁剪
- 随机翻转
- 归一化
训练集位置:
"ip:/home/datasets/qishuo/coco/train2017"
验证集位置:
"ip:/home/datasets/qishuo/coco/val2017"
修改路径参数、超参数等统一:vim xxx.python
- Clone:
git clone https://github.com/MagicianQi/multi_label_image_classification
cd ./multi_label_image_classification
- 进入虚拟环境(在172.30.1.118上):
source /home/qishuo/venv/bin/activate
- 生成MS COCO分类标签:
- 生成标签文件:
python ./scrpits/generate_coco_labels.py
- 生成标签文件:
- 训练:
- 训练 :
python resnet_multi_label.py
- 后台训练:
screen python resnet_multi_label.py
- 训练 :
- 验证:
- 计算准确率(acc):
python evaluate_multi_label_acc.py
- 计算召回率(recall):
python evaluate_multi_label_recall.py
- 计算准确率(acc):
- 测试:
- 测试图像结果:
python testing_multi_label.py 1.jpg 2.jpg 3.jpg
- 测试图像结果:
Image:
Result(Top5):
['person', 0.9997881], ['cell phone', 0.7692987], ['tv', 0.6630157], ['laptop', 0.5872826], ['tie', 0.24580538]