/GeoMAN-grain-temperature

The project about the GeoMAN network of grain temperature prediction

Primary LanguagePython

代码文件及模型输入输出的介绍和使用

1. 代码文件夹及其代码文件的介绍

  • 代码文件夹内容介绍:

    代码文件夹**有3个文件夹,分别是data,logs和results。

    第一个data文件夹下存放气象数据mete.xls、粮食温度数据grainterm_1.xlsx、和由这两个数据文件处理过后的.npy文件,即为神经网络的输入数据,数据的.npy文件由create_handle_data.py脚本运行而产生的,其中.npy文件中的数据格式在 Section 5 中会详细介绍。data中的.npy文件总共有17个,有5个训练集数据、5个交叉验证集数据、5个测试集数据、1个全局输入数据和1个全局注意力机制输入数据。17个数据文件中训练集、交叉验证集和测试集数据格式是一致的,都分别有1个局部输入数据、1个外部输入数据、1个全局输入索引、1个全局注意力机制输入索引和1个验证用的标签。data中还有一个scalers文件夹,其中存放的是标准归一化函数器。为了保证训练效果,需要先对数据进行归一化处理,映射到归一化标签的,但最后预测的标签是需要逆变换到原始标签的。因此,需要保存一个标准归一化函数器使得最后预测的标签和原始标签数量级一致,而scalers中的函数器由运行脚本create_pkl.py得以保存。

    第二个logs文件夹下存放的第一个文件GeoMAN-12-6-2-256-0.30-0.001,表示隐藏层神经元数为256的网络训练出来的模型参数结果;第二个文件是GeoMAN-12-6-2-128-0.30-0.001,表示隐藏层有256个神经元的网络训练出来的模型参数。两个文件夹结构一致,文件夹中有一个events文件,其保存的是神经网络的计算图结构,即神经网络可视化的结构图。还有一个文件夹saved_models,其中保存的是训练好以后的模型的参数结果。

    第三个results文件夹中保存的是训练好的隐藏层有128神经元的模型最终的测试结果,其中一个是由Predict_model.py产生的预测结果(PREDICT);另一个是Test_model_training.py脚本产生的数据集误差(MATRIC)。

  • 代码介绍:

    目前还没有介绍的脚本文件有base_model.py,N_network.py,Training_model.py,utils.py。

    其中base_model创建了一个基础神经网络类,是为了构建GeoMAN完整网络准备的。因此,N_network创建的就是GeoMAN的完整网络,包括结构及其需要操作的功能函数。Training_model.py就是用来训练模型的脚本,而utils是以上代码所需要使用的一些零散函数的集合文件。

2. 环境及其Python模块的要求

  • 代码运行环境介绍:

    • OS:Windows 10
    • IDE:PyCharm
    • Tensorflow 1.x
    • CPU or GPU
  • Python模块要求:

    • tensorflow
    • numpy
    • pandas
    • pickle
    • json
    • xlrd
    • sklearn

3. 代码逻辑结构

  • 先运行create_handle_data.py脚本生成数据处理后的.npy数据文件,包括将数据划分为全局数据、局部数据、外部输入数据和标签,并使用 Section 5 介绍的数据格式结构组织数据,最后生成.npy文件,其中数据都经过归一化处理,量纲统一。

    其中脚本会调用utils.py中handle_data函数将处理好的数据保存为.npy数据文件。
  • 再运行create_pkl.py脚本生成标准归一化函数器以恢复标签的归一化映射。

    其中脚本会调用utils.py中split_labels函数将处理好的label数据归一化处理并保存其标准归一化函数器为pkl文件。
  • Training_model.py:脚本中只有main函数,流程如下,先设置种子数并设置运行的GPU标号,再初始化学习率,编码器encoder,解码器decoder,特征数。详细数据见 Section 9。导入.npy已处理过的数据文件。之后,初始化tensorflow计算图,构建GeoMAN网络,计算网络参数个数和设置网络结果保存路径。最后加载训练集数据,填充网络输入接口placeholder占位符。运行计算图,训练网络参数,一旦loss是目前最小的,则保存当前模型,直到训练结束。

  • utils.py:basic_hyperparams初始化网络结构和训练参数。count_total_params计算初始化后的网络参数个数;load_data加载局部输入、全局输入索引、外部输入和标签。shuffle_data打散构建好的数据,减少数据相关性;get_batch_feed_dict构建训练的输入数据以填充placeholder占位符;load_global_inputs加载全局输入和全局注意力机制输入的数据;get_valid_batch_feed_dict构建交叉验证集的输入数据以填充placeholder占位符。

  • Test_model_training.py:创建两个函数,分别计算RMSE和MSE,其中两个性能测试因子详细解释见 Section 8。main函数中先设置运算GPU的标号,同时初始化模型的训练和结构参数。加载数据、标准化归一化函数器和保存的模型参数。运行已训练好的模型,对数据进行标签估计,同时和标签现实进行比较,计算RMSE和MSE,并打印测试结果。

  • Predict_model.py:main函数中先设置运算GPU的标号,同时初始化模型的训练和结构参数。加载标准化归一化函数器和保存的模型参数。构建最后12个布点号109的粮食温度时间序列及其相关数据,运行已训练好的模型,对数据进行标签估计,同时打印出结果。

4. 运行代码及其参数结果的保存方法

  • 训练模型

    • 在Pycharm中运行Training_model.py脚本进行模型训练
    • 在代码目录下使用命令:python Training_model.py运行脚本进行模型训练
  • 测试模型性能

    测试结果说明在 Section 8 中会详细介绍。

    • 在Pycharm中运行Test_model_training.py脚本对训练好的模型进行性能测试
    • 在代码目录下使用命令:python Test_model_training.py运行脚本进行模型性能测试
  • 预测测试

    预测结果为布点号109最后6天的预测温度序列。

    • 在Pycharm中运行Predict_model.py脚本对训练好的模型进行预测测试
    • 在代码目录下使用命令:python Predict_model.py运行脚本进行预测测试
  • 模型参数结果保存

    模型参数结果会在运行Training_model.py脚本时,实时更新loss最小的模型的参数,并将其保存在logs文件夹下的对应的saved_model文件夹中。

    同时,若要进行性能测试或者预测测试的话可以自动加载模型参数。

