04月29, 2018

最近点对问题的分治算法的具体实现

最近点对问题是一个经典问题,目标是在n个点中找到距离最近的一对点。

做法可以有很多,直接上KD-Tree都是可以的。但最常用的还是分治法。

其主要思想很简单,就是把问题分成两个子问题,再进行合并。不过在这之前,我们要先对点按x排序。求解时先把一堆点平均分成两堆点,分别递归求解两堆点的最近点对,记为 \delta_1 \delta_2 ,然后取 \delta=\min(\delta_1,\delta_2) 。之后我们把横坐标距离中间点在 [-\delta,\delta] 范围之内的点取出来,找它们之间的最近点对,与 \delta \min 即可。

不过这里面有个问题,最坏情况下,合并操作复杂度可能高达 O(n^2) ,这显然不可接受。实际上,这里面有个很好的性质,即,若 [-\delta,\delta] 之内的点是按y有序的,那么对于任何一个点,与它的距离可能小于 \delta 的点不会超过8个(它们一定在中间点为中心线的 \delta\times 2\delta 的矩形区域内,所以不超过8个,具体证明详见算法导论,不过实际上还可以证明是6个)。

因此,我们的目的就是要在合并时,除了一开始按x有序的数组外,还要拥有一个按y有序的数组。一个很容易想到的做法是,我们在把 [-\delta,\delta] 范围之内的点取出来之后,对它们做排序,但是这样时间复杂度在最坏情况下就变成了 O(n\log^2 n) 具体实现见代码1。 而且网上能找到的几乎所有的代码都是这么写的,最关键是他们居然都声称自己程序的复杂度是 O(n\log n) 的。

但实际上这个问题是有 O(n\log n) 的解法的,具体的做法在算法导论上有很详细的说明。其核心思想是,我们不仅对x做预排序,还要维护一个对y预排序的数组。这样我们在合并的时候就可以直接用了,而在开始递归调用处理子问题之前,我们可以在 O(n) 的时间内,对它进行按x划分,以中心点为界,划分成两部分,这样的话保证了在处理子问题时y数组也是正确划分的。不过这里面有个细节要注意,就是我们用这个数组是在合并的时候,但递归调用是在这之前,我们需要在递归调用之前已经划分完成才行,因此我们必须将划分前的数组存下来供后续使用。具体实现见代码2。

问题到这好像已经结束了,但如果你真的实现了上面的算法的话,你一定会对这个实现感到恶心,实在是太丑陋了,而且没有什么特别优美的写法。于是我们再仔细一想,这个划分实际上不就是合并的逆过程吗?从大问题划分,跟从小问题合并,岂不是一样的过程?而且这个划分是从有序数组划分成两个有序数组,为什么我们不从两个有序数组合并成一个有序数组呢?这好像就是归并排序的过程呀!事实上确实是这样,我们根本不需要进行划分,而像分治求最近点对一样,一步步合并出我们需要的按y有序的数组即可。具体实现见代码3。

现在我们编写代码测试正确性和效率,以HDU 1007为例。

首先是头文件和主函数。

#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<ctime>
#include<cassert>
#include<climits>
#include<iostream>
#include<algorithm>
#include<string>
#include<vector>
#include<deque>
#include<list>
#include<set>
#include<map>
#include<stack>
#include<queue>
#include<numeric>
#include<iomanip>
#include<bitset>
#include<sstream>
#include<fstream>
#define debug puts("-----")
#define pi (acos(-1.0))
#define eps (1e-8)
#define inf (1<<30)
#define INF (1ll<<62)
using namespace std;
int main()
{
    int n;
    while(cin>>n,n)
    {
        init(n); //input and presort
        printf("%.2f\n",ClosestPair(0,n-1)/2);
    }
    return 0;
}

代码1(cp_sort.cpp):

const int NV=300005;
struct point
{
    double x, y;
    void in()
    {
        scanf("%lf%lf",&x,&y);
    }
} p[NV];
int pyidx[NV];
bool cmpX(const point &p1,const point &p2)
{
    return p1.x<p2.x;
}
bool cmpY(const int a,const int b)
{
    return p[a].y<p[b].y;
}
double dis(const point &p1,const point &p2)
{
    return sqrt((p1.x-p2.x)*(p1.x-p2.x)+(p1.y-p2.y)*(p1.y-p2.y));
}
double ClosestPair(int l,int r) //[l, r]
{
    if (r-l+1==2) return dis(p[l],p[r]);
    if (r-l+1==3) return min(dis(p[l],p[r]),min(dis(p[l],p[l+1]),dis(p[l+1],p[r])));
    int m=l+r>>1;
    double delta=min(ClosestPair(l,m),ClosestPair(m+1,r)); //\delta=\min(\delta_1,\delta_2)
    int cnt=0;
    for (int i=l; i<=r; i++) //把所有x坐标在p[m].x的[-delta, delta]的点取出来
    {
        if(p[m].x-delta<=p[i].x&&p[i].x<=p[m].x+delta)
            pyidx[cnt++]=i;
    }
    sort(pyidx,pyidx+cnt,cmpY);
    for(int i=0; i<cnt; i++)
        for(int j=i+1; j<cnt; j++)
        {
            if(p[pyidx[j]].y-p[pyidx[i]].y>=delta) break; //最多寻找6个点后一定会break
            delta=min(delta,dis(p[pyidx[i]],p[pyidx[j]]));
        }
    return delta;
}
void init(int n)
{
    for (int i=0; i<n; i++) p[i].in();
    sort(p,p+n,cmpX);
}

代码2(cp_partition.cpp):

