决策树思想的来源非常朴素,程序设计中的条件分支结构就是if-else结构,最早的决策树就是利用这类结构分割数据的一种分类学习方法。
决策树算法概述
认识决策树
例如以前的一个游戏,二十个问题的游戏,游戏的规则很简单:参与游戏的一方在脑海里想某个事物,其他参与者向他提问题,只允许提20个问题,问题的答案也只能用对或错回答。问问题的人通过推断分解,逐步缩小待猜测事物的范围。决策树的工作原理与20个问题类似,用户输人一系列数据,然后给出游戏的答案。
我们经常使用决策树处理分类问题,近来的调查表明决策树也是最经常使用的数据挖掘算法。它之所以如此流行,一个很重要的原因就是不需要了解机器学习的知识,就能搞明白决策树是如何工作的。
我们来看一个例子:你母亲要给你介绍男朋友,是这么来对话的:
女儿:多大年纪了?
母亲:26。
女儿:长的帅不帅?
母亲:挺帅的。
女儿:收入高不?
母亲:不算很高,中等情况。
女儿:是公务员不?
母亲:是,在税务局上班呢。
女儿:那好,我去见见。
由此可以看出,如何高效的进行决策,取决于特征选取的先后顺序,那么如何决定特征选取的先后顺序?这就是决策树算法主要解决的问题。
基于信息论确定特征选取顺序
信息熵的定义
在信息论里,用H表示信息熵,表示随机变量不确定性的度量,单位是bit。如果待分类的事务可能划分在多个分类之中,则每个分类xi的信息定义为:

计算信息熵H,需要计算所有类别所有可能值包含的信息期望值(数学期望):

其中n是分类的数目,p(xi)是选择该分类的概率。
我们认为熵越大,随机变量的不确定性就越大。
信息增益
划分数据集的大原则是:将无序的数据变得更加有序。我们可以使用多种方法划分数据集,但是每种方法都有各自的优缺点。组织杂乱无章数据的一种方法就是使用信息论度量信息,信息论是量化处理信息的分支科学。我们可以在划分数据之前或之后使用信息论量化度量信息的内容。在划分数据集之前之后信息发生的变化称为信息增益,知道如何计算信息增益,我们就可以计算每个特征值划分数据集获得的信息增益,获得信息增益最高的特征就是最好的选择。
在可以评测哪种数据划分方式是最好的数据划分之前,我们必须学习如何计算信息增益。集合信息的度量方式称为香农熵或者简称为熵,这个名字来源于信息论之父克劳德·香农。
如果看不明白什么是信息增益( information gain )和熵( entropy )请不要着急。它们自诞生的那一天起,就注定会令世人十分费解。克劳德·香农写完信息论之后,约翰·冯·诺依曼建议使用“熵”这个术语,因为大家都不知道它是什么意思。
决策树的划分依据
当然决策树的原理不止信息增益这一种, 还有其他方法。但是原理都类似,我们就不去举例计算。类似的如下:
- ID3:信息增益最大的准则
- C4.5:信息增益比最大的准则
- CART:分类树——基尼系数最小的准则(在sklearn中可以选择划分的默认原则)。优势:划分更加细致
决策树的一般流程
- 收集数据:可以使用任何方法。
- 准备数据:树构造算法只适用于标称型数据,因此数值型数据必须离散化。
- 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期。
- 训练算法:构造树的数据结构。
- 测试算法:使用经验树计算错误率。
- 使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据的内在含义。
决策树算法特性
优点
- 计算复杂度不高
- 能够实现可视化
- 对数据规格无要求(离散、连续)
- 对中间值的缺失不敏感,可以处理不相关特征数据
- 模型相对固定
- 输出结果易于理解
缺点(不足)
- 数据不能缺失
- 可能会产生过度匹配的问题
- 要求对树的生长有控制
构建决策树
ID3算法的核心是在决策树各个结点上对应信息增益准则选择特征,递归地构建决策树。
- Step1:计算数据集所有的特征的信息增益。
- Step2:选择信息增益最大的特征作为结点的特征。
- Step3:由该特征的不同取值划分数据集;
- Step4:计算对step3中划分的所有子集进行以下操作。
- Step4.1:子集中的样本的目标值属于同一分类或者子集中所有特征的信息增益均很小或没有特征可以选择,算法end并得到一个决策树;
- Step4.2:否则,再对子集递归地调用Step1到Step4。
实施决策树算法
Python核心代码如下:
def majorityCut(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),
key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def createTree(dataSet, labels, featLables):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(labels) == 0:
return majorityCut(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLable = labels[bestFeat]
featLabels.append(bestFeatLable)
myTree = {bestFeatLable: {}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels.copy()
myTree[bestFeatLable][value] = createTree
(splitDataSet(dataSet, bestFeat, value),subLabels, featLables)
return myTree
函数说明:统计classList中出现此处最多的元素(类标签)
Parameters:
classList – 类标签列表
Returns:
sortedClassCount[0][0] – 出现此处最多的元素(类标签)
函数说明:创建决策树
Parameters:
dataSet – 训练数据集
labels – 分类属性标签
featLabels – 存储选择的最优特征标签
Returns:
myTree – 决策树