#!/usr/bin/env python3
import itertools
import math
import sys

from splay_operation import Tree, Node

def flatten(tree):
    """Flatten given tree in ascending order."""
    L, R, F = 0, 1, 2

    node, stack, flattened = tree.root, [L], []
    while node is not None:
        if stack[-1] == L:
            stack[-1] = R
            if node.left is not None:
                node = node.left
                stack.append(L)
        elif stack[-1] == R:
            flattened.append(node.key)
            stack[-1] = F
            if node.right is not None:
                node = node.right
                stack.append(L)
        else:
            node = node.parent
            stack.pop()

    return flattened

def test_splay():
    def deserialize_tree(string):
        def deserialize_node(i):
            assert string[i] == "("
            i += 1
            if string[i] == ")":
                return i + 1, None
            else:
                comma = string.find(",", i)
                comma2, left = deserialize_node(comma + 1)
                rparen, right = deserialize_node(comma2 + 1)
                assert string[rparen] == ")"
                return rparen + 1, Node(int(string[i : comma]), left=left, right=right)

        index, root = deserialize_node(0)
        assert index == len(string)
        return Tree(root)

    def compare(system, gold):
        if system is None and gold is not None:
            return "expected node with key {}, found None".format(gold.key)
        elif system is not None and gold is None:
            return "expected None, found node with key {}".format(system.key)
        elif system is not None and gold is not None:
            if system.key != gold.key:
                return "expected node with key {}, found {}".format(gold.key, system.key)
            return compare(system.left, gold.left) or compare(system.right, gold.right)

    with open("splay_tests.txt", "r") as splay_tests_file:
        for line in splay_tests_file:
            original_serialized, target_serialized, splayed_serialized = line.rstrip("\n").split()
            original = deserialize_tree(original_serialized)
            splayed = deserialize_tree(splayed_serialized)
            target = int(target_serialized)

            node = original.root
            while node is not None and node.key != target:
                if target < node.key: node = node.left
                else: node = node.right
            assert node is not None

            original.splay(node)
            error = compare(original.root, splayed.root)
            assert not error, "Error running splay on key {} of {}: {}".format(node.key, original_serialized, error)

def test_lookup():
    tree = Tree()
    for elem in range(0, 100000, 2):
        tree.insert(elem)

    # Find non-existing
    for elem in range(1, 100000, 2):
        for _ in range(10):
            assert tree.lookup(elem) is None, "Non-existing element was found"

    # Find existing
    for elem in range(0, 100000, 2):
        for _ in range(10):
            assert tree.lookup(elem) is not None, "Existing element was not found"

def test_insert():
    # Test validity first
    tree = Tree()
    sequence = [pow(997, i, 1999) for i in range(1, 1999)]
    for elem in sequence:
        tree.insert(elem)
    assert flatten(tree) == sorted(sequence), "Incorrect tree after a sequence of inserts"

    # Test speed
    tree = Tree()
    for elem in range(200000):
        for _ in range(10):
            tree.insert(elem)

def test_remove():
    # Test validity first
    tree = Tree()
    for elem in range(2, 1999 * 2):
        tree.insert(elem)

    sequence = [2 * pow(997, i, 1999) for i in range(1, 1999)]
    for elem in sequence:
        tree.remove(elem + 1)
    assert flatten(tree) == sorted(sequence), "Incorrect tree after a sequence of removes"

    # Test speed
    tree = Tree()
    for elem in range(0, 100000, 2):
        tree.insert(elem)

    # Non-existing elements
    for elem in range(1, 100000, 2):
        for _ in range(10):
            tree.remove(elem)

    # Existing elements
    for elem in range(2, 100000, 2):
        tree.remove(elem)

tests = [
    ("splay", test_splay),
    ("lookup", test_lookup),
    ("insert", test_insert),
    ("remove", test_remove),
]

if __name__ == "__main__":
    for required_test in sys.argv[1:] or [name for name, _ in tests]:
        for name, test in tests:
            if name == required_test:
                print("Running test {}".format(name), file=sys.stderr)
                test()
                break
        else:
            raise ValueError("Unknown test {}".format(name))