5. 模型输入格式介绍

ps:训练集是用于训练的数据;交叉验证集是用于训练中进行性能验证的数据;测试集是用于训练好的模型进行性能测试的数据。而由于训练集、交叉验证集和测试集数据结构一致,因此以下只分为7种数据结构进行介绍。

  • 局部输入数据格式 (Local inputs)

    局部输入因素有26个类型,即按照要求的109号布号点附近的粮食温度。

    由于RNN编码器encoder要求连续12个时间序列输入,因此可以将现有数据按照时间选择连续12天的时间序列,进行数据扩充和数据增强,帮助监督学习训练模型,这种数据增强方法同时使用在外部输入数据、测试预测标签和全局注意力机制输入数据中。

    因此,最后局部输入格式为(?, 12, 26),其中?表示数据增强以后能使用的数据组合对的数量。

  • 全局输入数据索引 (Global inputs index)

    全局输入因素为200个布号点种剩下的174个布号点。

    因此,索引所保存的就是剩下174个布号点在表格数据中的列索引。

  • 全局注意力机制输入数据索引 (Global attention states index)

    同全局输入数据索引,索引内容与其一致。

  • 外部输入数据格式 (Externel inputs)

    外部输入因素有8个类型,即按照要求的不同气象因素数据。

    由于RNN编码器decoder要求连续6个时间序列输入,因此可以将现有数据按照时间选择连续6天的时间序列,进行数据扩充和数据增强。

    因此,最后外部输入格式为(?, 6, 8),其中?表示数据增强以后能使用的数据组合对的数量。

  • 测试预测标签输入数据格式 (Labels)

    标签输入只有1种类型,即预测标签。

    因此,最后外部输入格式为(?, 6, 1),其中?表示数据增强以后能使用的数据组合对的数量。

  • 全局输入数据格式 (Global inputs)

    全局输入数据根据全局输入索引取数据,共有174个全局输入数据因素。

    因此,最后全局输入格式为(?, 174, 12),其中?表示数据增强以后能使用的数据组合对的数量。

  • 全局注意力机制输入数据格式 (Global attention states)

    全局注意力机制输入数据根据全局注意力机制输入索引取数据,因为全局注意力机制希望获得不同传感器之间的关系,而对于每个传感器,其局部输入影响其数值。因此对于每个传感器,需要将局部输入包含进来,这里,将每个布号点传感器附近的26个布号点作为局部输入。

    因此,最后全局注意力机制输入格式为(?, 174, 26, 12),其中?表示数据增强以后能使用的数据组合对的数量。

6. 模型输出格式介绍

  • 训练模型脚本的输出为训练过程中loss值和模型保存进程。

  • 测试性能模型的输出为训练好的模型在测试数据下的RMSE和MSE,具体的误差值介绍在 Section 8

  • 预测脚本的输出为布点号109最后6个时间序列的粮食温度值的估计。

7. 模型输入的使用介绍

若需要改变训练数据,自定义训练和测试数据,可以依照 Section 5 介绍的格式进行数据组织和修改create_handle_data.py脚本中的数据处理代码,最后重新运行以生成自定义的.npy数据文件。

8. 模型结果介绍及其分析

  • 模型结果介绍

    • RMSE 均方根误差:又称标准误差,均方根误差亦称标准误差,其定义为 ,i=1,2,3,…n。在有限测量次数中,均方根误差常用下式表示:√[∑di^2/n]=Re,式中:n为测量次数;di为一组测量值与真值的偏差。如果误差统计分布是正态分布,那么随机误差落在±σ以内的概率为68%。

      均方根误差是预测值与真实值偏差的平方与观测次数n比值的平方根,在实际测量中,观测次数n总是有限的,真值只能用最可信赖(最佳)值来代替。标准误差 对一组测量中的特大或特小误差反映非常敏感,所以,标准误差能够很好地反映出测量的精密度。这正是标准误差在工程测量中广泛被采用的原因。因此,标准差是用来衡量一组数自身的离散程度,而均方根误差是用来衡量观测值同真值之间的偏差,它们的研究对象和研究目的不同,但是计算过程类似。
    • MSE 均方误差:又称方差,即RMSE的平方。

  • 结果分析

    • 经过多次训练得出,隐藏层最优结果的数量为128,平均loss在1000个episodes训练t个step中,最低可以达到0.005,通过归一化逆变换后,标签误差大概在1 ~ 2单位量级。

    • 由results文件夹中两个图片可以得出,最后6个布号点109的温度预测误差在2以内;RMSE和MSE结果为0.2以下,属于误差较小,即表示最终的模型预测准确率较高。其中RMSE和MSE具体解释如上。

9. 模型结构参数和训练参数

  • 模型结构参数

    • 全局特征数: 174
    • 局部特征数:26
    • 被使用的时间序列长度:12
    • 编码器的隐藏层数:128
    • 解码器个数:1
    • 外部输入特征数:8
    • 被预测的时间序列长度:6
    • 解码器的隐藏层数:128
    • 解码器输出的个数:1
  • 训练参数

    • 学习率:0.001
    • lammda:0.001
    • 避免梯度爆炸的系数:2.5
    • dropout比率:0.3
    • stacked_layer数:2
    • 注意力机制flag:2
    • 是否需要外部输入网络的参数:True