前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >knn之构造kd树和最近邻求取c++实现

knn之构造kd树和最近邻求取c++实现

作者头像
用户7886150
修改2021-02-14 15:32:16
5720
修改2021-02-14 15:32:16
举报
文章被收录于专栏:bit哲学院

参考链接: C++ fdim()

这份代码测试样例为 

6

7 2

2 3

5 4

4 7

9 6

8 1

8 2

这样,通过中位数来选取根节点(这样的方法其实在一定程度上是有很大问题的,因为根节点的选取方法不同,会导致整棵树的结构不同,这里由于数据的关系,不能构成完全二叉树,所以在对于特殊的样例来说是会出错的,比如说(10,10)这个测试样例,根本无法找到包含他的子节点(区域),所以会导致出错))。 

#include<iostream>

#include<algorithm>

#include<cstring>

#include<vector>

#include<cmath>

#include<queue>

using namespace std;

struct node{

    pair<int,int>x;

    int dim;

    node*left;

    node*right;

    node*father;

    node(pair<int,int>p=make_pair(0,0),int dim=0,node*left=0,node*right=0,node*father=0)

    :dim(dim),left(left),right(right),father(father)

    {

     x=p;    

    }

};

bool cmp1(node*a,node* b)

{

    return a->x.first<b->x.first;

}

bool cmp2(node*a,node*b)

{

    return a->x.second<b->x.second;

}

vector<node*>vec;

node* buildtree(vector<node*>temp,int cnt)

if(temp.size()==0)

return 0;

else if(temp.size()==1)

return temp[0];

else{

    if(cnt==1)

     sort(temp.begin(),temp.end(),cmp1);

     else

     sort(temp.begin(),temp.end(),cmp2);

     int mid=temp.size()/2;

    vector<node*>p;

    for(int i=0;i<mid;i++)

    {

        p.push_back(temp[i]);

    }

    vector<node*>q;

    for(int i=mid+1;i<temp.size();i++)

    {

        q.push_back(temp[i]);

    }

    node*left=buildtree(p,(cnt+1)%2);

    node*right=buildtree(q,(cnt+1)%2);

    node*fat=new node(make_pair(temp[mid]->x.first,temp[mid]->x.second),cnt,left,right,0);

    if(left!=0)

    left->father=fat;

    if(right!=0)

    right->father=fat;

    //cout<<fat->x.first<<" "<<fat->x.second<<endl;

    return fat;

}

}

void traverse(node*root) 

{

    if(root==0)

    {

    }

    else

    {

        cout<<root->x.first<<" "<<root->x.second<<endl;

        traverse(root->left);

        traverse(root->right);

    }

}

node*find_first_belong(node*key,node*root)

{

    node*temp=root; 

    while(true) //遍历找到其归属的叶节点 

    {

        if(temp->left==0&&temp->right==0)

        {

            break;

        }

        else

        {

            int dim=temp->dim;//选择维度比较 

            if(dim==1)//选择x1比较 

            {

                if(key->x.first<=temp->x.first)

                temp=temp->left;

                else

                temp=temp->right;

            }

            else //选择x2比较 

            {

                if(key->x.second<=temp->x.second)

                 temp=temp->left;

                 else

                 temp=temp->right;

            }

        }

    }

    return temp;

}

double distance(node*a,node*b)

{

    double ax1=a->x.first;

    double ax2=a->x.second;

    double bx1=b->x.first;

    double bx2=b->x.second;

    return sqrt(pow(ax1-bx1,2)+pow(ax2-bx2,2));

}

node*query(node*key,node*root,double mindis)

//这里就是最不明白的一点,当另一区域跟圆相交,书上说是递归进行最近邻搜索,

//没搞懂到底怎么递归搜索,所以这里就直接用了很简单的遍历比较,希望以后能搞懂 

{

    node*rec=root;

    double mind=mindis;

    queue<node*>q;

    q.push(root);

    while(!q.empty())

    {

        node*temp=q.front();

        double dis=distance(key,temp);

        if(dis<mind)

        {

            mind=dis;

            rec=temp;

        }

        q.pop();

        if(temp->left!=0)

        q.push(temp->left);

        if(temp->right)

        q.push(temp->right);

    }

    return rec;

}

node*find_nearest(node*key,node*belong)

