Skip to content
Snippets Groups Projects
Select Git revision
  • f4e510856180e9b4a9e09fda133b27a92eb64d4e
  • master default
  • zs-dobrichovice
3 results

update-users.py

Blame
  • splay_operation_test.py 4.55 KiB
    #!/usr/bin/env python3
    import itertools
    import math
    import sys
    
    from splay_operation import Tree, Node
    
    def flatten(tree):
        """Flatten given tree in ascending order."""
        L, R, F = 0, 1, 2
    
        node, stack, flattened = tree.root, [L], []
        while node is not None:
            if stack[-1] == L:
                stack[-1] = R
                if node.left is not None:
                    node = node.left
                    stack.append(L)
            elif stack[-1] == R:
                flattened.append(node.key)
                stack[-1] = F
                if node.right is not None:
                    node = node.right
                    stack.append(L)
            else:
                node = node.parent
                stack.pop()
    
        return flattened
    
    def test_splay():
        def deserialize_tree(string):
            def deserialize_node(i):
                assert string[i] == "("
                i += 1
                if string[i] == ")":
                    return i + 1, None
                else:
                    comma = string.find(",", i)
                    comma2, left = deserialize_node(comma + 1)
                    rparen, right = deserialize_node(comma2 + 1)
                    assert string[rparen] == ")"
                    return rparen + 1, Node(int(string[i : comma]), left=left, right=right)
    
            index, root = deserialize_node(0)
            assert index == len(string)
            return Tree(root)
    
        def compare(system, gold):
            if system is None and gold is not None:
                return "expected node with key {}, found None".format(gold.key)
            elif system is not None and gold is None:
                return "expected None, found node with key {}".format(system.key)
            elif system is not None and gold is not None:
                if system.key != gold.key:
                    return "expected node with key {}, found {}".format(gold.key, system.key)
                return compare(system.left, gold.left) or compare(system.right, gold.right)
    
        with open("splay_tests.txt", "r") as splay_tests_file:
            for line in splay_tests_file:
                original_serialized, target_serialized, splayed_serialized = line.rstrip("\n").split()
                original = deserialize_tree(original_serialized)
                splayed = deserialize_tree(splayed_serialized)
                target = int(target_serialized)
    
                node = original.root
                while node is not None and node.key != target:
                    if target < node.key: node = node.left
                    else: node = node.right
                assert node is not None
    
                original.splay(node)
                error = compare(original.root, splayed.root)
                assert not error, "Error running splay on key {} of {}: {}".format(node.key, original_serialized, error)
    
    def test_lookup():
        tree = Tree()
        for elem in range(0, 100000, 2):
            tree.insert(elem)
    
        # Find non-existing
        for elem in range(1, 100000, 2):
            for _ in range(10):
                assert tree.lookup(elem) is None, "Non-existing element was found"
    
        # Find existing
        for elem in range(0, 100000, 2):
            for _ in range(10):
                assert tree.lookup(elem) is not None, "Existing element was not found"
    
    def test_insert():
        # Test validity first
        tree = Tree()
        sequence = [pow(997, i, 1999) for i in range(1, 1999)]
        for elem in sequence:
            tree.insert(elem)
        assert flatten(tree) == sorted(sequence), "Incorrect tree after a sequence of inserts"
    
        # Test speed
        tree = Tree()
        for elem in range(200000):
            for _ in range(10):
                tree.insert(elem)
    
    def test_remove():
        # Test validity first
        tree = Tree()
        for elem in range(2, 1999 * 2):
            tree.insert(elem)
    
        sequence = [2 * pow(997, i, 1999) for i in range(1, 1999)]
        for elem in sequence:
            tree.remove(elem + 1)
        assert flatten(tree) == sorted(sequence), "Incorrect tree after a sequence of removes"
    
        # Test speed
        tree = Tree()
        for elem in range(0, 100000, 2):
            tree.insert(elem)
    
        # Non-existing elements
        for elem in range(1, 100000, 2):
            for _ in range(10):
                tree.remove(elem)
    
        # Existing elements
        for elem in range(2, 100000, 2):
            tree.remove(elem)
    
    tests = [
        ("splay", test_splay),
        ("lookup", test_lookup),
        ("insert", test_insert),
        ("remove", test_remove),
    ]
    
    if __name__ == "__main__":
        for required_test in sys.argv[1:] or [name for name, _ in tests]:
            for name, test in tests:
                if name == required_test:
                    print("Running test {}".format(name), file=sys.stderr)
                    test()
                    break
            else:
                raise ValueError("Unknown test {}".format(name))