Closest pair of points algorithm

10,671

Fast Algorithm using a KD-Tree
This algorithm creates a kd-tree and then finds the closest pair for each point. Creating the kd-tree is O(n log2n), and finding the closest neighbour of a point is O(logn). Credit must go to Wikipedia, which in one article explains how to create kd-trees and also how to use them to find the closest neighbour.

import java.util.*;

public class Program
{
    public static void main(String[] args)
    {
        List<Point> points = generatePoints();
        Point[] closest = new Point[points.size()];

        KDTree tree = new KDTree(points, 0); // WILL MODIFY 'points'

        for (int i = 0; i < points.size(); i++)
        {
            closest[i] = tree.findClosest(points.get(i));
        }

        for (int i = 0; i < points.size(); i++)
        {
            System.out.println(points.get(i) + " is closest to " + closest[i]);
        }
    }

    private static List<Point> generatePoints()
    {
        ArrayList<Point> points = new ArrayList<Point>();
        Random r = new Random();

        for (int i = 0; i < 1000; i++)
        {
            points.add(new Point(r.nextInt() % 1000, r.nextInt() % 1000));
        }

        return points;
    }
}

class Point
{
    public static final Point INFINITY
        = new Point(Double.POSITIVE_INFINITY,
                    Double.POSITIVE_INFINITY);

    public double[] coord; // coord[0] = x, coord[1] = y

    public Point(double x, double y)
    {
        coord = new double[] { x, y };
    }

    public double getX() { return coord[0]; }
    public double getY() { return coord[1]; }

    public double distance(Point p)
    {
        double dX = getX() - p.getX();
        double dY = getY() - p.getY();
        return Math.sqrt(dX * dX + dY * dY);
    }

    public boolean equals(Point p)
    {
        return (getX() == p.getX()) && (getY() == p.getY());
    }

    public String toString()
    {
        return "(" + getX() + ", " + getY() + ")";
    }

    public static class PointComp implements Comparator<Point>
    {
        int d; // the dimension to compare in (0 => x, 1 => y)

        public PointComp(int dimension)
        {
            d = dimension;
        }

        public int compare(Point a, Point b)
        {
            return (int) (a.coord[d] - b.coord[d]);
        }
    }
}

class KDTree
{
    // 2D k-d tree
    private KDTree childA, childB;
    private Point point; // defines the boundary
    private int d; // dimension: 0 => left/right split, 1 => up/down split

    public KDTree(List<Point> points, int depth)
    {
        childA = null;
        childB = null;
        d = depth % 2;

        // find median by sorting in dimension 'd' (either x or y)
        Comparator<Point> comp = new Point.PointComp(d);
        Collections.sort(points, comp);

        int median = (points.size() - 1) / 2;
        point = points.get(median);

        // Create childA and childB recursively.
        // WARNING: subList() does not create a true copy,
        // so the original will get modified.
        if (median > 0)
        {
            childA = new KDTree(
                points.subList(0, median),
                depth + 1);
        }
        if (median + 1 < points.size())
        {
            childB = new KDTree(
                points.subList(median + 1, points.size()),
                depth + 1);
        }
    }

    public Point findClosest(Point target)
    {
        Point closest = point.equals(target) ? Point.INFINITY : point;
        double bestDist = closest.distance(target);
        double spacing = target.coord[d] - point.coord[d];
        KDTree rightSide = (spacing < 0) ? childA : childB;
        KDTree otherSide = (spacing < 0) ? childB : childA;

        /*
         * The 'rightSide' is the side on which 'target' lies
         * and the 'otherSide' is the other one. It is possible
         * that 'otherSide' will not have to be searched.
         */

        if (rightSide != null)
        {
            Point candidate = rightSide.findClosest(target);
            if (candidate.distance(target) < bestDist)
            {
                closest = candidate;
                bestDist = closest.distance(target);
            }
        }

        if (otherSide != null && (Math.abs(spacing) < bestDist))
        {
            Point candidate = otherSide.findClosest(target);
            if (candidate.distance(target) < bestDist)
            {
                closest = candidate;
                bestDist = closest.distance(target);
            }
        }

        return closest;
    }
}


Fix to the code in the question
If you really don't worry about the complexity, the only problem with your code is that you look forward but not backwards. Just duplicate the inner loop and make j go from (i - 1) to 0:

Point[] points = sort(input());
int[] closest = new int[points.length];

for (int i = 0; i < points.length; i++)
{
    double bestdist = Double.POSITIVE_INFINITY;

    for (int j = i + 1; (j < points.length) && ((points[j].x - points[i].x) < bestdist); j++ )
    {
        double currdist = dist(points[i], points[j]);

        if (currdist < bestdist)
        {
            closest[i] = j;
            bestdist = currdist;
        }
    }
    for (int j = i - 1; (j >= 0) && ((points[i].x - points[j].x) < bestdist); j-- )
    {
        double currdist = dist(points[i], points[j]);

        if (currdist < bestdist)
        {
            closest[i] = j;
            bestdist = currdist;
        }
    }
}
Share:
10,671
Paul
Author by

Paul

Updated on July 08, 2022

Comments

  • Paul
    Paul almost 2 years

    I am trying to implement a simpler version of this algorithm but which works better than the quadratic algorithm. My idea basically is to sort the points by only x coordinate and try to solve it from there. Once I sort my array of points by x coordinate, I want to iterate over the array and basically skip over points whose distance is greater than the first two points I took at.

    For example, my currentminDist = x;

    If the two pair of points I am looking at have distance > x (only by its x coord dist), I ignore the point and move past it in the array.

    I have the idea down, but I am kind of stuck on how to actually implement this (especially the condition part). I have a function that returns me the distance between two points based on their x coordinate.

    I am confused on how to actually write my conditions for my loop since I want to ignore a point if the distance happens to be too far and still fill out my array which will contain the answers for closest points for each i (i being current point I am looking at).

    Any tips or directions would be greatly appreciated. I am not very knowledgeable in coding algorithms so its quite frustrating.

    Here is part of my code:

    for (i = 0; i < numofmypoints; i++)
            {
                for (int j = i + 1; (j < numpofmypoints) && ((inputpoints[j].x - inputpoints[i].x) < currbest); j++ )
                {
                    currdist = Auxilary.distbyX(inputpoints[i],inputpoints[j]);
    
                    if (currdist < bestdist) 
                    {
                     closest[i] = j;
                     bestdist = currdist;
    
                    }
                }
            }
    

    distbyX is my function that just returns the distance between two points.

    Thanks!

  • Paul
    Paul about 12 years
    I am not worried about the worst case. I am assuming all the x values are distinct. That's why I want to try and solve it the way I laid it out. Your way makes sense where I can use a data structure to solve it, but I was wondering if it could be solved the way I described. I ran into the problem of it not calculating the nearest point for all the points, it only calculates it for a a few of them and the remaining are all just the same point repeated over and over. So it's why I was trying to see if I was going wrong somewhere.
  • tom
    tom about 12 years
    The classic 'Closest pair of points' problem is to find the pair of points which are closest to each other. Only now I realise that your problem is a different one - find the closest neighbour for each point. I will update my answer as soon as I can think of an algorithm.
  • tom
    tom about 12 years
    @Paul: I couldn't figure out a way to improve your sweepline to O(good), so I did it using a kd-tree.
  • Paul
    Paul about 12 years
    Ah yes, I looked at the kd tree implementation. I used a quad tree to solve it when I was using a data structure. But I wanted to implement a simpler sweep line algorithm and I could not figure out why my array was not being filled out. That makes a lot of sense now looking at it now.