const int NV=300005;
struct point
{
    double x, y;
    void in()
    {
        scanf("%lf%lf",&x,&y);
    }
} p[NV],py[NV],pytmp[NV];
int pyidx[NV];
bool cmpX(const point &p1,const point &p2)
{
    return p1.x<p2.x;
}
bool cmpY(const point &p1,const point &p2)
{
    return p1.y<p2.y;
}
double dis(const point &p1,const point &p2)
{
    return sqrt((p1.x-p2.x)*(p1.x-p2.x)+(p1.y-p2.y)*(p1.y-p2.y));
}
double midline;
bool pred(const point &p)
{
    return p.x<=midline;
}
double ClosestPair(int l,int r) //[l, r]
{
    if (r-l+1==2) return dis(p[l],p[r]);
    if (r-l+1==3) return min(dis(p[l],p[r]),min(dis(p[l],p[l+1]),dis(p[l+1],p[r])));
    int m=l+r>>1;
    midline=p[m].x;
    vector<point> pytmp(py+l,py+r+1); //在进行划分操作导致改变y的有序性之前先记录一下
    stable_partition(py+l,py+r+1,pred); //对按y排序的数组用中间线划分
    double delta=min(ClosestPair(l,m),ClosestPair(m+1,r)); //\delta=\min(\delta_1,\delta_2)
    int cnt=0;
    for (int i=0; i<=r-l; i++) //把所有x坐标在p[m].x的[-delta, delta]的点取出来
    {
        if(p[m].x-delta<=pytmp[i].x&&pytmp[i].x<=p[m].x+delta)
            pyidx[cnt++]=i;
    }
    for(int i=0; i<cnt; i++)
        for(int j=i+1; j<cnt; j++)
        {
            if(pytmp[pyidx[j]].y-pytmp[pyidx[i]].y>=delta) break; //最多寻找6个点后一定会break
            delta=min(delta,dis(pytmp[pyidx[i]],pytmp[pyidx[j]]));
        }
    return delta;
}
void init(int n)
{
    for (int i=0; i<n; i++) p[i].in();
    sort(p,p+n,cmpX);
    for (int i=0; i<n; i++) py[i]=p[i];
    sort(py,py+n,cmpY);
}

代码3(cp_merge.cpp):

const int NV=300005;
struct point
{
    double x, y;
    void in()
    {
        scanf("%lf%lf",&x,&y);
    }
} p[NV],py[NV],pytmp[NV];
int pyidx[NV];
bool cmpX(const point &p1,const point &p2)
{
    return p1.x<p2.x;
}
bool cmpY(const point &p1,const point &p2)
{
    return p1.y<p2.y;
}
double dis(const point &p1,const point &p2)
{
    return sqrt((p1.x-p2.x)*(p1.x-p2.x)+(p1.y-p2.y)*(p1.y-p2.y));
}
double ClosestPair(int l,int r) //[l, r]
{
    if (r-l+1<=3) sort(py+l,py+r+1,cmpY);
    if (r-l+1==2) return dis(p[l],p[r]);
    if (r-l+1==3) return min(dis(p[l],p[r]),min(dis(p[l],p[l+1]),dis(p[l+1],p[r])));
    int m=l+r>>1;
    double delta=min(ClosestPair(l,m),ClosestPair(m+1,r)); //\delta=\min(\delta_1,\delta_2)
    merge(py+l,py+m+1,py+m+1,py+r+1,pytmp+l,cmpY); //利用归并排序的思想合并得到按y有序的序列
    copy(pytmp+l,pytmp+r+1,py+l);
    int cnt=0;
    for (int i=l; i<=r; i++) //把所有x坐标在p[m].x的[-delta, delta]的点取出来
    {
        if(p[m].x-delta<=py[i].x&&py[i].x<=p[m].x+delta)
            pyidx[cnt++]=i;
    }
    for(int i=0; i<cnt; i++)
        for(int j=i+1; j<cnt; j++)
        {
            if(py[pyidx[j]].y-py[pyidx[i]].y>=delta) break; //最多寻找6个点后一定会break
            delta=min(delta,dis(py[pyidx[i]],py[pyidx[j]]));
        }
    return delta;
}
void init(int n)
{
    for (int i=0; i<n; i++) p[i].in();
    sort(p,p+n,cmpX);
    for (int i=0; i<n; i++) py[i]=p[i];
}

测试结果(从上到下依次是代码3、2、1):

alt

还是可以很明显看出代码2、3比1要快,但是快得不多,主要是因为采用排序的方法可以只对 [-\delta,\delta] 内的点排序,而其他两种方法都要对区间内所有的元素都做线性操作,这样随机数据下排序的方法可能反而会快一些或者效率并没有显著差别。2、3效率基本一样,但是代码3明显看起来优美一些,而且花的内存也少一些。代码1、3的空间复杂度都是 O(n) ,2的空间复杂度是 O(n\log n) (因为要先把划分结果存起来)。

PS: 《算法导论》真的是一本好书,然而之前的几年加起来都没我最近半年翻得多……

PS2: 动态最近点对问题其实也很简单,只要我们维护有序序列时使用一个平衡的二叉搜索树即可,最终的复杂度会再乘上一个 \log n 。我记得有年网络赛就是这个题,唉,好多问题都好简单,只怪自己没认真学。

UPDATE: 可以使用inplace_merge函数合并,更省事。pred也可以用lambda表达式替换(需C++11支持)。

本文链接:https://debug.fanzheng.org/post/the-implementation-of-closest-pair-problem-using-divide-and-conquer-method.html

-- EOF --

Comments

评论加载中...

注:如果长时间无法加载,请针对 disq.us | disquscdn.com | disqus.com 启用代理。