前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >从零开始学人工智能-Python·决策树(三)·节点

从零开始学人工智能-Python·决策树(三)·节点

作者头像
企鹅号小编
发布2018-01-22 15:21:04
7280
发布2018-01-22 15:21:04
举报
文章被收录于专栏:人工智能

作者:射命丸咲Python 与 机器学习 爱好者

知乎专栏:https://zhuanlan.zhihu.com/carefree0910-pyml

个人网站:http://www.carefree0910.com

本章用到的 GitHub 地址:

https://github.com/carefree0910/MachineLearning/blob/master/Zhihu/CvDTree/one/CvDTree.py

本章用到的数学相关知识:

https://zhuanlan.zhihu.com/p/24501172

上一章我们把 node 的结构搭好了,这一章要做的就是塞东西进去。为此,我们不妨先看看我们需要实现什么:

一个 fit 函数,它能够根据输入的数据递归生成一颗决策树

一个 handle_terminate 函数,它在 node 成为 leaf 时调用,用于更新它爸爸和它爸爸的爸爸和……等等的信息

一个 prune 函数,用于剪枝

一个 view 函数,用于可视化

大提上来说就是这两点,剩下的就是一些细节。我们分开来说说怎么去实现它们

fit 函数

先说一下大概的流程:

接受数据和相应的标签

判断该 node 是否应该被当做 leaf;若是,则 return,否则继续往下走算出各维度的条件熵,记录下最好的条件熵、信息增益和此时关注的数据维度

比如说,如果输入的数据和标签为:

那么通过计算各个维度的条件熵和信息增益可知,此时该 node 关注的数据维度应该是第一维、也就是 A 对应的那一维。直观来说,这意味着 A 提供的信息量最大(事实上在这个栗子中,A 和 Label 是一样的)

根据信息增益判断是否终止(比如在 ID3 中,如果信息增益小于阈值的话就直接终止。这种判断方法会有比较严重的缺陷,观众老爷们可以想一想为什么 ( σ'ω')σ 【提示:异或数据集】)

根据所选的数据维度的各个特征把数据集切分成几份,分别喂给新的 node、递归,同时把这些 node 记录在自己的 children 里面

由于利用了递归,感觉还是一个比较干净利落的实现。下面就贴一些核心的代码,完整的实现可以参见这里

计算各个维度的条件熵和信息增益,这里就要用到准则章节的东西了

递归

handle_terminate 函数

如果童鞋们还记得我们 node 的结构的话,大概就会知道当一个 node 成为 leaf 后、需要做的事情有两个:

只要分别实现它们就好了:

其中

计算该 leaf 属于哪一类

更新它列祖列宗的 leafs 变量

prune 函数

需要指出的是,node 的 prune 函数不是决策树的剪枝算法、而是会在决策树的剪枝算法中被调用。它仅仅是为了该 node 的所有子孙都切了而已(喂

先说说流程,核心思想其实就是把该 node 变成一个 leaf:

判断该 node 应该属于哪一类

把该 node 的 leafs 中的 leaf 从该 node 的列祖列宗中的 leafs 中删除

把该 node 存进列祖列宗的 leafs 中

把该 node 自身及其所有子孙打上“已被剪枝”的标签

接下来是实现:

其中打标签用的 self.mark_pruned 函数的定义如下:

view 函数

基本思路很简单:如果自己是 leaf、就直接输出相关信息,否则在输出自己相关信息的同时、还要调用自己所有 children 的 view 函数。以下是实现:

这一章有点长,稍微总结一下:

决策树的生长关键是靠递归。当 node 接收一个数据和标签时,它会选出数据的某个维度、记录下来,然后会根据该维度的各个特征将数据、标签进行划分,分别喂给新的 node、从而能够递归下去

在 node 被判定应该是 leaf 时,要判断它属于哪一类并更新它列祖列宗的 leafs 变量

node 的 prune 函数是用来把 node 变成 leaf 并更新结构的,它本身不是决策树的剪枝算法、但它会在决策树的剪枝算法中被调用

下一章我们就要说说怎么建立一个框架以利用这些 node 来搭建一颗真正的决策树了。可能有童鞋已经敏锐地发现:不就只剩一个剪枝算法没有实现了吗?

事实上正是如此。下一章的框架确实只额外地实现了剪枝算法,剩下的都是封装的活儿

希望观众老爷们能够喜欢~

本文来自企鹅号 - 数据头条媒体

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文来自企鹅号 - 数据头条媒体

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档