目录
引言
一、复杂数据的局部性建模
二、连续和离散型特征的树的构建
三、将cart算法用于回归
3.1 构建树
编辑四、树剪枝
4.1 预剪枝
4.2 后剪枝
五、树模型
六、使用python的Tkinter库创建GUI
6.1 用Tkinter创建GUI
6.2 集成Matplotlib和Tkinter
本章将会学习CART(分类回归树)的树构建算法,算法可以用于分类也可以用于回归。与回归树的做法不同,该算法需要在每个叶节点上构建出一个线性模型。树构建的算法还要调整一些参数,故会介绍使用python中tkinter模块建立图形交互界面。并在该界面的辅助下分析参数对回归效果的影响。
决策树通过不断将数据切分成小数据,直到所有目标变量完全相同,或者数据不能再切分为止。决策树是一种贪心算法,要在给定的时间内做出最佳的选择,并不关心能否达到全局最优。
树回归的优缺点:
优点:可以对复杂和非线性的数据建模
缺点:结果不易理解
适用数据类型:数值型和标称型数据
ID3的做法是每次选取当前最佳的特征来分隔数据,并按照该特征的所有可能取值来切分。当特征有4种取值时,那么数据将会被分为4份,当按某种特征切分后,该特征在之后的算法执行过程中将不会再起作用。另一种二元切分法,每次将数据集切成两份。当数据的某特征值等于切分所要求的值时,这些数据就会进入树的左子树,反之进入树的右子树。
ID3算法还存在无法直接处理连续型特征的问题,需要提前将数据替换成离散型。二元切分法则易于对树构建过程进行调整以处理连续型特征。当特征值大于给定值走左子树,否则走右子树。二元切法也节省了树构建的时间。
树回归的一般方法:
1、收集数据:采用任意方法收集数据。
2、准备数据:需要数值型的数据,标称型数据应该映射成二值型数据。
3、分析数据:绘出数据的二维可视化显示结果,以字典方式生成树。
4、训练算法:大部分时间都花费在叶节点树模型的构建上。
5、测试算法:使用测试数据上的R2值来分析模型的效果。
6、使用算法:使用训练出的树做预测,预测结果还可以用来做很多事情。
利用字典来存储树的数据结构,字典将包含以下4个元素:
1、待切分的特征
2、待切分的特征值
3、右子树。当不在需要切分时,也可以是单个值
4、左子树。与右子树类似
实现cart算法:
伪代码:
找到最佳的待切分特征:
如果该节点不能再分,将该节点存为叶节点
执行二元切分
在右子树调用creatTree()方法
再左子树调用creatTree()方法
python代码实现:
def regLeaf(dataSet):
# 该函数对数据最后一列即结果进行求均值,用于在构建叶节点的时候调用,返回节点中样本结果的均值作为回归值
return mean(dataSet[:, -1])
def regErr(dataSet):
# 该函数是计算当前节点中样本的总方差
return var(dataSet[:, -1]) * shape(dataSet)[0]
# 下载并处理数据
def loadDataSet(fileName):
dataMat = []
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
fltLine = list(map(float, curLine)) # map返回结果为迭代器,使用list函数,将其转化成列表
dataMat.append(fltLine)
return dataMat
# 对确定特征进行二元切分
def binSplitDataSet(dataSet,feature,value):
mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0],:]
mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
return mat0,mat1
def createTree(dataSet, leafType=regLeaf, errType=regErr,ops=(1,4)):
feat,val = chooseBestSplit(dataSet,leafType,errType,ops)
if feat == None:
return val # 如果None说明满足了停止条件,这里val就是某类模型的结果
retTree = {}
retTree['spInd'] = feat # 切分的特征
retTree['spVal'] = val # 特征的取值
lSet,rSet = binSplitDataSet(dataSet,feat,val)
retTree['left'] = createTree(lSet,leafType,errType,ops)
retTree['right'] = createTree(rSet,leafType,errType,ops)
return retTree
由此创建了一个简单的矩阵,按指定列的值切分矩阵得到下面的矩阵。
通过chooseBestSplit()函数,实现给定误差计算方法,成功找到数据集上最佳的二元切分方法。函数仅需完成两件事:用最佳方式切分数据集和生成相应的叶节点。同时还使用到了上面创建的leafType、errType、ops三个参数,leafType是对创建叶节点的函数的引用,errType是对前面介绍的总方差计算函数的引用,ops是一个用户定义的参数构成的元组。用于完成树的构建。
伪代码如下:
对每个特征:
对每个特征值
将数据集切分成两份
计算切分的误差
如果当前误差小于当前最小误差,那么将当前切分设为最佳切分并更新最小误差返回最佳切分的特征和阈值
python代码:
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
tolS = ops[0] # 如果切分的误差减少量小于这个值就说明切分没什么意义
tolN = ops[1] # 如果切分出的数据很小,小于这个值也没意义
if len(set(dataSet[:, -1].T.tolist()[0])) == 1:
# 集合的长度等于1说明当前节点中的样本都是都一个值,那么就不用在划分了
return None, leafType(dataSet)
m, n = shape(dataSet)
S = errType(dataSet)
bestS = inf # 用来存放最小的误差
bestIndex = 0 # 用来存放最好的切分特征
bestValue = 0 # 用来存放最好特征的切分值
for featIndex in range(n - 1): # 第n-1列为结果列
for splitVal in set(dataSet[:, featIndex].T.tolist()[0]): # 按照现有结果中的取值来划分
# 原文这里写错了,set(dataSet[:,featIndex])是错误的
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): # 判断会不会太小
continue
newS = errType(mat0) + errType(mat1)
if newS < bestS: # 小于则更新
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
# 这里判断的含义是防止前面根本找不到合适的切分点导致两个best都没有修改
return None, leafType(dataSet)
return bestIndex, bestValue
myDat = cart.loadDataSet('D:\learning\ex00.txt')
maMat = mat(myDat)
result = cart.createTree(maMat)
print(result)
结果如下:
从文档ex0.txt中读取数据并为其构建一棵回归树代码同上,将ex00.txt替换即可:
树剪枝的概念:当树模型的节点过多,就说明了模型可能对数据进行了“过拟合”,通过降低决策树的复杂度来避免过拟合的过程就称为剪枝。
输入代码:
myDat2 = cart.loadDataSet('D:\learning\ex2.txt')
myMat2 = mat(myDat2)
result2 = cart.createTree(myMat2)
print(result2)
输出结果如下图,可以看到构建的树有很多叶节点,产生这个现象的原因在于,停止条件tols对误差的数量级非常敏感。:
修改停止条件:
cart.createTree(myMat2, ops=(10000, 4))
得到结果:
通过修改停止条件得到合理结果并非很好的办法。
后剪枝方法需要把数据集分成测试集和训练集,先指定参数,使得构建出的树足够的大、足够的复杂,以便于剪枝,自上而下的找到叶节点,用测试集来判断将这些叶节点合并是否能降低误差,可以则合并。
伪代码如下:
基于已有的树切分测试数据:
如果存在任一子集是一棵树,则在该子集递归剪枝过程
计算将当前两个叶节点合并后的误差
计算不合并的误差
如果不合并会降低误差,将叶节点合并
python代码:
def isTree(obj):
# 因为创建树的过程中如果是树都对应一个字典来存储,否则是一个常数
return (type(obj).__name__ == 'dict')
def getMean(tree):
# 如果左右节点还是树的话则需要迭代去寻找.直到是叶节点,叶节点就直接返回为数值
if isTree(tree['right']):
tree['right'] = getMean(tree['right'])
if isTree(tree['left']):
tree['left'] = getMean(tree['left'])
return (tree['left'] + tree['right']) / 2.0
def prune(tree, testData):
if shape(testData)[0] == 0:
# 如果没有测试数据都直接将树整个变成一个叶节点
return getMean(tree)
if (isTree(tree['right']) or isTree(tree['left'])): # 左右有一个为树
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']):
tree['left'] = prune(tree['left'], lSet) # 继续迭代下去
if isTree(tree['right']):
tree['right'] = prune(tree['right'], rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
# 左右都是一个叶子节点的情况,那么左右节点字典对应的就是它们其中样本的均值,可以来计算误差
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal']) # 测试集在当前特征的划分情况
errorNoMerge = sum(power(lSet[:, -1] - tree['left'], 2)) + sum(power(rSet[:, -1] - tree['right'], 2))
# errorNoMerge就是根据test去划分然后计算误差
treeMean = (tree['left'] + tree['right']) / 2.0 # 如果不划分将直接取均值
# 这里应该是做简化处理,因为整体的均值并不是两个均值相差除以2,但前面也没办法获知每一个叶结点的数目
errorMetge = sum(power(testData[:, -1] - treeMean, 2))
# 不划分的时候的误差
if errorMetge < errorNoMerge:
print("Merging")
return treeMean
else:
return tree
else:
return tree
程序中包含的三个函数:
isTree()用于测试输入变量是否为一棵树,返回布尔类型的结果。
getMean()是一个递归函数,从上往下遍历树直到叶节点为止。当找到两个叶节点则计算其平均值。
prune()主函数,有两个参数:待剪枝的树与剪枝所需要的测试数据。先判断是否为空,一旦非空,反复递归调用函数prune()对测试集进行切分。因为树是由其他数据集生成的,所以测试集上会有一些样本与原数据集样本的取值范围不同。
myDat = cart.loadDataSet('D:\learning\ex2.txt')
myMat = mat(myDat)
mytree = cart.createTree(myMat, ops=(0,1))
mydarTest = cart.loadDataSet('D:\learning\ex2test.txt')
myMat2Test = mat(mydarTest)
result2 = cart.prune(mytree,myMat2Test)
print(result2)
输出结果为:
虽然结果有被剪枝掉,但还是没有像预期那样剪枝成两部分,这就说明后剪枝可能不如预剪枝那般有效。实际过程中为寻求最佳模型可以同时使用两种剪枝技术。
用树来对数据建模,除将叶节点简单设置为常数值外,还有一种将叶节点设定为分段线性函数,所谓分段线性指的是模型由多个线性片段组成。
python实现模型树的叶节点生成函数
def linearSolve(dataSet):
# 该函数用来进行简单的线性拟合,返回拟合出来的权值向量
m, n = shape(dataSet)
X = mat(ones((m,n)))
Y = mat(ones((m, 1)))
X[:, 1:n] = dataSet[:, 0:n - 1] # 第一列是1,其他与dataSet的n-1列相同
Y = dataSet[:, -1]
xTx = X.T * X
if linalg.det(xTx) == 0.0:
raise NameError("This matrix is singular, cannot do inverse,\n try increasing the second value of ops")
ws = xTx.I * (X.T * Y)
return ws, X, Y
def modelLeaf(dataSet):
ws,X,Y = linearSolve(dataSet) # 用来构建叶子节点
return ws
def modelErr(dataSet):
ws,X,Y = linearSolve(dataSet)
yHat = X * ws
return sum(power(Y-yHat,2))
利用GUI对回归树调优
1、收集数据:所提供的文本文件。
2、准备数据:用Python解析上述文件,得到数值型数据。
3、分析数据:用Tkinter构建一个GUI来展示模型和数据。
4、训练算法:训练一棵回归树和一棵模型树,并与数据集一起展示出来。
5、测试算法:这里不需要测试过程。
6、使用算法:gui使得人们可以在预剪枝时测试不同参数的影响,还可以帮助我们选择模型的类型。
首先是熟悉使用tkinter库,这个库是python自带的不需要额外安装,直接导入即可。
root = Tk()
mylabel = Label(root, text='hello world')
mylabel.grid()
root.mainloop()
结果显示为:
创建新的python文件treeExplore.py文件:
from numpy import *
from tkinter import *
import cart
def reDraw(tolS,tolN):
pass
def drawNewTree():
pass
root = Tk()
Label(root,text="Plot Place Holder").grid(row = 0,columnspan=3)
Label(root,text="tolN").grid(row=1,column=0)
tolNentry = Entry(root)
tolNentry.grid(row = 1,column=1)
tolNentry.insert(0,'10')
Label(root,text='tolS').grid(row=2,column=0)
tolSentry = Entry(root)
tolSentry.grid(row = 2,column = 1)
tolSentry.insert(0,'1.0')
Button(root,text='ReDraw',command=drawNewTree).grid(row=1,column = 2,rowspan=3)
chkBtnVar = IntVar()
chkBtn = Checkbutton(root,text='Model Tree',variable=chkBtnVar)
chkBtn.grid(row=3,column = 0,columnspan=2)
reDraw.rawDat = mat(cart.loadDataSet("D:\learning\sine.txt"))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw(1.0,10)
root.mainloop()
创建一组Tkinter模块,并利用网格布局管理器安排位置。使用.grid()的方法设定行和列的位置。通过设定columnspan和rowspan的值来告诉布局管理器是否允许一个小部件跨行或跨列。其他还包括文本输入框、复选按钮和按钮整数值等。entry部件是一个允许单行文本输入的文本框。checkbutton和intvar是为读取checkbutton的状态需要创建一个变量,也就是intvar。
最后输出结果为:
可以通过修改matplotlib后端达到在GUI上绘图的目的;matplotlib的构建程序包括了一个前端也就是面向用户的一些代码,如plot()和scatter()方法等。事实上同时创建了一个后端,用于实现绘图和不同应用之间接口。通过改变后端可以将图像绘制在PNG、PDF、SVG等格式的文件上。设置后端为TkAgg。TkAgg可以在所选GUI框架上调用Agg,将Agg呈现在画布上,可以在Tk的GUI上放置一个画布并用.grid()调整布局:
from numpy import *
from tkinter import *
import cart
import matplotlib
matplotlib.use('TkAgg')
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
def reDraw(tolS,tolN):
reDraw.f.clf()
reDraw.a = reDraw.f.add_subplot(111)
if chkBtnVar.get():
if tolN < 2:
tolN = 2
myTree = cart.createTree(reDraw.rawDat, cart.modelLeaf,
cart.modelErr, (tolS, tolN))
yHat = cart.createForeCast(myTree, reDraw.testDat,
cart.modelTreeEval)
else:
myTree = cart.createTree(reDraw.rawDat, ops = (tolS, tolN))
yHat = cart.createForeCast(myTree, reDraw.testDat)
reDraw.a.scatter(reDraw.rawDat[:,0],reDraw.rawDat[:,1],s=5)
reDraw.a.plot(reDraw.testDat,yHat,linewidth=2.0)
reDraw.canvas.show()
def getInputs():
try:tolN = int(tolNentry.get())
except:
tolN = 10
print("enter Integer for tolN")
tolNentry.delete(0,END)
tolNentry.insert(0,'10')
try:tolS = float(tolSentry.get())
except:
tolS = 1.0
print("enter Float for tolS")
tolSentry.delete(0,END)
tolSentry.insert(0,'1.0')
return tolN,tolS
def drawNewTree():
tolN,tolS = getInputs()
reDraw(tolS,tolN)
root = Tk()
Label(root,text="Plot Place Holder").grid(row = 0,columnspan=3)
Label(root,text="tolN").grid(row=1,column=0)
tolNentry = Entry(root)
tolNentry.grid(row = 1,column=1)
tolNentry.insert(0,'10')
Label(root,text='tolS').grid(row=2,column=0)
tolSentry = Entry(root)
tolSentry.grid(row = 2,column = 1)
tolSentry.insert(0,'1.0')
Button(root,text='ReDraw',command=drawNewTree).grid(row=1,column = 2,rowspan=3)
chkBtnVar = IntVar()
chkBtn = Checkbutton(root,text='Model Tree',variable=chkBtnVar)
chkBtn.grid(row=3,column = 0,columnspan=2)
reDraw.rawDat = mat(cart.loadDataSet("D:\learning\sine.txt"))
reDraw.testDat = arange(min(reDraw.rawDat[:,0]),max(reDraw.rawDat[:,0]),0.01)
reDraw.f = Figure(figsize=(5,4),dpi = 100)
reDraw.canvas = FigureCanvasTkAgg(reDraw.f,master=root)
reDraw.canvas.draw()
reDraw.canvas.get_tk_widget().grid(row = 0,columnspan = 3)
reDraw(1.0,10)
root.mainloop()