# Project: kdtree
# Filename: kdtree.py
# Created on 10.11.17 by francesco
# Last modified on 10.11.17 by francesco


def kdtree(points, depth=0):
    # Termination Condition
    if points == []:
        return None

    # Select axis
    axis = depth % (len(points[0]))

    # Sort points
    sort(points, axis)
    # Select median (middle element in sorted list) (// preserves int result)
    median = len(points) // 2

    node = {
        'location': points[median],
        'leftChild': kdtree(points[:median], depth + 1),
        'rightChild': kdtree(points[median + 1:], depth + 1),
    }

    return node


def sort(points, axis):
    """
    Use Insert Sort function seen previously in the exercices to sort without
    using list.sort() or sorted(list) because both require the key to be specified as
    an itemgetter or a lambda function not yet seen
    """
    for j in range(1, len(points)):
        current_point = points[j]
        i = j - 1
        while i >= 0 and points[i][axis] > current_point[axis]:
            points[i + 1] = points[i]
            i = i - 1
        points[i + 1] = current_point


def search(tree, point):

    # Initializes current best to unknown and r to infinite
    best = {'node': None, 'r': float('inf')}

    return nn(tree, point, best)


def nn(tree, point, current_best, depth=0):
    """
    Recursive Nearest neighbor search
    """
    # Select axis
    axis = depth % (len(point))
    # For current node set branches to not visited
    visited = {'leftChild': 0, 'rightChild': 0}

    # Check distance for current node
    tmp_r = distance(tree['location'], point)
    # Change current best if the distance is less than current r
    if tmp_r < current_best['r']:
        current_best['node'] = tree['location']
        current_best['r'] = distance(tree['location'], point)

    # Choose a non empty branch depending on the value of the i-th coordinate of the point
    if point[axis] > tree['location'][axis] and tree['rightChild'] != None:
        current_best = nn(tree['rightChild'], point, current_best, depth + 1)
        visited['rightChild'] = 1
    elif tree['leftChild'] != None:
        current_best = nn(tree['leftChild'], point, current_best, depth + 1)
        visited['leftChild'] = 1
    else:
        # --> if no children exists
        # Termination condition
        return current_best

    if not ball_radius_test(point, tree['location'], current_best['r'], axis):
        # Choose the opposite child of the explored one
        child = 'leftChild'
        if visited['leftChild']:
            # The branch we previously explored was the left one, change child to the opposite
            child = 'rightChild'

        # Checks whether the opposite branch exists or is empty
        if tree[child] != None:
            current_best = nn(tree[child], point, current_best, depth + 1)

    return current_best


def ball_radius_test(point, pivot, r, axis):
    """
    Tests whether there could be a nearest neighbor in the other side of the tree
    """

    if abs(point[axis] - r) > abs(point[axis] - pivot[axis]):
        # Sphere of radius distance(point, current_best) overlaps the other side of the space
        return False
    else:
        # No overlap, there cannot be a nearest neighbor in the other branch of the tree
        return True


def distance(point1, point2):
    """
    Euclidean distance btw two m-dimensional points
    """
    d_square = 0
    for i in range(len(point1)):
        d_square += (point2[i] - point1[i]) ** 2
    return d_square ** (1 / 2)


points = [(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)]

tree = kdtree(points)

nearest = search(tree, (2,6))

print('Nearest point:\n\t{}\nDistance:\n\t{}'.format(nearest['node'], nearest['r']))

# import json
# with open("tree.json", "w") as file:
#     json.dump(tree, file, indent=4)
