Commit ed70b773 authored by Milan Straka's avatar Milan Straka

The range_tree assignment.

parent e0548e7f
test: range_tree_test
./$<
INCLUDE ?= .
CXXFLAGS=-std=c++11 -O2 -Wall -Wextra -g -Wno-sign-compare -I$(INCLUDE)
range_tree_test: range_tree_test.cpp range_tree.h test_main.cpp
$(CXX) $(CXXFLAGS) $(filter %.cpp,$^) -o $@
clean:
rm -f range_tree_test
.PHONY: clean test
#include <cstdint>
#include <limits>
// A node of the tree
class Node {
public:
int64_t key;
int64_t value;
Node* left;
Node* right;
Node* parent;
// Constructor
Node(int64_t key, int64_t value, Node* parent=nullptr, Node* left=nullptr, Node* right=nullptr) {
this->key = key;
this->value = value;
this->parent = parent;
this->left = left;
this->right = right;
if (left) left->parent = this;
if (right) right->parent = this;
}
};
// Splay tree
class Tree {
public:
// Pointer to root of the tree; nullptr if the tree is empty.
Node* root;
Tree(Node* root=nullptr) {
this->root = root;
}
// Rotate the given `node` up. Perform a single rotation of the edge
// between the node and its parent, choosing left or right rotation
// appropriately.
virtual void rotate(Node* node) {
if (node->parent) {
if (node->parent->left == node) {
if (node->right) node->right->parent = node->parent;
node->parent->left = node->right;
node->right = node->parent;
} else {
if (node->left) node->left->parent = node->parent;
node->parent->right = node->left;
node->left = node->parent;
}
if (node->parent->parent) {
if (node->parent->parent->left == node->parent)
node->parent->parent->left = node;
else
node->parent->parent->right = node;
} else {
root = node;
}
Node* original_parent = node->parent;
node->parent = node->parent->parent;
original_parent->parent = node;
}
}
// Splay the given node.
virtual void splay(Node* node) {
while (node->parent && node->parent->parent) {
if ((node->parent->right == node && node->parent->parent->right == node->parent) ||
(node->parent->left == node && node->parent->parent->left == node->parent)) {
rotate(node->parent);
rotate(node);
} else {
rotate(node);
rotate(node);
}
}
if (node->parent)
rotate(node);
}
// Look up the given key in the tree, returning the
// the node with the requested key or nullptr.
Node* lookup(int64_t key) {
Node* node = root;
Node* node_last = nullptr;
while (node) {
node_last = node;
if (node->key == key)
break;
if (key < node->key)
node = node->left;
else
node = node->right;
}
if (node_last)
splay(node_last);
return node;
}
// Insert a (key, value) into the tree.
// If the key is already present, nothing happens.
void insert(int64_t key, int64_t value) {
if (!root) {
root = new Node(key, value);
return;
}
Node* node = root;
while (node->key != key) {
if (key < node->key) {
if (!node->left)
node->left = new Node(key, value, node);
node = node->left;
} else {
if (!node->right)
node->right = new Node(key, value, node);
node = node->right;
}
}
splay(node);
}
// Delete given key from the tree.
// It the key is not present, do nothing.
//
// The implementation first splays the element to be removed to
// the root, and if it has both children, splays the largest element
// in the left subtree and links it to the original right subtree.
void remove(int64_t key) {
if (lookup(key)) {
Node* right = root->right;
root = root->left;
if (!root) {
root = right;
right = nullptr;
}
if (root)
root->parent = nullptr;
if (right) {
Node* node = root;
while (node->right)
node = node->right;
splay(node);
root->right = right;
right->parent = root;
}
}
}
// Return number of elements in range [left, right].
//
// Given a closed range [left, right], return the sum of values of elements
// in the range, i.e., sum(value | (key, value) in tree, left <= key <= right).
int64_t range_sum(int64_t left, int64_t right) {
// TODO: Implement
}
// Destructor to free all allocated memory.
~Tree() {
Node* node = root;
while (node) {
Node* next;
if (node->left) {
next = node->left;
node->left = nullptr;
} else if (node->right) {
next = node->right;
node->right = nullptr;
} else {
next = node->parent;
delete node;
}
node = next;
}
}
};
#include <algorithm>
#include <functional>
#include <string>
#include <utility>
#include <vector>
using namespace std;
// If the condition is not true, report an error and halt.
#define EXPECT(condition, message) do { if (!(condition)) expect_failed(message); } while (0)
void expect_failed(const string& message);
#include "range_tree.h"
void create_test_tree(int64_t size, bool ascending, Tree& tree) {
vector<int64_t> sequence = {7};
for (int64_t i = 2; i < size; i++)
sequence.push_back((sequence.back() * sequence.front()) % size);
if (ascending)
sort(sequence.begin(), sequence.end());
for (int64_t element : sequence)
tree.insert(element, element);
}
void test_missing(int64_t size, bool ascending) {
Tree tree;
create_test_tree(size, ascending, tree);
int64_t values = 0;
for (int64_t i = 0; i < size; i++)
values += tree.range_sum(-size, 0) + tree.range_sum(size, 2 * size);
EXPECT(values == 0, "Expected no values in an empty range");
}
void test_suffixes(int64_t size, bool ascending) {
Tree tree;
create_test_tree(size, ascending, tree);
for (int64_t left = 1; left < size; left++) {
int64_t values = tree.range_sum(left, size - 1);
int64_t expected = size * (size - 1) / 2 - left * (left - 1) / 2;
EXPECT(values == expected,
"Expected " + to_string(expected) + " for range [" + to_string(left) +
", " + to_string(size - 1) + "], got " + to_string(values));
}
}
void test_updates(int64_t size, bool ascending) {
Tree tree;
create_test_tree(size, ascending, tree);
for (int64_t left = 1; left < size; left++) {
tree.remove(left);
tree.insert(left + size - 1, left + size - 1);
int64_t values = tree.range_sum(left + 1, size + left);
int64_t expected = (size + left) * (size + left - 1) / 2 - (left + 1) * left / 2;
EXPECT(values == expected,
"Expected " + to_string(expected) + " for range [" + to_string(left + 1) +
", " + to_string(size + left) + "], got " + to_string(values));
}
}
vector<pair<string, function<void()>>> tests = {
{"random_missing", [] { test_missing(13, false); }},
{"random_suffixes", [] { test_suffixes(13, false); }},
{"random_updates", [] { test_updates(13, false); }},
{"path_missing", [] { test_missing(13, true); }},
{"path_suffixes", [] { test_suffixes(13, true); }},
{"path_updates", [] { test_updates(13, true); }},
{"random_missing_big", [] { test_missing(199999, false); }},
{"random_suffixes_big", [] { test_suffixes(199999, false); }},
{"random_updates_big", [] { test_updates(199999, false); }},
{"path_missing_big", [] { test_missing(199999, true); }},
{"path_suffixes_big", [] { test_suffixes(199999, true); }},
{"path_updates_big", [] { test_updates(199999, true); }},
};
#include <cstdlib>
#include <functional>
#include <iostream>
#include <string>
#include <utility>
#include <vector>
using namespace std;
extern vector<pair<string, function<void()>>> tests;
void expect_failed(const string& message) {
cerr << "Test error: " << message << endl;
exit(1);
}
int main(int argc, char* argv[]) {
vector<string> required_tests;
if (argc > 1) {
required_tests.assign(argv + 1, argv + argc);
} else {
for (const auto& test : tests)
required_tests.push_back(test.first);
}
for (const auto& required_test : required_tests) {
bool found = false;
for (const auto& test : tests)
if (required_test == test.first) {
cerr << "Running test " << required_test << endl;
test.second();
found = true;
break;
}
if (!found) {
cerr << "Unknown test " << required_test << endl;
return 1;
}
}
return 0;
}
#!/usr/bin/env python3
import math
class Node:
"""Node in a binary tree `Tree`"""
def __init__(self, key, value, left=None, right=None, parent=None):
self.key = key
self.value = value
self.parent = parent
self.left = left
self.right = right
if left is not None: left.parent = self
if right is not None: right.parent = self
class Tree:
"""A splay tree implementation"""
def __init__(self, root=None):
self.root = root
def rotate(self, node):
""" Rotate the given `node` up.
Performs a single rotation of the edge between the given node
and its parent, choosing left or right rotation appropriately.
"""
if node.parent is not None:
if node.parent.left == node:
if node.right is not None: node.right.parent = node.parent
node.parent.left = node.right
node.right = node.parent
else:
if node.left is not None: node.left.parent = node.parent
node.parent.right = node.left
node.left = node.parent
if node.parent.parent is not None:
if node.parent.parent.left == node.parent:
node.parent.parent.left = node
else:
node.parent.parent.right = node
else:
self.root = node
node.parent.parent, node.parent = node, node.parent.parent
def splay(self, node):
"""Splay the given node"""
while node.parent is not None and node.parent.parent is not None:
if (node.parent.right == node and node.parent.parent.right == node.parent) or \
(node.parent.left == node and node.parent.parent.left == node.parent):
self.rotate(node.parent)
self.rotate(node)
else:
self.rotate(node)
self.rotate(node)
if node.parent is not None:
self.rotate(node)
def lookup(self, key):
"""Look up the given key in the tree.
Returns the node with the requested key or `None`.
"""
node, node_last = self.root, None
while node is not None:
node_last = node
if node.key == key:
break
if key < node.key:
node = node.left
else:
node = node.right
if node_last is not None:
self.splay(node_last)
return node
def insert(self, key, value):
"""Insert (key, value) into the tree.
If the key is already present, do nothing.
"""
if self.root is None:
self.root = Node(key, value)
return
node = self.root
while node.key != key:
if key < node.key:
if node.left is None:
node.left = Node(key, value, parent=node)
node = node.left
else:
if node.right is None:
node.right = Node(key, value, parent=node)
node = node.right
self.splay(node)
def remove(self, key):
"""Remove given key from the tree.
It the key is not present, do nothing.
The implementation first splays the element to be removed to
the root, and if it has both children, splays the largest element
in the left subtree and links it to the original right subtree.
"""
if self.lookup(key) is not None:
right = self.root.right
self.root = self.root.left
if self.root is None:
self.root, right = right, None
if self.root is not None:
self.root.parent = None
if right is not None:
node = self.root
while node.right is not None:
node = node.right
self.splay(node)
self.root.right = right
right.parent = self.root
def range_sum(self, left, right):
"""Return number of elements in range [left, right]
Given a closed range [left, right], return the sum of values of elements
in the range, i.e., sum(value | (key, value) in tree, left <= key <= right).
"""
raise NotImplementedError()
#!/usr/bin/env python3
import sys
from range_tree import Tree
def test_tree(size, ascending):
sequence = [pow(7, i, size) for i in range(1, size)]
if ascending: sequence = sorted(sequence)
tree = Tree()
for element in sequence:
tree.insert(element, element)
return tree
def test_missing(size, ascending):
tree = test_tree(size, ascending)
values = 0
for _ in range(size):
values += tree.range_sum(-size, 0)
values += tree.range_sum(size, 2 * size)
assert values == 0, "Expected no values in an empty range"
def test_suffixes(size, ascending):
tree = test_tree(size, ascending)
for left in range(1, size):
values = tree.range_sum(left, size - 1)
expected = size * (size - 1) // 2 - left * (left - 1) // 2
assert values == expected, "Expected {} for range [{}, {}], got {}".format(expected, left, size - 1, values)
def test_updates(size, ascending):
tree = test_tree(size, ascending)
for left in range(1, size):
tree.remove(left)
tree.insert(left + size - 1, left + size - 1)
values = tree.range_sum(left + 1, size + left)
expected = (size + left) * (size + left - 1) // 2 - (left + 1) * left // 2
assert values == expected, "Expected {} for range [{}, {}], got {}".format(expected, left + 1, size + left, values)
tests = [
("random_missing", lambda: test_missing(13, False)),
("random_suffixes", lambda: test_suffixes(13, False)),
("random_updates", lambda: test_updates(13, False)),
("path_missing", lambda: test_missing(13, True)),
("path_suffixes", lambda: test_suffixes(13, True)),
("path_updates", lambda: test_updates(13, True)),
("random_missing_big", lambda: test_missing(19997, False)),
("random_suffixes_big", lambda: test_suffixes(19997, False)),
("random_updates_big", lambda: test_updates(19997, False)),
("path_missing_big", lambda: test_missing(19997, True)),
("path_suffixes_big", lambda: test_suffixes(19997, True)),
("path_updates_big", lambda: test_updates(19997, True)),
]
if __name__ == "__main__":
for required_test in sys.argv[1:] or [name for name, _ in tests]:
for name, test in tests:
if name == required_test:
print("Running test {}".format(name), file=sys.stderr)
test()
break
else:
raise ValueError("Unknown test {}".format(name))
You are given an implementation of a splay tree which associates every numeric
key with a numeric value. The splay tree provides `lookup`, `insert`, and `remove`
operations.
Your goal is to modify the splay tree to support range queries in amortized
logarithmic time. The operation you need to implement takes an range on the
input and should return the sum of values of the elements in the given range.
As usual, you should submit only the `range_tree.{h,py}` file.
## Optional: Range updates (for 5 points)
If you also implement an operation
```
range_update(left, right, delta)
```
which adds `delta` to the value of all elements with key in `[left, right]` range
and runs in amortized logarithmic time, you will get 5 points.
Currently there are no automated tests for this method; therefore, if you
implement it, submit the solution to ReCodEx and write an email to your
teaching assistant.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment