/mnist

Primary LanguageC++GNU General Public License v3.0GPL-3.0

/**
*@Author: fengqi
*@Email: 2607546441@qq.com
*/
本程序使用全连接神经网络进行手写数字识别的训练和预测。当然修改一下输入和输出节点数,调整网络层数,也可用于其他多分类或回归问题。
代码结构参考了 yolo(You Only Look Once) 项目源码框架 darknet.

目录文件介绍:
    --mnist/ 存放的是 mnist 数据集原始二进制文件;
    --obj/ 存放的是编译生成的 .obj 文件;
    --backup/ 存放的是每轮训练过程中生成的权重文件
    --testData/ 是运行 ./data 程序读取 mnist 测试数据集生成的测试图片(如 0.jpg, 1.jpg...)和对应的标签文件 testLabel.txt;
    --trainData/ 是运行 ./data 程序读取 mnist 训练数据集生成的训练图片(如 0.jpg, 1.jpg...)和对应的标签文件 trainLabel.txt;
    --layer.cpp/.h 是全连接层的类定义及实现,主要是分配层的计算数据存储空间和前向计算反向传播以及参数更新函数定义;
    --network.cpp/.h 是网络的类定义及实现,主要是定义了网络中,全连接层的添加,网络的前向传播,反向传播等函数;
    --mnist.cpp 是读取 mnist/ 下的二进制文件,生成相应的图片,便于可视化和图片读写;
    --main.cpp 是主函数入口文件,里面实现了网络训练及验证,以及预测功能,并且实现了在训练网络完成后保存网络各层的权重到文件里,
        方便下次训练或预测时随时载入权重,不用重新训练网络;

使用方法介绍:
    --本程序使用 Makefile 进行项目管理构建,只需在终端输入 make 命令,即可生成 data 和 run 两个可执行文件;
    --注意,本程序生成和读写图片使用了 opencv,所以请确保电脑上安装并配置好了 opencv 开发环境。
        1、运行 ./data 可读取 mnist/ 数据文件,生成训练和测试用的图片;
        2、运行 ./run train 可训练网络,网络训练完成后,会生成 mnist.weight 网络权重文件;
        3、运行 ./run test mnist.weight 可进行图片识别预测,只需输入图片文件名,按 ctrl-c 停止即可;

主要参数介绍:
    --程序里的主要需要修改的参数有,训练迭代次数 epoches 和学习率 learning rate, 在构建 network 对象时传入;
    --全连接层的个数及各层神经元数量和激活函数类型,可通过 network->addLayer(int node,ACTIVATION activate) 调整;
    --学习率变化调整,默认调整方式是每次迭代减小 0.01,当小于 0.01 时,固定使用 0.01 作为学习率;