diff --git a/02-splay_operation/cpp/splay_operation_more_tests.cpp b/02-splay_operation/cpp/splay_operation_more_tests.cpp
index f70db3ff999d3cfbd6a76478b81dcabc0e467697..fe0c46c602ea0b3dbad51936b1272511247d0ed6 100644
--- a/02-splay_operation/cpp/splay_operation_more_tests.cpp
+++ b/02-splay_operation/cpp/splay_operation_more_tests.cpp
@@ -199,6 +199,16 @@ void test_remove() {
         for (int i = 1; i < elements; i++)
             tree.remove(0);
     }
+    {
+        Node *node = nullptr;
+        for (int i = 1; i < elements; i++)
+            node = new Node(i, nullptr, node, nullptr);
+        node = new Node(0, nullptr, nullptr, node);
+        Tree tree(node);
+
+        for (int i = 1; i < elements; i++)
+            tree.remove(i);
+    }
 }
 
 vector<pair<string, function<void()>>> tests = {
diff --git a/02-splay_operation/python/splay_operation_more_tests.py b/02-splay_operation/python/splay_operation_more_tests.py
index 1ebb8aa7e07cd64ce2b1cf6bd74307591f2f2cec..64ea6f8209398620b28b8d940c2ef4f6ff83f768 100644
--- a/02-splay_operation/python/splay_operation_more_tests.py
+++ b/02-splay_operation/python/splay_operation_more_tests.py
@@ -141,6 +141,14 @@ def test_remove():
     for elem in range(elements):
         tree.remove(0)
 
+    node = None
+    for i in range(1, elements):
+        node = Node(i, None, node, None)
+    node = Node(0, None, None, node)
+    tree = Tree(node)
+    for i in range(1, elements):
+        tree.remove(i)
+
 tests = [
     ("splay", test_splay),
     ("lookup", test_lookup),