首先我们通过一个简单的数据集来了解一下KNN算法。
data_x数据集是绘制图像需要的几个基本点,data_y是这些点所在的不同类别,分别是0,1。
接下来我们通过matplotlib来展示这些点的分布情况。
由结果我们可以清楚的看到,红色和绿色的两个不同类别分布情况,接下来我们又得到了另一个点,但是我们不知道他是属于哪一个类别,这时我们就需要将它也在图上表示出来,代码如下:
我们将新增的这个点用蓝色表示,结果如下:
我们可以清楚的再图上看到他的位置,和红色部分靠近。
因为KNN算法是根据相近的几个点的位置来判断该点是什么类别,所以我们可以知道,这个点是属于红色的类别。
思路:
1.计算出新的那个点和原来那些点的距离
或直接使用:
结果都是一样的,我们得到了新的点和原来点的距离
2.接下来我们要得到与新的点距离最近的几个点,我们要对生成的dis里面的数据进行排序,得到这些点的位置,我们可以得到最近的点的索引位置,使用numpy内置的argsort方法。
将排序结果存在short里面:
可以知道,最近的点是索引为8的点,其次是7.
3.我们还要得到这些距离最近的几个点属于哪些类别。我们设置k为6,看前六个点属于哪些类别,将y_train中的数据在short中遍历,看哪些符合条件。
结果:[1, 1, 1, 1, 1, 0]
所以,最近的五个点为类别1,还有一个类别0.
或者我们计算结果:
因为我们要获取的是类别,所以通过[0][0]获取。
因此我们就得到了最近的点的类别是1,这个点就是类别1.
领取专属 10元无门槛券
私享最新 技术干货