首页
学习
活动
专区
工具
TVP
发布
精选内容/技术社群/优惠产品,尽在小程序
立即前往

回归树的原理及Python实现

提到回归树,相信大家应该都不会觉得陌生(不陌生你点进来干嘛[捂脸]),大名鼎鼎的 GBDT 算法就是用回归树组合而成的。本文就回归树的基本原理进行讲解,并手把手、肩并肩地带您实现这一算法。

完整实现代码请参考

github:https://github.com/tushushu/Imylu/blob/master/regression_tree.py

1. 原理篇

我们用人话而不是大段的数学公式,来讲讲回归树是怎么一回事。

1.1 最简单的模型

如果预测某个连续变量的大小,最简单的模型之一就是用平均值。比如同事的平均年龄是 28 岁,那么新来了一批同事,在不知道这些同事的任何信息的情况下,直觉上用平均值 28 来预测是比较准确的,至少比 0 岁或者 100 岁要靠谱一些。我们不妨证明一下我们的直觉:

1.2 加一点难度

仍然是预测同事年龄,这次我们预先知道了同事的职级,假设职级的范围是整数1-10,如何能让这个信息帮助我们更加准确的预测年龄呢?

一个思路是根据职级把同事分为两组,这两组分别应用我们之前提到的“平均值”模型。比如职级小于 5 的同事分到A组,大于或等于5的分到 B 组,A 组的平均年龄是 25 岁,B 组的平均年龄是 35 岁。如果新来了一个同事,职级是 3,应该被分到 A 组,我们就预测他的年龄是 25 岁。

1.3 最佳分割点

还有一个问题待解决,如何取一个最佳的分割点对不同职级的同事进行分组呢?

我们尝试所有 m 个可能的分割点 P_i,沿用之前的损失函数,对 A、B 两组分别计算 Loss 并相加得到 L_i。最小的 L_i 所对应的 P_i 就是我们要找的“最佳分割点”。

1.4 运用多个变量

再复杂一些,如果我们不仅仅知道了同事的职级,还知道了同事的工资(貌似不科学),该如何预测同事的年龄呢?

我们可以分别根据职级、工资计算出职级和工资的最佳分割点P_1, P_2,对应的Loss L_1, L_2。然后比较L_1和L2,取较小者。假设L_1

1.5 答案揭晓

如何实现这种1 to 2, 2 to 4, 4 to 8的算法呢?

熟悉数据结构的同学自然会想到二叉树,这种树被称为回归树,顾名思义利用树形结构求解回归问题。

2. 实现篇

本人用全宇宙最简单的编程语言——Python实现了回归树算法,没有依赖任何第三方库,便于学习和使用。简单说明一下实现过程,更详细的注释请参考本人github上的代码。

2.1 创建Node类

初始化,存储预测值、左右结点、特征和分割点

classNode(object):

def __init__(self,score=None):

self.score=score

self.left=None

self.right=None

self.feature=None

self.split=None

2.2 创建回归树类

初始化,存储根节点和树的高度。

classRegressionTree(object):

def __init__(self):

self.root=Node()

self.height=

2.3 计算分割点、MSE

根据自变量X、因变量y、X元素中被取出的行号idx,列号feature以及分割点split,计算分割后的MSE。注意这里为了减少计算量,用到了方差公式:

2.4 计算最佳分割点

遍历特征某一列的所有的不重复的点,找出MSE最小的点作为最佳分割点。如果特征中没有不重复的元素则返回None。

def _choose_split_point(self,X,y,idx,feature):

unique=set([X[i][feature]foriinidx])

iflen(unique)==1:

returnNone

unique.remove(min(unique))

mse,split,split_avg=min(

(self._get_split_mse(X,y,idx,feature,split)

forsplitinunique),key=lambdax:x[])

returnmse,feature,split,split_avg

2.5 选择最佳特征

遍历所有特征,计算最佳分割点对应的MSE,找出MSE最小的特征、对应的分割点,左右子节点对应的均值和行号。如果所有的特征都没有不重复元素则返回None

def _choose_feature(self,X,y,idx):

m=len(X[])

split_rets=[xforxinmap(lambdax:self._choose_split_point(

X,y,idx,x),range(m))ifxisnotNone]

ifsplit_rets==[]:

returnNone

_,feature,split,split_avg=min(

split_rets,key=lambdax:x[])

idx_split=[[],[]]

whileidx:

i=idx.pop()

xi=X[i][feature]

ifxi

