xiaobaoqiu Blog

Think More, Code Less

Closest Pair

Problem:

Given a set of points in a two dimensional space, you will have to find the distance between the closest two points.

示例:

input

1
2
3
4
5
0 2
6 67
43 71
39 107
189 140

output

1
36.2215

Solutions:

1. Brute force

计算全部n(n-1)/2对点对,找到最近的点对.

时间复杂度:O(n2)

空间复杂度:O(1)

2. Divide and conquer

步骤如下: 1. sort 按照x坐标(y坐标也可以)从小达到排序,时间复杂度O(nlogn); 2. divide 找到中间点q,用q将所有点P所有的点划分为左右两个部分PL和PR; 3. conquer 递归计算PL和PR这两个点集的最近距离,二者中的较小值记为d; 4. combine 取PL和PR中x坐标与点q的x坐标距离小于d的所有点(记为PM,即为一个宽度为2d的竖条带),将PM按照y坐标排序;遍历PM,在y坐标差值小于d的情况下计算距离;

其中1 2 3步骤很明显,步骤4技巧比较多:

  1. 最多6个点满足y坐标差值小于d,就是说我们的遍历PM在O(k)时间内搞定(k为PM点的数目),证明见 参考:http://www.cs.mcgill.ca/~cs251/ClosestPair/proofbox.html
  2. 计算距离,不用计算sqrt((x1-x2)(x1-x2) + (y1-y2)(y1-y2)),取sqrt内的值比较就可以;
  3. 按照y值排序PM,这是combine的性能瓶颈,我们可以在递归时候做merge排序,避免每次combine时候都重新排序的代价,因此combine的代价为O(n);

总的时间代价:

1
T(n) = 2T(n/2) + O(n),即T(n)=O(nlogn)

参考代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import com.google.common.primitives.Doubles;
import java.awt.*;
import java.awt.geom.Point2D;
import java.util.Arrays;
import java.util.Comparator;

public class ClosestPair {

    // 最近的点对及其距离
    private Point2D best1, best2;
    private double bestDistance = Double.POSITIVE_INFINITY;

    //x排序器
    private Comparator<Point2D> X_ORDER = new Comparator<Point2D>() {
        @Override
        public int compare(Point2D o1, Point2D o2) {
            return Doubles.compare(o1.getX(), o2.getX());
        }
    };

    public ClosestPair(Point2D[] points) {
        int N = points.length;
        if (N <= 1) return;

        //按照x排序
        Point2D[] pointsByX = Arrays.copyOf(points, N);
        Arrays.sort(pointsByX, X_ORDER);

        // 检查重合的点
        for (int i = 0; i < N-1; i++) {
            if (pointsByX[i].equals(pointsByX[i+1])) {
                bestDistance = 0.0;
                best1 = pointsByX[i];
                best2 = pointsByX[i+1];
                return ;
            }
        }

        // 用于按照y排序(这里还没排序)
        Point2D[] pointsByY = Arrays.copyOf(pointsByX, N);

        // 辅助数组
        Point2D[] aux = new Point2D[N];

        closest(pointsByX, pointsByY, aux, 0, N-1);
    }

    /**
     * 找pointsByX[lo..hi]中的最近点对
     * @param pointsByX : 按照x坐标排序好的点
     * @param pointsByY
     * @param aux : 辅助数组
     * @param lo : 最小下标
     * @param hi : 最大下标
     * @return
     */
    private double closest(Point2D[] pointsByX, Point2D[] pointsByY, Point2D[] aux, int lo, int hi) {
        if (hi <= lo) return Double.POSITIVE_INFINITY;

        // 中间点
        int mid = (lo + hi) >> 1;
        Point2D median = pointsByX[mid];

        // 递归求解左右子数组的最近点对
        double d = Math.min(closest(pointsByX, pointsByY, aux, lo, mid),
                closest(pointsByX, pointsByY, aux, mid+1, hi));

        // merge pointsByY[lo,mid]和pointsByY[mid+1, hi], 实现按照y坐标排序
        merge(pointsByY, aux, lo, mid, hi);

        // 将按照y排序好的点, 和中间点距离小于d的存在辅助数组中,即为宽度2d的中间条带
        int M = 0;
        for (int i = lo; i <= hi; i++) {
            if (Math.abs(pointsByY[i].getX() - median.getX()) < d)
                aux[M++] = pointsByY[i];
        }

        // 比较中间条带内的点
        for (int i = 0; i < M; i++) {
            // a geometric packing argument shows that this loop iterates at most 7 times
            for (int j = i+1; (j < M) && (aux[j].getY() - aux[i].getY() < d); j++) {
                double distance = aux[i].distance(aux[j]);
                if (distance < d) {
                    d = distance;
                    if (distance < bestDistance) {
                        bestDistance = d;
                        best1 = aux[i];
                        best2 = aux[j];
                    }
                }
            }
        }
        return d;
    }

    /**
     * 利用辅助数组aux[lo .. hi]将a[lo .. mid] 和 a[mid+1 ..hi]合并,
     * 保证字数组a[lo .. mid]和a[mid+1 ..hi]都是有序
     * 排序准则为y坐标,稳定
     * @param a : 待合并数组
     * @param aux : 辅助数组
     * @param lo
     * @param mid
     * @param hi
     */
    private static void merge(Point2D[] a, Point2D[] aux, int lo, int mid, int hi) {
        // 复制到辅助数组
        for (int k = lo; k <= hi; k++) {
            aux[k] = a[k];
        }

        // merge 回 a[]
        int i = lo, j = mid+1;
        for (int k = lo; k <= hi; k++) {
            if      (i > mid)              a[k] = aux[j++];
            else if (j > hi)               a[k] = aux[i++];
            else if (aux[j].getY() < aux[i].getY()) a[k] = aux[j++];
            else                           a[k] = aux[i++];
        }
    }

    public Point2D either() { return best1; }
    public Point2D other()  { return best2; }
    public double distance() { return bestDistance; }

    public static void main(String[] args) {
        int N = 5;
        Point2D[] points = new Point2D[]{
                new Point(0, 2),
                new Point(6, 67),
                new Point(43, 71),
                new Point(39, 107),
                new Point(189, 140)
        };
        ClosestPair closest = new ClosestPair(points);
        System.out.println(closest.distance() + " from " + closest.either() + " to " + closest.other());
    }
}