from collections import namedtuple
from operator import itemgetter
from pprint import pformat


class Node(namedtuple('Node', 'location left_child right_child')):
    def __repr__(self):
        return pformat(tuple(self))

def distance(a, b):
    return (a[0]-b[0])**2 + (a[1]-b[1])**2


def nearest_neighbour(p):
    dist1 = distance(p, tree[0])
    dist2 = distance(p, tree[1][0])
    dist3 = distance(p, tree[2][0])

    if(dist1<dist2 and dist1<dist3):
        return tree[0]

    elif(dist2<dist3):
        dist4 = distance(p, tree[1][1][0])
        dist5 = distance(p, tree[1][2][0])

        if(dist4<dist2):
            return tree[1][1][0]

        elif(dist5<dist2):
            return tree[1][2][0]

        else:
            return tree[1][0]

    else:
        dist6 = distance(p, tree[2][1][0])

        if(dist6<dist3):
            return tree[2][1][0]

        else:
            return tree[2][0]



def kdtree(point_list, depth=0):

    # assumes all points have the same dimension
    try:
        k = len(point_list[0])
    except IndexError:
        return None

    # Select axis based on depth so that axis cycles through
    # all valid values
    axis = depth % k

    # Sort point list and choose median as pivot element
    point_list.sort(key=itemgetter(axis))
    median = len(point_list) // 2  # choose median

    # Create node and construct subtrees
    return Node(
        location=point_list[median],
        left_child=kdtree(point_list[:median], depth + 1),
        right_child=kdtree(point_list[median + 1:], depth + 1)
    )


point_list = [(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)]
tree = kdtree(point_list)

print nearest_neighbour([5, 2])