Nearest Neighbour Queries
Given a target T and a set of points S, find the nearest neighbour of T in S.
https://www.youtube.com/watch?v=Glp7THUpGow
https://www.youtube.com/watch?v=XG4zpiJAkD4
https://www.cs.cmu.edu/~ckingsf/bioinfo-lectures/kdtrees.pdf
package algorithmsinanutshell.spatialtree
import algorithmdesignmanualbook.print
import kotlin.math.pow
/**
 * Given a target T and a set of points S, find the nearest neighbour of T in S.
 *
 * https://www.youtube.com/watch?v=Glp7THUpGow
 *
 * https://www.youtube.com/watch?v=XG4zpiJAkD4
 *
 * https://www.cs.cmu.edu/~ckingsf/bioinfo-lectures/kdtrees.pdf
 */
class NearestNeighbourQueries(val array: Array<Array<Int>>) {
    private val tree: MultiDimNode = KDTree(array).tree
    private fun findNearest(subTree: MultiDimNode?, target: Array<Int>, depth: Int): MultiDimNode? {
        if (subTree == null) return null
        val dimensionIndex = depth % target.size
        val nextBranch: MultiDimNode?
        val otherBranch: MultiDimNode?
        if (target[dimensionIndex] < subTree.value[dimensionIndex]) {
            nextBranch = subTree.left
            otherBranch = subTree.right
        } else {
            nextBranch = subTree.right
            otherBranch = subTree.left
        }
        var nodeFromNextBranch = findNearest(nextBranch, target, depth + 1)
        var best = closest(MultiDimNode(target), nodeFromNextBranch, subTree)
        val euclideanDistance = best!!.distanceFrom(target)
        val perpendicularDistance = (target[dimensionIndex] - subTree.value[dimensionIndex]).toDouble().pow(2)
        // Traverse into the unvisited section i.e otherBranch if the best distance found so far is bigger than
        // the perpendicular distance to the unvisited branch
        if (euclideanDistance >= perpendicularDistance) {
            nodeFromNextBranch = findNearest(otherBranch, target, depth + 1)
            best = closest(MultiDimNode(target), nodeFromNextBranch, subTree)
            return best
        }
        return best
    }
    fun execute(target: Array<Int>): MultiDimNode? {
        return findNearest(tree, target, 0)
    }
    private fun closest(target: MultiDimNode, vararg p1: MultiDimNode?): MultiDimNode? {
        return p1.toList()
            .filterNotNull()
            .minByOrNull {
                it.distanceFrom(target.value)
            }
    }
}
fun main() {
    run {
        val array = arrayOf(
            arrayOf(3, 6), arrayOf(17, 15), arrayOf(13, 15),
            arrayOf(6, 12), arrayOf(9, 1), arrayOf(2, 7), arrayOf(10, 19)
        )
        NearestNeighbourQueries(array).execute(arrayOf(1, 2)).print()
    }
    run {
        val array =
            arrayOf(arrayOf(30, 40), arrayOf(5, 25), arrayOf(10, 12), arrayOf(70, 70), arrayOf(50, 30), arrayOf(35, 45))
        NearestNeighbourQueries(array).execute(arrayOf(52, 52)).print()
    }
}
Updated on 2021-08-22