From 3f24bc38dd16d2d8a00270e3ad2297f353e11b7b Mon Sep 17 00:00:00 2001
From: Jirka Fink <fink@ktiml.mff.cuni.cz>
Date: Thu, 19 Oct 2023 09:38:17 +0200
Subject: [PATCH] Splay operation: One test is improved

---
 .../cpp/splay_operation_more_tests.cpp         | 18 ++++++++++--------
 .../python/splay_operation_more_tests.py       | 15 ++++++++-------
 2 files changed, 18 insertions(+), 15 deletions(-)

diff --git a/02-splay_operation/cpp/splay_operation_more_tests.cpp b/02-splay_operation/cpp/splay_operation_more_tests.cpp
index fe0c46c..aeaa6f2 100644
--- a/02-splay_operation/cpp/splay_operation_more_tests.cpp
+++ b/02-splay_operation/cpp/splay_operation_more_tests.cpp
@@ -107,7 +107,7 @@ class TestSplay {
     }
 };
 
-const int elements = 5000000;
+const int elements = 5000000; // Must be even!
 
 void test_lookup() {
     // Insert even numbers
@@ -200,14 +200,16 @@ void test_remove() {
             tree.remove(0);
     }
     {
-        Node *node = nullptr;
-        for (int i = 1; i < elements; i++)
-            node = new Node(i, nullptr, node, nullptr);
-        node = new Node(0, nullptr, nullptr, node);
-        Tree tree(node);
+        Node *left_subtree = nullptr, *right_subtree = nullptr;
+        for (int i = elements/2-1; i >= 0; i--) {
+            left_subtree = new Node(i, nullptr, nullptr, left_subtree);
+            right_subtree = new Node(elements-i, nullptr, right_subtree, nullptr);
+        }
+        Node *root = new Node(elements/2, nullptr, left_subtree, right_subtree);
+        Tree tree(root);
 
-        for (int i = 1; i < elements; i++)
-            tree.remove(i);
+        while(tree.root)
+            tree.remove(tree.root->key);
     }
 }
 
diff --git a/02-splay_operation/python/splay_operation_more_tests.py b/02-splay_operation/python/splay_operation_more_tests.py
index 64ea6f8..8e8b4d7 100644
--- a/02-splay_operation/python/splay_operation_more_tests.py
+++ b/02-splay_operation/python/splay_operation_more_tests.py
@@ -97,7 +97,7 @@ def test_insert():
     assert flatten(tree) == sorted(sequence), "Incorrect tree after a sequence of inserts"
 
     # Test speed
-    elements = 200000
+    elements = 200000 # Must be even!
     tree = Tree()
     for elem in range(elements):
         for _ in range(10):
@@ -141,13 +141,14 @@ def test_remove():
     for elem in range(elements):
         tree.remove(0)
 
-    node = None
-    for i in range(1, elements):
-        node = Node(i, None, node, None)
-    node = Node(0, None, None, node)
+    left_subtree = right_subtree = None
+    for i in reversed(range(0, elements//2)):
+        left_subtree = Node(i, right=left_subtree)
+        right_subtree = Node(elements-i, left=right_subtree)
+    node = Node(elements//2, left=left_subtree, right=right_subtree)
     tree = Tree(node)
-    for i in range(1, elements):
-        tree.remove(i)
+    while tree.root is not None:
+        tree.remove(tree.root.key)
 
 tests = [
     ("splay", test_splay),
-- 
GitLab