kd树是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。kd树是二叉树,表示对k维空间点集的一个划分。构造kd树相当于不断地用垂直于坐标轴的超平面将k维空间切分,构成一系列的k维超矩形区域。kd树的每个结点对应于一个k维超矩形区域。
构造kd树的方法:构造根结点,使得根结点对应于k维空间中包含所有实例点的超矩形区域(结点)上选择一个坐标轴和在此坐标轴上的一个切分点,确定一个超平面,这个超平面通过选定的切分点并垂直于选定的坐标轴,将当前超矩形区域切分为左右两个子区域(子结点);这时,实例被分到两个子区域。这个过程直到子区域内没有实例时终止(终止时的结点为叶结点)。在此过程中,将实例保存在相应的结点上。

输入: 已构造kd树,目标点x;
输出:x的最近邻;
(1) 初始化当前最近点best_point和最近距离best_dist分别为None ,和 inf,并且声明为全局变量。
(2) 定义一个从当前结点(当前结点可以是任意结点)开始的递归搜索函数,递归终止条件是 当前结点为空,表示调用本函数的上级结点已经是叶子结点,程序直接返回。
(3) 计算目标点x与当前结点的距离,如果小于current_dist,则对当前最近点best_point和最近距离best_dist分别更新为当前结点的点数据和current_dist,否则略过。
(4) 以本结点在构建kd树时存储的划分轴axis为搜索方向。
(5) 对目标结点在axis轴上与本结点axis轴上进行比较,如果小于则递归调用搜索左树,否则递归调用搜索右树,并且调用搜索左树和右树的同时,判断当前结点和目标结点在axis上的距离是否小于已更新的(和目标点)当前最短距离,如果是,表明当前结点在axis上与目标点的距离在以best_dist为半径的超球体内,自然当前结点的另一边子树可能存在最近邻点,因而当前结点的另一边子树需要搜索。
(6) 以上2-6步构成一个递归搜索函数,在此函数外部定义一个包含(1)的直接函数,并在末尾调用从根结点开始的递归搜索函数即可实现搜索目的。并返回最邻近点。
k个最近邻 可以在上述算法的基础上设计一个容量为k的双端队列,先存储k个最邻近的点和距离,当有第k+1个点入队尾时,先删除队列第一个元素,直到程序结尾。
算法实现:
class Node:
def __init__(self, point, left=None, right=None, axis=None):
self.point = point # 节点存储的数据点
self.left = left # 左子树
self.right = right # 右子树
self.axis = axis # 划分轴(维度)
def build_kdtree(points, depth=0):
"""递归构建KD树"""
if not points: # 空节点返回None
return None
k = len(points[0]) # 数据维度
axis = depth % k # 根据深度选择划分轴
# 按当前轴排序并选择中位数作为分割点
points_sorted = sorted(points, key=lambda point: point[axis])
median_idx = len(points_sorted) // 2
# 递归构建子树
return Node(
point=points_sorted[median_idx],
axis=axis,
left=build_kdtree(points_sorted[:median_idx], depth + 1),
right=build_kdtree(points_sorted[median_idx+1:], depth + 1)
)
def squared_distance(p1, p2):
"""计算两点之间的平方距离"""
return sum((x - y) ** 2 for x, y in zip(p1, p2))
def nearest_neighbor(root, target):
"""查找目标点的最近邻"""
best = None # 当前最近点
best_dist = float('inf') # 当前最小距离
def recursive_search(node):
nonlocal best, best_dist
if node is None:
return
# 计算当前节点距离
current_dist = squared_distance(node.point, target)
if current_dist < best_dist:
best = node.point
best_dist = current_dist
# 决定搜索方向
axis = node.axis
if target[axis] < node.point[axis]:
recursive_search(node.left) # 先搜索左子树
# 检查右子树是否需要搜索
if (node.point[axis] - target[axis]) ** 2 < best_dist: #
recursive_search(node.right)
else:
recursive_search(node.right) # 先搜索右子树
# 检查左子树是否需要搜索
if (target[axis] - node.point[axis]) ** 2 < best_dist:
recursive_search(node.left)
recursive_search(root)
return best
# 示例用法
if __name__ == "__main__":
# 示例数据点(二维)
points = [(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)]
# 构建KD树
kdtree = build_kdtree(points)
# 查找最近邻
target = (8, 0.5)
nearest = nearest_neighbor(kdtree, target)
print(f"Target: {target} Nearest: {nearest}")
#运行结果:
Target: (8, 0.5) Nearest: (8, 1) 原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。