/Decision-Tree

An implement of Decision Tree

Primary LanguageC++

决策树实现:基于信息增益的ID3算法

实现了决策树的学习以及分类,并采用后向剪枝进行优化,采用信息增益作为学习过程的属性选择依据。

文件目录:

  • encode.txt:对各个属性的编码说明
  • ID3.cpp:源代码
  • TrainData.txt:训练数据
  • TestData.txt:测试数据

编码

把每一种属性都用数字来表示。

青绿:0 乌黑:1 浅白:2
蜷缩:0 稍蜷:1 硬挺:2
浊响:0 沉闷:1 清脆:2
清晰:0 稍糊:1 模糊:2
凹陷:0 稍凹:1 平坦:2
硬滑:0 软粘:1 

训练

第1行:训练样本个数及其属性个数
第2行:各个属性的属性值的个数
第3-12行:训练样本

10 6
3 3 3 3 3 2
0 0 0 0 0 0 1  
1 0 1 0 0 0 1
1 0 0 0 0 0 1
0 1 0 0 1 1 1
1 1 0 1 1 1 1
0 2 2 0 2 1 0
2 1 1 1 0 0 0
1 1 0 0 1 1 0
2 0 0 2 2 0 0
0 0 1 1 1 0 0

测试

第1行:测试样本个数及其属性个数
第2-8行:测试样本

7 6
0 0 1 0 0 0 1
2 0 0 0 0 0 1
1 1 0 0 1 0 1
1 1 1 1 1 0 0
2 2 2 2 2 0 0
2 0 0 2 2 1 0
0 1 0 1 0 0 0

函数功能简介

void InputData(TrainData& data);    //输入训练数据
Node* BuildTree(TrainData data, vector<int> vis);   //建树(核心),即学习过程
int MostValue(TrainData data);      //统计数据中种类最多的那一类
int BestAttribute(TrainData data, vector<int>vis);  //找出信息增益最大的那一个属性
double CalGain(TrainData data, int Attr);   //计算各个属性的信息增益
double CalEntropy(TrainData data);      //计算信息熵
void GetTestData(TestData& Tdata);      //输入测试数据
int Judge(Node* root, TestData Tdata, int state);   //对测试数据分类
void Prune(Node* root, Node* u, TestData Tdata);    //后向剪枝
void PrintTreeStruct(Node* root);       //打印树的结构(测试代码用,可以删去这部分)