diff --git a/.gitignore b/.gitignore index 1bd183227b3f..86f5383daa51 100644 --- a/.gitignore +++ b/.gitignore @@ -32,7 +32,7 @@ include/ccf/version.h python/version.py src/host/config_schema.h **/compatibility_report.json -.venv_ccf_sandbox/* +.venv_ccf_sandbox/ workspace/* ledger_[0-9]* snapshot_[0-9]* diff --git a/src/ds/rb_map.h b/src/ds/rb_map.h index 187a83cdeab8..383e79995ee8 100644 --- a/src/ds/rb_map.h +++ b/src/ds/rb_map.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace rb { @@ -139,9 +140,18 @@ namespace rb return Map(B, t.left(), t.rootKey(), t.rootValue(), t.right(), t.size()); } + // Return a red-black tree without this key present. + // + // Based on the Introduction to Algorithms CLRS implementation. Map remove(const K& key) const { - throw std::logic_error("rb::Map::remove(k): Not implemented!"); + auto res = _remove(key); + if (res.second && res.first.rootColor() == R) + { + auto r = res.first; + res.first = Map(B, r.left(), r.rootKey(), r.rootValue(), r.right()); + } + return res.first; } template @@ -175,7 +185,15 @@ namespace rb Color rootColor() const { - return _root->_c; + if (empty()) + { + // empty nodes are black + return B; + } + else + { + return _root->_c; + } } const K& rootKey() const @@ -198,6 +216,7 @@ namespace rb return Map(_root->_rgt); } + // Insert a new key and value pair. Map insert(const K& x, const V& v) const { if (empty()) @@ -284,6 +303,361 @@ namespace rb { return Map(c, left(), rootKey(), rootValue(), right()); } + + Map rotateRight() const + { + auto x = left(); + auto y = Map(rootColor(), x.right(), rootKey(), rootValue(), right()); + return Map(x.rootColor(), x.left(), x.rootKey(), x.rootValue(), y); + } + + Map rotateLeft() const + { + auto y = right(); + auto x = Map(rootColor(), left(), rootKey(), rootValue(), y.left()); + return Map(y.rootColor(), x, y.rootKey(), y.rootValue(), y.right()); + } + + // Fix a double black for this node, generated from removal. + // The double black node is the left node of this one. + // Return whether the double black needs to be propagated up. + std::pair fixDoubleBlackLeft() const + { + auto sibling = right(); + + auto root = Map(*this); + + if (sibling.rootColor() == R) + { + // recolor root and sibling + root = Map( + R, root.left(), root.rootKey(), root.rootValue(), sibling.paint(B)); + // rotate root left to make the sibling the root + root = root.rotateLeft(); + // We've moved the double black node during the rotation so now we need + // to fix it recursively + auto fixedLeft = root.left().fixDoubleBlackLeft(); + root = Map( + root.rootColor(), + fixedLeft.first, + root.rootKey(), + root.rootValue(), + root.right()); + if (!fixedLeft.second) + { + // nothing left to fix + return std::make_pair(root, false); + } + // in fixing that we may have moved the double black to a child of this + // root, so fix that now + sibling = root.right(); + } + + auto doubleBlack = false; + if (sibling.left().rootColor() == B && sibling.right().rootColor() == B) + { + // current node is being made black, siblings children are both black so + // we can safely convert the sibling to red and propagate the double + // black + sibling = sibling.paint(R); + // we might still have to propagate the double black up + doubleBlack = root.rootColor() == B; + root = Map(B, root.left(), root.rootKey(), root.rootValue(), sibling); + } + else + { + if (sibling.right().rootColor() == B) + { + // root, sibling and sibling's right are all black + // rotate the right with the sibling as the root + auto siblingLeft = sibling.left().paint(B); + sibling = Map( + R, + siblingLeft, + sibling.rootKey(), + sibling.rootValue(), + sibling.right()); + sibling = sibling.rotateRight(); + root = Map( + root.rootColor(), + root.left(), + root.rootKey(), + root.rootValue(), + sibling); + } + + auto recoloredSibling = Map( + root.rootColor(), + sibling.left(), + sibling.rootKey(), + sibling.rootValue(), + sibling.right().paint(B)); + root = Map( + B, root.left(), root.rootKey(), root.rootValue(), recoloredSibling); + root = root.rotateLeft(); + doubleBlack = false; + } + return std::make_pair(root, doubleBlack); + } + + std::pair fixDoubleBlackRight() const + { + auto sibling = left(); + + auto root = Map(*this); + + if (sibling.rootColor() == R) + { + // recolor root and sibling + root = Map( + R, sibling.paint(B), root.rootKey(), root.rootValue(), root.right()); + // rotate root left to make the sibling the root + root = root.rotateRight(); + // We've moved the double black node during the rotation so now we need + // to fix it recursively + auto fixedRight = root.right().fixDoubleBlackRight(); + root = Map( + root.rootColor(), + root.left(), + root.rootKey(), + root.rootValue(), + fixedRight.first); + if (!fixedRight.second) + { + // nothing left to fix + return std::make_pair(root, false); + } + // in fixing that we may have moved the double black to a child of this + // root, so fix that now + sibling = root.left(); + } + + auto doubleBlack = false; + if (sibling.left().rootColor() == B && sibling.right().rootColor() == B) + { + // current node is being made black, siblings children are both black so + // we can safely convert the sibling to red and propagate the double + // black + sibling = sibling.paint(R); + // we might still have to propagate the double black up + doubleBlack = root.rootColor() == B; + root = Map(B, sibling, root.rootKey(), root.rootValue(), root.right()); + } + else + { + if (sibling.left().rootColor() == B) + { + // root, sibling and sibling's left are all black + // rotate the right with the sibling as the root + auto siblingRight = sibling.right().paint(B); + sibling = Map( + R, + sibling.left(), + sibling.rootKey(), + sibling.rootValue(), + siblingRight); + sibling = sibling.rotateLeft(); + root = Map( + root.rootColor(), + sibling, + root.rootKey(), + root.rootValue(), + root.right()); + } + + auto recoloredSibling = Map( + root.rootColor(), + sibling.left().paint(B), + sibling.rootKey(), + sibling.rootValue(), + sibling.right()); + root = Map( + B, recoloredSibling, root.rootKey(), root.rootValue(), root.right()); + root = root.rotateRight(); + doubleBlack = false; + } + return std::make_pair(root, doubleBlack); + } + + // Remove the node with the given key. + // returns a new map along with a bool indicating whether we need to handle + // a double black. + std::pair _remove(const K& key) const + { + if (empty()) + { + // key not present in the tree, can't remove it so just return an empty + // map. + return std::make_pair(Map(), false); + } + + const K& rootk = rootKey(); + + if (key < rootk) + { + // remove key from the left subtree + auto left_without = left()._remove(key); + // copy the left into a new map to return + auto newMap = + Map(rootColor(), left_without.first, rootKey(), rootValue(), right()); + if (left_without.second) + { + // there is a double black node in the left subtree so fix it up + return newMap.fixDoubleBlackLeft(); + } + // no double blacks are present + return std::make_pair(newMap, false); + } + else if (rootk < key) + { + // mirror of the above case + auto right_without = right()._remove(key); + auto newMap = + Map(rootColor(), left(), rootKey(), rootValue(), right_without.first); + if (right_without.second) + { + return newMap.fixDoubleBlackRight(); + } + return std::make_pair(newMap, false); + } + else if (key == rootk) + { + // delete key from this node + if (left().empty() && right().empty()) + { + // leaf node, a simple case + auto doubleBlack = rootColor() == B; + return std::make_pair(Map(), doubleBlack); + } + else if (left().empty()) + { + // nothing on the left so we can replace this node with the right + // child + auto r = right(); + // Exactly one of the node being removed and the right node are black: + // - the left is empty so has a black height of 1 + // - the right must also have a black height of 1 to maintain the + // height but it is not empty or we would have hit the above if + // statement + // - therefore the right node is red. + assert(r.left().empty()); + assert(r.right().empty()); + assert(r.rootColor() == R); + return std::make_pair(r.paint(B), false); + } + else if (right().empty()) + { + // mirror of the above case + auto l = left(); + assert(l.left().empty()); + assert(l.right().empty()); + assert(l.rootColor() == R); + return std::make_pair(l.paint(B), false); + } + else + { + // both children are non-empty, swap this node's key and value with + // the successor and then delete the successor + auto successor = right().minimum(); + auto right_without = right()._remove(successor.first); + auto newMap = Map( + rootColor(), + left(), + successor.first, + successor.second, + right_without.first); + if (right_without.second) + { + return newMap.fixDoubleBlackRight(); + } + return std::make_pair(newMap, false); + } + } + else + { + // key not found in the tree + return std::make_pair(Map(*this), false); + } + } + + // Return the minimum key in this map along with its value. + std::pair minimum() const + { + assert(!empty()); + + if (left().empty()) + { + return std::make_pair(rootKey(), rootValue()); + } + else + { + return left().minimum(); + } + } + +#ifndef NDEBUG + // Check properties of the tree + void check() const + { + size_t totalBlackCount = 0; + _check(0, totalBlackCount); + } +#endif + +#ifndef NDEBUG + // Print an s-expression style representation of the tree's colors + std::string to_str() const + { + auto ss = std::stringstream(); + if (empty()) + { + ss << "B"; + return ss.str(); + } + auto color = rootColor() == B ? "B" : "R"; + ss << "(" << color << " " << left().to_str() << right().to_str() << ")"; + return ss.str(); + } +#endif + +#ifndef NDEBUG + void _check(size_t blackCount, size_t& totalBlackCount) const + { + if (empty()) + { + totalBlackCount = blackCount + 1; + return; + } + + if (rootColor() == R) + { + if (!left().empty() && left().rootColor() == R) + { + throw std::logic_error("rb::Map::check(): Double red node found"); + } + if (!right().empty() && right().rootColor() == R) + { + throw std::logic_error("rb::Map::check(): Double red node found"); + } + } + size_t leftBlackCount = 0; + size_t rightBlackCount = 0; + left()._check(blackCount + (rootColor() == B ? 1 : 0), leftBlackCount); + right()._check(blackCount + (rootColor() == B ? 1 : 0), rightBlackCount); + + if (leftBlackCount != rightBlackCount) + { + std::cout << to_str() << std::endl; + throw std::logic_error(fmt::format( + "rb::Map::check(): black counts didn't match between left and right " + "{} {}", + leftBlackCount, + rightBlackCount)); + } + + totalBlackCount = leftBlackCount + (rootColor() == B ? 1 : 0); + } +#endif }; template diff --git a/src/ds/test/map_bench.cpp b/src/ds/test/map_bench.cpp index 8cccdb42d6a5..497a9e657c6e 100644 --- a/src/ds/test/map_bench.cpp +++ b/src/ds/test/map_bench.cpp @@ -126,6 +126,31 @@ static void benchmark_getp(picobench::state& s) s.stop_timer(); } +template +static void benchmark_remove(picobench::state& s) +{ + size_t size = s.iterations(); + auto map = gen_map(size); + s.start_timer(); + for (auto _ : s) + { + (void)_; + if constexpr ( + std::is_same_v> || std::is_same_v>) + { + auto res = map.remove(0); + do_not_optimize(res); + } + else + { + auto res = map.erase(0); + do_not_optimize(res); + } + clobber_memory(); + } + s.stop_timer(); +} + template static void benchmark_foreach(picobench::state& s) { @@ -189,16 +214,26 @@ PICOBENCH(bench_rb_map_getp).iterations(sizes).samples(10).baseline(); auto bench_champ_map_getp = benchmark_getp>; PICOBENCH(bench_champ_map_getp).iterations(sizes).samples(10); -const std::vector for_sizes = {32 << 4, 32 << 5, 32 << 6}; - PICOBENCH_SUITE("foreach"); auto bench_rb_map_foreach = benchmark_foreach>; -PICOBENCH(bench_rb_map_foreach).iterations(for_sizes).samples(10).baseline(); +PICOBENCH(bench_rb_map_foreach).iterations(sizes).samples(10).baseline(); auto bench_champ_map_foreach = benchmark_foreach>; -PICOBENCH(bench_champ_map_foreach).iterations(for_sizes).samples(10); +PICOBENCH(bench_champ_map_foreach).iterations(sizes).samples(10); // std auto bench_std_map_foreach = benchmark_foreach>; -PICOBENCH(bench_std_map_foreach).iterations(for_sizes).samples(10); +PICOBENCH(bench_std_map_foreach).iterations(sizes).samples(10); auto bench_unord_map_foreach = benchmark_foreach>; -PICOBENCH(bench_unord_map_foreach).iterations(for_sizes).samples(10); +PICOBENCH(bench_unord_map_foreach).iterations(sizes).samples(10); + +PICOBENCH_SUITE("remove"); +auto bench_rb_map_remove = benchmark_remove>; +PICOBENCH(bench_rb_map_remove).iterations(sizes).samples(10).baseline(); +auto bench_champ_map_remove = benchmark_remove>; +PICOBENCH(bench_champ_map_remove).iterations(sizes).samples(10); + +// std +auto bench_std_map_remove = benchmark_remove>; +PICOBENCH(bench_std_map_remove).iterations(sizes).samples(10); +auto bench_unord_map_remove = benchmark_remove>; +PICOBENCH(bench_unord_map_remove).iterations(sizes).samples(10); diff --git a/src/ds/test/map_test.cpp b/src/ds/test/map_test.cpp index 24dc92ce375f..c6cb70bede4d 100644 --- a/src/ds/test/map_test.cpp +++ b/src/ds/test/map_test.cpp @@ -133,24 +133,6 @@ struct Remove : public Op } }; -template -struct NoOp : public Op -{ - NoOp() = default; - - std::pair apply(const Model& a, const M& b) - { - return std::make_pair(a, b); - } - - std::string str() - { - auto ss = std::stringstream(); - ss << "NoOp (Remove not implemented!)"; - return ss.str(); - } -}; - template std::vector>> gen_ops(size_t n) { @@ -188,19 +170,11 @@ std::vector>> gen_ops(size_t n) } case 3: // remove { - // Remove operation is not yet implemented for RBMap - if constexpr (std::is_same_v) - { - std::uniform_int_distribution<> gen_idx(0, keys.size() - 1); - auto i = gen_idx(gen); - auto k = keys[i]; - keys.erase(keys.begin() + i); - op = std::make_unique>(k); - } - else - { - op = std::make_unique>(); - } + std::uniform_int_distribution<> gen_idx(0, keys.size() - 1); + auto i = gen_idx(gen); + auto k = keys[i]; + keys.erase(keys.begin() + i); + op = std::make_unique>(k); break; } default: @@ -357,8 +331,6 @@ TEST_CASE_TEMPLATE("Snapshot is immutable", M, ChampMap, RBMap) INFO("Meanwhile, modify map"); { - // Remove operation is not yet implemented for RBMap - if constexpr (std::is_same_v) { auto all_entries = get_all_entries(map); auto& key_to_remove = all_entries.begin()->first;