CNN预测图片旋转角度
在下文提到的数据集上训练35个epoch(耗时30min)得到的平均预测误差为19.98°
,模型文件大小3.8MB
,可以轻松破解百度旋图验证码
测试效果如下
其中百度验证码图片来自RotateCaptchaBreak
-
一张显存4G以上的GPU
-
确保你的
Python
版本>=3.7
-
确保你的
PyTorch
版本>=1.11
-
拉取代码并安装依赖库
git clone https://github.com/Starry-OvO/Rotate-Captcha-Crack.git
cd ./Rotate-Captcha-Crack
pip install -r requirements.txt
-
我这里直接扒的
Landscape-Dataset
,你也可以自己收集一些风景照放到任意一个文件夹里,因为是自监督学习,所以不限制图像尺寸也不需要标注 -
在
config.yaml
里配置dataset.root
字段指向装有图片的文件夹 -
运行
prepare.py
准备数据集
python prepare.py
python train.py
python evaluate.py
Linux环境需要配置GUI或者自己把debug方法从显示图像改成保存图像
python test.py
-
搞这个项目的主要目的是练手
torchdata
,数据集的预处理与构建流程是本项目的精髓所在 -
现有的旋图验证码破解方法大多基于
RotNet (ICLR2018)
,其backbone为ResNet50
,将角度预测视作360分类问题,并计算交叉熵损失,本项目的RotationNet
是对RotNet
的简单改进 -
backbone为
regnet (CVPR2020)
的RegNetX 1.6GFLOPs
-
RotNet
中使用的交叉熵损失会令1°
和359°
之间的度量距离接近一个类似358°
的较大值,这显然是一个违背常识的结果,它们之间的度量距离应当是一个类似2°
的极小值。而RotNet
仓库(并未写入论文)给出的angle_error_regression
损失函数效果较差,因为该损失函数在应对离群值时梯度方向存在明显问题,你可以在后续的损失函数图像比对中轻松看出这一点 -
本人设计的损失函数
RotationLoss
和angle_error_regression
的思路相近,我使用最后的全连接层预测出一个角度值并与ground-truth
作差,然后在MSELoss
的基础上加了个余弦约束项来缩小真实值的±k*360°
与真实值之间的度量距离 -
为什么这里使用
MSELoss
,因为自监督学习生成的label
可以保证不含有任何离群值,因此损失函数设计不需要考虑离群值的问题,同时MSELoss
不破坏损失函数的可导性 -
该损失函数在整个实数域可导且几乎为凸,为什么说是几乎,因为当
lambda_cos>0.25
时在predict=±1
的地方会出现局部极小 -
最后直观比较一下
RotationLoss
和angle_error_regression
的函数图像
- angle_error_regression
- RotationLoss