{

    node *nearest=belong;

     double mindis=distance(key,belong);

     //cout<<mindis<<" mindis"<<endl;

     while(true)

     {

     //cout<<belong->x.first<<" "<<belong->x.second<<endl;

     node*fat=belong->father;

     if(fat==0)

     break;

     int fdim=fat->dim;

     if(distance(fat,key)<mindis)

     {

         mindis=distance(fat,key);

         nearest=fat;

     }

     if(fdim==1) //判断圆是否与x1=fat->x.first相交 

     {

         int fx1=fat->x.first;

         int kx1=key->x.first;

         if(abs(fx1-kx1)<mindis)

         {

             node*res=query(key,fat->right,mindis);

             if(res!=0&&distance(res,key)<mindis)

             {

                 nearest=res;

                 mindis=distance(res,key);

             }

         }

     }

     else //反之 

     {

          int fx2=fat->x.second;

          int kx2=key->x.second;

          if(abs(fx2-kx2)<mindis)

          {

              node*res=query(key,fat->right,mindis);

             if(res!=0&&distance(res,key)<mindis)

             {

                 nearest=res;

                 mindis=distance(res,key);

             }

          }

     }

     belong=fat;

     if(belong==0)

     break;

  }

  return nearest;

}

node*search(node*key,node*root)

{

    node* belong=find_first_belong(key,root);

    //cout<<belong->x.first<<" "<<belong->x.second<<endl;

    node* nearest=find_nearest(key,belong);

}

int main()

{

    int n;

    cin>>n;

    for(int i=0;i<n;i++)

    {

        int x,y;

        cin>>x>>y;

        node* temp=new node(make_pair(x,y));

        vec.push_back(temp);

    }

    node*root=buildtree(vec,1);

    //traverse(root);

    int x,y;

    cin>>x>>y;

    node *key=new node(make_pair(x,y));

    node*near=search(key,root);

    cout<<near->x.first<<" "<<near->x.second<<endl;

} 以上代码,经过测试,除了(10,10)这种类似的特殊数据会出错,别的基本正确,代码写的很乱。。。。 

这里还有一个很大的问题在于,我不知道一旦判定了圆和其他区域相交之后该怎么进行递归搜索,所以这里直接用了遍历。。。。   

总算搞懂了什么递归搜索: 

下面的是第二个版本: 

#include<iostream>

#include<algorithm>

#include<cstring>

#include<vector>

#include<cmath>

#include<queue>

using namespace std;

struct node{

    pair<int,int>x;

    int dim;

    node*left;

    node*right;

    node*father;

    node(pair<int,int>p=make_pair(0,0),int dim=0,node*left=0,node*right=0,node*father=0)

    :dim(dim),left(left),right(right),father(father)

    {

     x=p;    

    }

};

bool cmp1(node*a,node* b)

{

    return a->x.first<b->x.first;

}

bool cmp2(node*a,node*b)

{

    return a->x.second<b->x.second;

}

vector<node*>vec;

node* buildtree(vector<node*>temp,int cnt)

if(temp.size()==0)

return 0;

else if(temp.size()==1)

return temp[0];

else{

    if(cnt==1)

     sort(temp.begin(),temp.end(),cmp1);

     else

     sort(temp.begin(),temp.end(),cmp2);

     int mid=temp.size()/2;

    vector<node*>p;

    for(int i=0;i<mid;i++)

    {

        p.push_back(temp[i]);

    }

    vector<node*>q;

    for(int i=mid+1;i<temp.size();i++)

    {

        q.push_back(temp[i]);

    }

    node*left=buildtree(p,(cnt+1)%2);

    node*right=buildtree(q,(cnt+1)%2);

    node*fat=new node(make_pair(temp[mid]->x.first,temp[mid]->x.second),cnt,left,right,0);

    if(left!=0)

    left->father=fat;

    if(right!=0)

    right->father=fat;

    //cout<<fat->x.first<<" "<<fat->x.second<<endl;

    return fat;

}

}

void traverse(node*root) 

{

    if(root==0)

    {

    }

    else

    {

        cout<<root->x.first<<" "<<root->x.second<<endl;

        traverse(root->left);

        traverse(root->right);

    }

}

node*find_first_belong(node*key,node*root)

