零样本学习(zero-shot learning)是在已知类别上训练物体识别模型,要求模型能够识别来自未知类别的样本。对图像理解、(从已知类别到未知类别的)知识迁移具有重要意义。
- Python 3.5
- PyTorch 0.4
使用 AI Challenger 2018 的图像属性数据集,本数据集共78,017张图片、230个类别、359种属性。
图片可划分为5个超类(super-class),分别是动物(Animals)、水果(Fruits)、交通工具(Vehicles)、电子产品(Electronics)、发型(Hairstyles)。其中,动物和水果属于自然产物,交通工具和电子产品属于人造物,发型属于抽象概念。每个超类分别包含A: 50, F: 50, V: 50, E: 50, H: 30 个类别,总计230个类别。对于每个超类(super-class),分别设计了A: 123, F: 58, V: 81, E: 75, H: 22 个属性,共359个属性。每张图片只包含一个前景物体,标注了标签和物体包围框。对于每个类别,随机挑选了20张图片进行属性标注。
- 训练集(seen classes):80%类别
- 测试集(unseen classes):20%类别
标注示例图:
你可以从这里下载该数据集。
提取78,017张图片及相应的标注文件:
$ python pre-process.py
$ python train.py
如果想可视化训练过程,请执行:
$ tensorboard --logdir path_to_current_dir/logs
各超类训练结束最佳的验证集准确率和损失为:
度量 | 动物 | 水果 | 交通工具 | 电子产品 | 发型 |
---|---|---|---|---|---|
ACCURACY | 97.222 | 82.755 | 92.500 | 90.790 | 53.854 |
LOSS | 0.025 | 0.044 | 0.023 | 0.051 | 0.077 |
下载下列预训练模型放在 models 目录然后执行:
$ python demo.py -s "Animals"
此处超类可以是5个超类中任意一个。