Select Git revision
update-users.py
-
Martin Mareš authoredMartin Mareš authored
splay_operation_test.py 4.55 KiB
#!/usr/bin/env python3
import itertools
import math
import sys
from splay_operation import Tree, Node
def flatten(tree):
"""Flatten given tree in ascending order."""
L, R, F = 0, 1, 2
node, stack, flattened = tree.root, [L], []
while node is not None:
if stack[-1] == L:
stack[-1] = R
if node.left is not None:
node = node.left
stack.append(L)
elif stack[-1] == R:
flattened.append(node.key)
stack[-1] = F
if node.right is not None:
node = node.right
stack.append(L)
else:
node = node.parent
stack.pop()
return flattened
def test_splay():
def deserialize_tree(string):
def deserialize_node(i):
assert string[i] == "("
i += 1
if string[i] == ")":
return i + 1, None
else:
comma = string.find(",", i)
comma2, left = deserialize_node(comma + 1)
rparen, right = deserialize_node(comma2 + 1)
assert string[rparen] == ")"
return rparen + 1, Node(int(string[i : comma]), left=left, right=right)
index, root = deserialize_node(0)
assert index == len(string)
return Tree(root)
def compare(system, gold):
if system is None and gold is not None:
return "expected node with key {}, found None".format(gold.key)
elif system is not None and gold is None:
return "expected None, found node with key {}".format(system.key)
elif system is not None and gold is not None:
if system.key != gold.key:
return "expected node with key {}, found {}".format(gold.key, system.key)
return compare(system.left, gold.left) or compare(system.right, gold.right)
with open("splay_tests.txt", "r") as splay_tests_file:
for line in splay_tests_file:
original_serialized, target_serialized, splayed_serialized = line.rstrip("\n").split()
original = deserialize_tree(original_serialized)
splayed = deserialize_tree(splayed_serialized)
target = int(target_serialized)
node = original.root
while node is not None and node.key != target:
if target < node.key: node = node.left
else: node = node.right
assert node is not None
original.splay(node)
error = compare(original.root, splayed.root)
assert not error, "Error running splay on key {} of {}: {}".format(node.key, original_serialized, error)
def test_lookup():
tree = Tree()
for elem in range(0, 100000, 2):
tree.insert(elem)
# Find non-existing
for elem in range(1, 100000, 2):
for _ in range(10):
assert tree.lookup(elem) is None, "Non-existing element was found"
# Find existing
for elem in range(0, 100000, 2):
for _ in range(10):
assert tree.lookup(elem) is not None, "Existing element was not found"
def test_insert():
# Test validity first
tree = Tree()
sequence = [pow(997, i, 1999) for i in range(1, 1999)]
for elem in sequence:
tree.insert(elem)
assert flatten(tree) == sorted(sequence), "Incorrect tree after a sequence of inserts"
# Test speed
tree = Tree()
for elem in range(200000):
for _ in range(10):
tree.insert(elem)
def test_remove():
# Test validity first
tree = Tree()
for elem in range(2, 1999 * 2):
tree.insert(elem)
sequence = [2 * pow(997, i, 1999) for i in range(1, 1999)]
for elem in sequence:
tree.remove(elem + 1)
assert flatten(tree) == sorted(sequence), "Incorrect tree after a sequence of removes"
# Test speed
tree = Tree()
for elem in range(0, 100000, 2):
tree.insert(elem)
# Non-existing elements
for elem in range(1, 100000, 2):
for _ in range(10):
tree.remove(elem)
# Existing elements
for elem in range(2, 100000, 2):
tree.remove(elem)
tests = [
("splay", test_splay),
("lookup", test_lookup),
("insert", test_insert),
("remove", test_remove),
]
if __name__ == "__main__":
for required_test in sys.argv[1:] or [name for name, _ in tests]:
for name, test in tests:
if name == required_test:
print("Running test {}".format(name), file=sys.stderr)
test()
break
else:
raise ValueError("Unknown test {}".format(name))