{

    node*temp=root; 

    while(true) //遍历找到其归属的叶节点 

    {

        if(temp->left==0&&temp->right==0)

        {

            break;

        }

        else

        {

            int dim=temp->dim;//选择维度比较 

            if(dim==1)//选择x1比较 

            {

                if(key->x.first<=temp->x.first)

                temp=temp->left;

                else

                temp=temp->right;

            }

            else //选择x2比较 

            {

                if(key->x.second<=temp->x.second)

                 temp=temp->left;

                 else

                 temp=temp->right;

            }

        }

    }

    return temp;

}

double distance(node*a,node*b)

{

    double ax1=a->x.first;

    double ax2=a->x.second;

    double bx1=b->x.first;

    double bx2=b->x.second;

    return sqrt(pow(ax1-bx1,2)+pow(ax2-bx2,2));

}

node*query(node*key,node*root,double mindis)//没有用的函数

{

    node*rec=root;

    double mind=mindis;

    queue<node*>q;

    q.push(root);

    while(!q.empty())

    {

        node*temp=q.front();

        double dis=distance(key,temp);

        if(dis<mind)

        {

            mind=dis;

            rec=temp;

        }

        q.pop();

        if(temp->left!=0)

        q.push(temp->left);

        if(temp->right)

        q.push(temp->right);

    }

    return rec;

}

node*find_nearest(node*key,node*belong,node*root)

{

    node *nearest=belong;

     double mindis=distance(key,belong);

     //cout<<belong->x.first<<" belong "<<belong->x.second<<endl;

     //cout<<mindis<<" mindis"<<endl;

     while(true)

     {

     //cout<<belong->x.first<<" "<<belong->x.second<<endl;

     node*fat=belong->father;

     if(fat==0||fat==root->father)

     break;

     node*other=new node(); //相比第一个这里还更加对了,因为这里还考虑到了万一归属的叶节点不是左节点的情况

     if(fat->left==belong)

     {

         other=fat->right;

     }

     else

     other=fat->left;

     //cout<<fat->x.first<<" "<<" fat  "<<fat->x.second<<endl;

     int fdim=fat->dim;

     if(distance(fat,key)<mindis)

     {

         mindis=distance(fat,key);

         nearest=fat;

     }

     if(fdim==1) //判断圆是否与x1=fat->x.first相交 

     {

         int fx1=fat->x.first;

         int kx1=key->x.first;

         if(abs(fx1-kx1)<mindis)

         {

             node*tm=find_first_belong(key,other);

             node*res=find_nearest(key,tm,other); //传说中的递归搜索在这里,利用他之前的函数

             if(res!=0&&distance(res,key)<mindis)

             {

                 nearest=res;

                 mindis=distance(res,key);

             }

         }

         //cout<<fx1<<" xxxx   "<<kx1<<" "<<mindis<<endl;

     }

     else //反之 

     {

          int fx2=fat->x.second;

          int kx2=key->x.second;

          if(abs(fx2-kx2)<mindis)

          {

              node*tm=find_first_belong(key,other);

              //cout<<tm->x.first<<" **** "<<tm->x.second<<endl;

              //cout<<other->x.first<<" other "<<other->x.second<<endl;

             node*res=find_nearest(key,tm,other);

             if(res!=0&&distance(res,key)<mindis)

             {

                 nearest=res;

                 mindis=distance(res,key);

                 //cout<<mindis<<"  mindis"<<endl;

             }

          }

     }

     belong=fat;

     if(belong==0)

     break;

  }

  return nearest;

}

node*search(node*key,node*root)

{

    node* belong=find_first_belong(key,root);

    //cout<<belong->x.first<<" "<<belong->x.second<<endl;

    node* nearest=find_nearest(key,belong,root);

    return nearest;

}

int main()

{

    int n;

    cin>>n;

    for(int i=0;i<n;i++)

    {

        int x,y;

        cin>>x>>y;

        node* temp=new node(make_pair(x,y));

        vec.push_back(temp);

    }

    node*root=buildtree(vec,1);

    //traverse(root);

    int x,y;

    cin>>x>>y;

    node *key=new node(make_pair(x,y));

    node*near=search(key,root);

    cout<<"the nearest point is "<<near->x.first<<" "<<near->x.second<<endl;

} 然而还是没有解决(10,10)的情况,明天再说!!!!

本文系转载,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文系转载前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档