Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(search): Rax TreeMap #3909

Merged
merged 3 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/core/search/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ add_library(query_parser base.cc ast_expr.cc query_driver.cc search.cc indices.c
sort_indices.cc vector_utils.cc compressed_sorted_set.cc block_list.cc
${gen_dir}/parser.cc ${gen_dir}/lexer.cc)

target_link_libraries(query_parser base absl::strings TRDP::reflex TRDP::uni-algo TRDP::hnswlib)
target_link_libraries(query_parser base absl::strings TRDP::reflex TRDP::uni-algo TRDP::hnswlib redis_lib)

cxx_test(compressed_sorted_set_test query_parser LABELS DFLY)
cxx_test(block_list_test query_parser LABELS DFLY)
cxx_test(rax_tree_test redis_test_lib LABELS DFLY)
cxx_test(search_parser_test query_parser LABELS DFLY)
cxx_test(search_test query_parser LABELS DFLY)
153 changes: 153 additions & 0 deletions src/core/search/rax_tree.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#pragma once

#include <absl/types/span.h>

#include <cstdio>
#include <optional>
#include <string_view>

#include "base/pmr/memory_resource.h"

extern "C" {
#include "redis/rax.h"
}

namespace dfly::search {

// absl::flat_hash_map/std::unordered_map compatible tree map based on rax tree.
// Allocates all objects on heap (with custom memory resource) as rax tree operates fully on
// pointers.
template <typename V> struct RaxTreeMap {
struct FindIterator;

// Simple seeking iterator
struct SeekIterator {
friend struct FindIterator;

SeekIterator() {
dranikpg marked this conversation as resolved.
Show resolved Hide resolved
raxStart(&it_, nullptr);
it_.node = nullptr;
}

~SeekIterator() {
raxStop(&it_);
}

SeekIterator(SeekIterator&&) = delete; // self-referential
SeekIterator(const SeekIterator&) = delete; // self-referential

SeekIterator(rax* tree, const char* op, std::string_view key) {
raxStart(&it_, tree);
raxSeek(&it_, op, to_key_ptr(key), key.size());
operator++();
}

explicit SeekIterator(rax* tree) : SeekIterator(tree, "^", std::string_view{nullptr, 0}) {
}

bool operator==(const SeekIterator& rhs) const {
return it_.node == rhs.it_.node;
}

bool operator!=(const SeekIterator& rhs) const {
return !operator==(rhs);
}

SeekIterator& operator++() {
if (!raxNext(&it_)) {
raxStop(&it_);
it_.node = nullptr;
}
return *this;
}

std::pair<std::string_view, V&> operator*() const {
return {std::string_view{reinterpret_cast<const char*>(it_.key), it_.key_len},
*reinterpret_cast<V*>(it_.data)};
}

private:
raxIterator it_;
};

// Result of find() call. Inherits from pair to mimic iterator interface, not incrementable.
struct FindIterator : public std::optional<std::pair<std::string_view, V&>> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a comment that we use optional here to get "->" operators for free.

bool operator==(const SeekIterator& rhs) const {
if (this->has_value() != !bool(rhs.it_.flags & RAX_ITER_EOF))
return false;
if (!this->has_value())
return true;
return (*this)->first ==
std::string_view{reinterpret_cast<const char*>(rhs.it_.key), rhs.it_.key_len};
}

bool operator!=(const SeekIterator& rhs) const {
return !operator==(rhs);
}
};

public:
explicit RaxTreeMap(PMR_NS::memory_resource* mr) : tree_(raxNew()), mr_(mr) {
}

size_t size() const {
return raxSize(tree_);
}

auto begin() const {
return SeekIterator{tree_};
}

auto end() const {
return SeekIterator{};
}

auto lower_bound(std::string_view key) const {
return SeekIterator{tree_, ">=", key};
}

FindIterator find(std::string_view key) const {
if (void* ptr = raxFind(tree_, to_key_ptr(key), key.size()); ptr != raxNotFound)
return FindIterator{std::pair<std::string_view, V&>(key, *reinterpret_cast<V*>(ptr))};
return FindIterator{std::nullopt};
}

template <typename... Args>
std::pair<FindIterator, bool> try_emplace(std::string_view key, Args&&... args);

void erase(FindIterator it) {
V* old = nullptr;
raxRemove(tree_, to_key_ptr(it->first.data()), it->first.size(),
reinterpret_cast<void**>(&old));
mr_->deallocate(old, sizeof(V), alignof(V));
}

private:
static unsigned char* to_key_ptr(std::string_view key) {
return reinterpret_cast<unsigned char*>(const_cast<char*>(key.data()));
}

rax* tree_;
PMR_NS::memory_resource* mr_;
};

