首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >KD树的构造与Python代码实现

KD树的构造与Python代码实现

原创
作者头像
用户10150864
发布2025-10-26 09:27:33
发布2025-10-26 09:27:33
1610
举报
文章被收录于专栏:机器学习机器学习

一、 KD树的定义

kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。kd树是二叉树,表示对k维空间点集的一个划分。构造kd树相当于不断地用垂直于坐标轴的超平面将k维空间切分,构成一系列的k维超矩形区域。kd树的每个结点对应于一个k维超矩形区域。

构造kd树的方法:构造根结点,使得根结点对应于k维空间中包含所有实例点的超矩形区域(结点)上选择一个坐标轴和在此坐标轴上的一个切分点,确定一个超平面,这个超平面通过选定的切分点并垂直于选定的坐标轴,将当前超矩形区域切分为左右两个子区域(子结点);这时,实例被分到两个子区域。这个过程直到子区域内没有实例时终止(终止时的结点为叶结点)。在此过程中,将实例保存在相应的结点上。

1.1 构造平衡KD树

1.2 搜索 KD树的的最近邻搜索算法

输入: 已构造kd树,目标点x;

输出:x的最近邻;

(1) 初始化当前最近点best_point和最近距离best_dist分别为None ,和 inf,并且声明为全局变量。

(2) 定义一个从当前结点(当前结点可以是任意结点)开始的递归搜索函数,递归终止条件是 当前结点为空,表示调用本函数的上级结点已经是叶子结点,程序直接返回。

(3) 计算目标点x与当前结点的距离,如果小于current_dist,则对当前最近点best_point和最近距离best_dist分别更新为当前结点的点数据和current_dist,否则略过。

(4) 以本结点在构建kd树时存储的划分轴axis为搜索方向。

(5) 对目标结点在axis轴上与本结点axis轴上进行比较,如果小于则递归调用搜索左树,否则递归调用搜索右树,并且调用搜索左树和右树的同时,判断当前结点和目标结点在axis上的距离是否小于已更新的(和目标点)当前最短距离,如果是,表明当前结点在axis上与目标点的距离在以best_dist为半径的超球体内,自然当前结点的另一边子树可能存在最近邻点,因而当前结点的另一边子树需要搜索。

(6) 以上2-6步构成一个递归搜索函数,在此函数外部定义一个包含(1)的直接函数,并在末尾调用从根结点开始的递归搜索函数即可实现搜索目的。并返回最邻近点。

k个最近邻 可以在上述算法的基础上设计一个容量为k的双端队列,先存储k个最邻近的点和距离,当有第k+1个点入队尾时,先删除队列第一个元素,直到程序结尾。

算法实现:

代码语言:txt
复制
class Node:
    def __init__(self, point, left=None, right=None, axis=None):
        self.point = point    # 节点存储的数据点
        self.left = left      # 左子树
        self.right = right    # 右子树
        self.axis = axis      # 划分轴(维度)

def build_kdtree(points, depth=0):
    """递归构建KD树"""
    if not points:  # 空节点返回None
        return None
    
    k = len(points[0])       # 数据维度
    axis = depth % k         # 根据深度选择划分轴
    
    # 按当前轴排序并选择中位数作为分割点
    points_sorted = sorted(points, key=lambda point: point[axis])
    median_idx = len(points_sorted) // 2
    
    # 递归构建子树
    return Node(
        point=points_sorted[median_idx],
        axis=axis,
        left=build_kdtree(points_sorted[:median_idx], depth + 1),
        right=build_kdtree(points_sorted[median_idx+1:], depth + 1)
    )

def squared_distance(p1, p2):
    """计算两点之间的平方距离"""
    return sum((x - y) ** 2 for x, y in zip(p1, p2))

def nearest_neighbor(root, target):
    """查找目标点的最近邻"""
    best = None          # 当前最近点
    best_dist = float('inf')  # 当前最小距离
    
    def recursive_search(node):
        nonlocal best, best_dist
        
        if node is None:
            return
        
        # 计算当前节点距离
        current_dist = squared_distance(node.point, target)
        if current_dist < best_dist:
            best = node.point
            best_dist = current_dist
        
        # 决定搜索方向
        axis = node.axis
        if target[axis] < node.point[axis]:
            recursive_search(node.left)   # 先搜索左子树
            # 检查右子树是否需要搜索
            if (node.point[axis] - target[axis]) ** 2 < best_dist:  # 
                recursive_search(node.right)
        else:
            recursive_search(node.right)  # 先搜索右子树
            # 检查左子树是否需要搜索
            if (target[axis] - node.point[axis]) ** 2 < best_dist:
                recursive_search(node.left)
    
    recursive_search(root)
    return best

# 示例用法
if __name__ == "__main__":
    
    # 示例数据点(二维)
    points = [(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)]
    
    # 构建KD树
    kdtree = build_kdtree(points)
    
    # 查找最近邻
    target = (8, 0.5)
    nearest = nearest_neighbor(kdtree, target)
    print(f"Target: {target}  Nearest: {nearest}")
    
    
 #运行结果:
    Target: (8, 0.5)  Nearest: (8, 1)    

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
作者已关闭评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、 KD树的定义
    • 1.1 构造平衡KD树
    • 1.2 搜索 KD树的的最近邻搜索算法
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档