基于 OCaml 语言实现 BP 神经网络。
最终,使用 y=x^2 函数模型 和 mnist 手写数字数据集进行训练和预测。
bin/
: 目标文件目录data/
: mnist手写数字数据集目录src/
: 源代码目录BP.ml
: BP神经网络模块BP.mli
: BP模块的接口文件main.ml
: 运行入口文件,主要用于测试BP和mnist模块mnist.ml
: 解析mnist数据集模块mnist.mli
: mnist模块的接口文件util.ml
: 工具模块,提供打印 Array 数组等接口util.mli
: 工具模块的接口文件
README.md
: 说明文档Makefile
: 工程脚本文件
- 实现BP神经网络算法。完成者:dounine
- 初始化
- 训练
- 预测
- 使用 y=x^2 模型进行测试
- 解析 mnist 手写数字数据集,并进行测试。完成者:TayMarine
- 解析 mnist 数据集
- 显示手写数字图片
- 使用 BP 神经网络进行预测
BP神经网络模块中的关键数据结构,如下:
x_count : int
: 输入层神经元个数mid_count : int
: 中间层神经元个数y_count : int
: 输出层神经元个数x2mid_weight : float[][]
: 输入层到中间层的权值mid2y_weight : float[][]
: 中间层到输出层的权值mid_theta : float[]
: 中间层阈值y_theta : float[]
: 输出层阈值eta : float
: 学习率precision : float
: 误差精度max_train_count : int
: 最大训练次数
image_info: type
: 用于存储图片数据的记录类型,有四个分量label_info : type
: 用于存储标签数据的记录类型,有两个分量pixel: int[][]
: 二维数组存储所有图片的28*28的像素值,每张图片的像素描述归为一个一维数组single : int[][]
: 存储一张图片的28*28的像素值label_num : int[][]
: 存储所有标签数据
-
BP
模块(详情见BP.mli):init(x_count, mid_count, y_count, eta=03, max_train_count=100, precision=0.0001)
: 初始化神经网络train(X_list, Y_list)
: 训练神经网络predict(X_list)
: 对输入数据进行预测
-
mnist
模块(详情见mnist.mli):get_data(num)
: 获取mnist数据集中的数据draw (pixel_ary, color)
: 画图
-
util
模块(详情见util.mli):print_row(row)
: 打印一维 Array 数组print_matrix(matr)
: 打印二维 Array 数组get_subMatrix(matr, start_pos, end_pos)
: 获取二维数组的子数组scatter(x_arr, y_arr, color)
: 画散点图
主要对train()
函数进行算法设计:
train(X_list, Y_list):
for 1 -> max_train_count: # 重复训练
float E[] # 每一个样例的均方误差
for X,Y in X_list,Y_list: # 遍历所有样例
Y_out, mid_out = compute_Y(X) # 根据 X 计算 Y
E.append(compute_e(Y, Y_out)) # 根据预测值和实际值,计算均方误差
mid2y_grad = compute_mid2y_grad(Y, Y_out) # 计算中间层到输出层的梯度项
x2mid_grad = compute_x2mid_grad(mid2y_grad) # 计算输入层到中间层的梯度项
update_mid2y(mid2y_grad) # 更新中间层到输出层的权值和输出层的阈值
update_x2mid(x2mid_grad) # 更新输入层到中间层的权值和中间层的阈值
E_mean = sum(E)/len(E) # 训练完所有样例后,计算累计误差
if (E_mean < precision) # 如果误差小于精度,则不进行重复训练
break
图片文件中前十六个字节为标记字节,前四个字节为魔数(magic number),接下来的四个字节为图片数量(值为60,000),在之后的八个字节分别是每张图片像素点的行数和列数,都是四字节数据。剩下的都是图片像素数据,每个像素点为1个字节,属于灰度图,784个字节为1张图片(28x28=784)。
标签文件中只有前八个字节是标记,与图片文件的前八个字节意义相同。剩下的为标签数据,每个标签1字节,其有效数据的范围在0~9之间。
using in linux recommend.
compile:
make
run:
make run
clean:
make clean
git commit -m "type: description"
- type:
- feat:新功能(feature)
- fix:修补bug
- docs:文档(documentation)
- style:格式(不影响代码运行的变动)
- refactor:重构(即不是新增功能,也不是修改bug的代码变动)
- test:增加测试
- chore:构建过程或辅助工具的变动
- description: 详细描述