//
//  main.swift
//  kdtree
//
//  Created by Francesco Servida on 12.11.17.
//  Copyright © 2017 Francesco Servida. All rights reserved.
//

import Foundation

func kdtree(points: [[Int]], depth: Int = 0) -> [String:Any]  {
    var points_tmp = points
    // Termination condition
	if points_tmp.count == 0 {
		return [:]
	}
    
    // Select Axis
    let axis = depth % (points_tmp[0].count)
    
    // Sort Points
    points_tmp = sort(points: points_tmp, axis: axis)
    
    let median = Int(points_tmp.count / 2)
	
	let left_points: [[Int]]
	let right_points: [[Int]]
	if points_tmp.count == 1 {
		left_points = []
		right_points = []
	}
	else if points_tmp.count == 2 {
		left_points = [Array(points_tmp[0])]
		right_points = []
	}
	else {
		left_points = Array(points_tmp[0...median-1])
		right_points = Array(points_tmp[median+1...points_tmp.count-1])
	}

    let node: [String:Any] = [
        "location": points_tmp[median],
        "leftChild": kdtree(points: left_points, depth: depth+1),
        "rightChild": kdtree(points: right_points, depth: depth+1)
    ]
	
    return node
}

func sort(points: [[Int]], axis: Int) -> [[Int]] {
    if points.count == 1 {
        return points
    }
    var points_tmp = points
    for j in 1...(points_tmp.count-1){
        let current_point = points_tmp[j]
        var i = j - 1
        while (i >= 0 && points_tmp[i][axis] > current_point[axis]){
            points_tmp[i+1] = points_tmp[i]
            i = i - 1
        }
        points_tmp[i+1] = current_point
    }
    return points_tmp
}

func search(tree: [String:Any], point: [Int]) -> [String:Any]{
	// Initializes current best to unknown and r to infinite
	let best: [String: Any] = [
		"node": [],
		"r": Float.infinity
	]

	return nn(tree: tree, point: point, current_best_let: best)
}

func nn<T>(tree:[String:T], point:[Int], current_best_let: [String:Any], depth: Int=0) -> [String:Any]{
	// Recursive Nearest neighbor search
	var current_best = current_best_let
	let current_best_node = current_best["node"]!
	let current_best_r = current_best["r"]! as! Float

	// Select axis
	let axis = depth % (point.count)
	// For current node set branches to not visited
	var visited = [
		"leftChild": 0,
		"rightChild": 0,
	]

	let leftChild: [String : Any] = tree["leftChild"]! as! [String : Any]
	let rightChild: [String : Any] = tree["rightChild"]! as! [String : Any]
	let currentNode: [Int] = tree["location"]! as! [Int]

	// Check distance for current node
	let tmp_r = distance(point1: currentNode, point2: point)
	// Change current best if the distance is less than the current 'r'
	if tmp_r < current_best_r {
		current_best["node"] = tree["location"]
		current_best["r"] = tmp_r
	}

	// Chose a non empty branch depending on the value of the i-th coordinate of the point
	if (point[axis] > currentNode[axis]) && (rightChild.count != 0) {
		current_best = nn(tree: rightChild, point:point, current_best_let :current_best, depth: depth + 1)
		visited["rightChild"] = 1
	}
	else if leftChild.count != 0 {
		current_best = nn(tree: leftChild, point:point, current_best_let :current_best, depth: depth + 1)
		visited["leftChild"] = 1
	}
	else {
		// --> if no children branch exists
		// Termination condition
		return current_best
	}

	if !ball_radius_test(point: point, pivot: currentNode, r: current_best_r, axis: axis) {
		// Choose the opposite child of the explored one
		var child = "leftChild"
		if visited["leftChild"]! == 1 {
			// The branch we previously explored was the left one, change child to the opposite
			child = "rightChild"
		}
		let child_branch: [String : Any] = tree[child]! as! [String : Any]
		// Checks whether the opposite branch exists or is empty
		if child_branch.count != 0 {
			current_best = nn(tree:child_branch,point:point, current_best_let :current_best, depth: depth + 1)
		}
	}

	return current_best
}

func ball_radius_test(point: [Int], pivot: [Int], r: Float, axis: Int) -> Bool {
	// Tests whether there could be a nearest neighbor in the other side of the tree
	let sphere_axis = abs(Float(point[axis]) - r)
	let dist_point_pivot = Float(abs(point[axis] - pivot[axis]))

	if sphere_axis > dist_point_pivot {
		// Sphere of radius distance(point, current_best) overlaps the other side of the space
		return false
	}
	else {
		// No overlap, there cannot be a nearest neighbor in the other branch of the tree
		return true
	}
}

func distance(point1: [Int], point2: [Int]) -> Float {
	// Euclidean distance btw two m-dimensional points
	var d_square = Float(0)
	for i in 0...point1.count-1 {
		d_square = d_square + pow(Float(point2[i] - point1[i]), 2.0)
	}
	return pow(Float(d_square), 1/2)
}

var points = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]

var tree = kdtree(points: points)
var nearest = search(tree: tree, point: [2,6])

print("Nearest point:\n\t\(nearest["node"]!)\nDistance:\n\t\(nearest["r"]!)")