idx_split[].append(i)

else:

idx_split[1].append(i)

returnfeature,split,split_avg,idx_split

2.6 规则转文字

将规则用文字表达出来,方便我们查看规则。

def _expr2literal(self,expr):

feature,op,split=expr

op=">="ifop==1else"

return"Feature%d %s %.4f"%(feature,op,split)

2.7 获取规则

将回归树的所有规则都用文字表达出来,方便我们了解树的全貌。这里用到了队列+广度优先搜索。有兴趣也可以试试递归或者深度优先搜索。

def _get_rules(self):

que=[[self.root,[]]]

self.rules=[]

whileque:

nd,exprs=que.pop()

literals=list(map(self._expr2literal,exprs))

self.rules.append([literals,nd.score])

ifnd.left:

rule_left=copy(exprs)

rule_left.append([nd.feature,-1,nd.split])

que.append([nd.left,rule_left])

ifnd.right:

rule_right=copy(exprs)

rule_right.append([nd.feature,1,nd.split])

que.append([nd.right,rule_right])

2.8 训练模型

仍然使用队列+广度优先搜索,训练模型的过程中需要注意:

控制树的最大深度max_depth;

控制分裂时最少的样本量min_samples_split;

叶子结点至少有两个不重复的y值;

至少有一个特征是没有重复值的。

def fit(self,X,y,max_depth=5,min_samples_split=2):

self.root=Node()

que=[[,self.root,list(range(len(y)))]]

whileque:

depth,nd,idx=que.pop()

ifdepth==max_depth:

break

iflen(idx)

set(map(lambdai:y[i],idx))==1:

continue

feature_rets=self._choose_feature(X,y,idx)

iffeature_retsisNone:

continue

nd.feature,nd.split,split_avg,idx_split=feature_rets

nd.left=Node(split_avg[])

nd.right=Node(split_avg[1])

que.append([depth+1,nd.left,idx_split[]])

que.append([depth+1,nd.right,idx_split[1]])

self.height=depth

self._get_rules()

2.9 打印规则

模型训练完毕,查看一下模型生成的规则

def print_rules(self):

fori,ruleinenumerate(self.rules):

literals,score=rule

print("Rule %d: "%i,' | '.join(

literals)+' => split_hat %.4f'%score)

2.10 预测一个样本

def _predict(self,row):

nd=self.root

ifrow[nd.feature]

nd=nd.left

else:

nd=nd.right

returnnd.score

2.11 预测多个样本

def predict(self,X):

return[self._predict(Xi)forXiinX]

3 效果评估

3.1 main函数

使用著名的波士顿房价数据集,按照7:3的比例拆分为训练集和测试集,训练模型,并统计准确度。

@run_time

def main():

print("Tesing the accuracy of RegressionTree...")

# Load data

X,y=load_boston_house_prices()

# Split data randomly, train set rate 70%

X_train,X_test,y_train,y_test=train_test_split(

X,y,random_state=10)

# Train model

reg=RegressionTree()

reg.fit(X=X_train,y=y_train,max_depth=4)

# Show rules

reg.print_rules()

# Model accuracy

get_r2(reg,X_test,y_test)

3.2 效果展示

最终生成了15条规则,拟合优度0.801,运行时间1.74秒,效果还算不错~

3.3 工具函数

本人自定义了一些工具函数,可以在github上查看https://github.com/tushushu/Imylu/blob/master/utils.py1. run_time – 测试函数运行时间 2. load_boston_house_prices – 加载波士顿房价数据 3. train_test_split – 拆分训练集、测试机 4. get_r2 – 计算拟合优度

总结

回归树的原理:

损失最小化,平均值大法。 最佳行与列,效果顶呱呱。

回归树的实现:

一顿操作猛如虎,加减乘除二叉树。

【关于作者】

李小文:先后从事过数据分析、数据挖掘工作,主要开发语言是Python,现任一家小型互联网公司的算法工程师。Github:https://github.com/tushushu

  • 发表于:
  • 原文链接https://kuaibao.qq.com/s/20180907A21W3I00?refer=cp_1026
  • 腾讯「腾讯云开发者社区」是腾讯内容开放平台帐号(企鹅号)传播渠道之一,根据《腾讯内容开放平台服务协议》转载发布内容。
  • 如有侵权,请联系 cloudcommunity@tencent.com 删除。

扫码

添加站长 进交流群

领取专属 10元无门槛券

私享最新 技术干货

扫码加入开发者社群
领券