diff --git a/02-splay_operation/cpp/splay_operation_test.cpp b/02-splay_operation/cpp/splay_operation_test.cpp
index 23c28b676a76c395ffbfd8f618a47ac8941e9812..aeaa6f2ebda92805aae936c8360b13c3c81281eb 100644
--- a/02-splay_operation/cpp/splay_operation_test.cpp
+++ b/02-splay_operation/cpp/splay_operation_test.cpp
@@ -107,19 +107,21 @@ class TestSplay {
     }
 };
 
+const int elements = 5000000; // Must be even!
+
 void test_lookup() {
     // Insert even numbers
     Tree tree;
-    for (int i = 0; i < 5000000; i += 2)
+    for (int i = 0; i < elements; i += 2)
         tree.insert(i);
 
     // Find non-existing
-    for (int i = 1; i < 5000000; i += 2)
+    for (int i = 1; i < elements; i += 2)
         for (int j = 0; j < 10; j++)
             EXPECT(!tree.lookup(i), "Non-existing element was found");
 
     // Find existing
-    for (int i = 0; i < 5000000; i += 2)
+    for (int i = 0; i < elements; i += 2)
         for (int j = 0; j < 10; j++)
             EXPECT(tree.lookup(i), "Existing element was not found");
 }
@@ -142,15 +144,17 @@ void test_insert() {
     // Test speed
     {
         Tree tree;
-        for (int i = 0; i < 5000000; i++)
+        for (int i = 0; i < elements; i++)
             for (int j = 0; j < 10; j++)
                 tree.insert(i);
     }
 
     {
         Tree tree;
-        for (int i = 5000000; i >= 0; i--)
+        for (int i = elements; i >= 0; i--)
             tree.insert(i);
+        for (int i = 0; i < elements; i++)
+            tree.insert(elements);
     }
 }
 
@@ -175,19 +179,38 @@ void test_remove() {
     // Test speed
     {
         Tree tree;
-        for (int i = 0; i < 5000000; i++)
+        for (int i = 0; i < elements; i++)
             tree.insert(i);
 
         // Non-existing elements
-        for (int i = 1; i < 5000000; i += 2)
+        for (int i = 1; i < elements; i += 2)
             for (int j = 0; j < 10; j++)
                 tree.remove(i);
 
         // Existing elements
-        for (int i = 2; i < 5000000; i += 2)
+        for (int i = 2; i < elements; i += 2)
             for (int j = 0; j < 10; j++)
                 tree.remove(i);
     }
+    {
+        Tree tree;
+        for (int i = 1; i < elements; i++)
+            tree.insert(i);
+        for (int i = 1; i < elements; i++)
+            tree.remove(0);
+    }
+    {
+        Node *left_subtree = nullptr, *right_subtree = nullptr;
+        for (int i = elements/2-1; i >= 0; i--) {
+            left_subtree = new Node(i, nullptr, nullptr, left_subtree);
+            right_subtree = new Node(elements-i, nullptr, right_subtree, nullptr);
+        }
+        Node *root = new Node(elements/2, nullptr, left_subtree, right_subtree);
+        Tree tree(root);
+
+        while(tree.root)
+            tree.remove(tree.root->key);
+    }
 }
 
 vector<pair<string, function<void()>>> tests = {
diff --git a/02-splay_operation/python/splay_operation_test.py b/02-splay_operation/python/splay_operation_test.py
index a2ff13dbfb132f57c5278bfba347e8a8f3d38e80..8e8b4d781212f2a86a495a95b4a799f41a5625ff 100644
--- a/02-splay_operation/python/splay_operation_test.py
+++ b/02-splay_operation/python/splay_operation_test.py
@@ -97,14 +97,17 @@ def test_insert():
     assert flatten(tree) == sorted(sequence), "Incorrect tree after a sequence of inserts"
 
     # Test speed
+    elements = 200000 # Must be even!
     tree = Tree()
-    for elem in range(200000):
+    for elem in range(elements):
         for _ in range(10):
             tree.insert(elem)
 
     tree = Tree()
-    for elem in reversed(range(200000)):
+    for elem in reversed(range(elements)):
         tree.insert(elem)
+    for elem in range(elements):
+        tree.insert(elements)
 
 def test_remove():
     # Test validity first
@@ -118,19 +121,35 @@ def test_remove():
     assert flatten(tree) == sorted(sequence), "Incorrect tree after a sequence of removes"
 
     # Test speed
+    elements = 200000
     tree = Tree()
-    for elem in range(0, 100000, 2):
+    for elem in range(0, elements, 2):
         tree.insert(elem)
 
     # Non-existing elements
-    for elem in range(1, 100000, 2):
+    for elem in range(1, elements, 2):
         for _ in range(10):
             tree.remove(elem)
 
     # Existing elements
-    for elem in range(2, 100000, 2):
+    for elem in range(2, elements, 2):
         tree.remove(elem)
 
+    tree = Tree()
+    for elem in range(1, elements):
+        tree.insert(elem)
+    for elem in range(elements):
+        tree.remove(0)
+
+    left_subtree = right_subtree = None
+    for i in reversed(range(0, elements//2)):
+        left_subtree = Node(i, right=left_subtree)
+        right_subtree = Node(elements-i, left=right_subtree)
+    node = Node(elements//2, left=left_subtree, right=right_subtree)
+    tree = Tree(node)
+    while tree.root is not None:
+        tree.remove(tree.root.key)
+
 tests = [
     ("splay", test_splay),
     ("lookup", test_lookup),