template <typename V>
template <typename... Args>
std::pair<typename RaxTreeMap<V>::FindIterator, bool> RaxTreeMap<V>::try_emplace(
std::string_view key, Args&&... args) {
if (auto it = find(key); it)
return {it, false};

void* ptr = mr_->allocate(sizeof(V), alignof(V));
V* data = new (ptr) V(std::forward<Args>(args)...);
assert(uint64_t(ptr) == uint64_t(data)); // we free by the latter

V* old = nullptr;
raxInsert(tree_, to_key_ptr(key), key.size(), data, reinterpret_cast<void**>(&old));
assert(old == nullptr);

auto it = std::make_optional(std::pair<std::string_view, V&>(key, *data));
return std::make_pair(FindIterator{it}, true);
}

} // namespace dfly::search
107 changes: 107 additions & 0 deletions src/core/search/rax_tree_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright 2023, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//

#include "core/search/rax_tree.h"

#include <absl/container/btree_set.h>
#include <absl/strings/str_cat.h>
#include <gtest/gtest.h>
#include <mimalloc.h>

#include <algorithm>
#include <memory_resource>

#include "base/gtest.h"
#include "base/iterator.h"
#include "base/logging.h"

extern "C" {
#include "redis/zmalloc.h"
}

namespace dfly::search {

using namespace std;

struct RaxTreeTest : public ::testing::Test {
static void SetUpTestSuite() {
auto* tlh = mi_heap_get_backing();
init_zmalloc_threadlocal(tlh);
}
};

TEST_F(RaxTreeTest, EmplaceAndIterate) {
RaxTreeMap<std::string> map(pmr::get_default_resource());

vector<pair<string, string>> elements(90);
for (int i = 10; i < 100; i++)
elements[i - 10] = make_pair(absl::StrCat("key-", i), absl::StrCat("value-", i));

for (auto& [key, value] : elements) {
auto [it, inserted] = map.try_emplace(key, value);
EXPECT_TRUE(inserted);
EXPECT_EQ(it->first, key);
EXPECT_EQ(it->second, value);
}

size_t i = 0;
for (auto [key, value] : map) {
EXPECT_EQ(elements[i].first, key);
EXPECT_EQ(elements[i].second, value);
i++;
}
}

TEST_F(RaxTreeTest, LowerBound) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check how it works with empty string_views, strings. Sometimes redis code has problems with zero length slices that have nullptr pointers

RaxTreeMap<int> map(pmr::get_default_resource());
vector<string> keys;

for (unsigned i = 0; i < 5; i++) {
for (unsigned j = 0; j < 5; j++) {
keys.emplace_back(absl::StrCat("key-", string(1, 'a' + i), "-", j));
map.try_emplace(keys.back(), 0);
}
}

auto it1 = map.lower_bound("key-c-3");
auto it2 = lower_bound(keys.begin(), keys.end(), "key-c-3");

while (it1 != map.end()) {
EXPECT_EQ((*it1).first, *it2);
++it1;
++it2;
}

EXPECT_TRUE(it1 == map.end());
EXPECT_TRUE(it2 == keys.end());

// Test lower bound empty string
vector<string> keys2;
for (auto it = map.lower_bound(string_view{}); it != map.end(); ++it)
keys2.emplace_back((*it).first);
EXPECT_EQ(keys, keys2);
}

TEST_F(RaxTreeTest, Find) {
RaxTreeMap<int> map(pmr::get_default_resource());
for (unsigned i = 100; i < 999; i += 2)
map.try_emplace(absl::StrCat("value-", i), i);

auto it = map.begin();
for (unsigned i = 100; i < 999; i++) {
auto fit = map.find(absl::StrCat("value-", i));
if (i % 2 == 0) {
EXPECT_TRUE(fit == it);
EXPECT_EQ(fit->second, i);
++it;
} else {
EXPECT_TRUE(fit == map.end());
}
}

// Test find with empty string
EXPECT_TRUE(map.find(string_view{}) == map.end());
}

} // namespace dfly::search
Loading