Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(search): Prefix search for tags #3972

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/core/search/ast_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ AstFieldNode::AstFieldNode(string field, AstNode&& node)
: field{field.substr(1)}, node{make_unique<AstNode>(std::move(node))} {
}

AstTagsNode::AstTagsNode(std::string tag) {
AstTagsNode::AstTagsNode(TagValue tag) {
tags = {std::move(tag)};
}

AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) {
AstTagsNode::AstTagsNode(AstExpr&& l, TagValue tag) {
DCHECK(holds_alternative<AstTagsNode>(l));
auto& tags_node = get<AstTagsNode>(l);

Expand Down Expand Up @@ -82,4 +82,8 @@ namespace std {
ostream& operator<<(ostream& os, optional<size_t> o) {
return os;
}

ostream& operator<<(ostream& os, dfly::search::AstTagsNode::TagValueProxy o) {
return os;
}
} // namespace std
23 changes: 18 additions & 5 deletions src/core/search/ast_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,22 @@ struct AstFieldNode {

// Stores a list of tags for a tag query
struct AstTagsNode {
AstTagsNode(std::string tag);
AstTagsNode(AstNode&& l, std::string tag);

std::vector<std::string> tags;
using TagValue = std::variant<AstTermNode, AstPrefixNode>;

struct TagValueProxy
: public AstTagsNode::TagValue { // bison needs it to be default constructible
TagValueProxy() : AstTagsNode::TagValue(AstTermNode("")) {
}
TagValueProxy(AstPrefixNode tv) : AstTagsNode::TagValue(std::move(tv)) {
}
TagValueProxy(AstTermNode tv) : AstTagsNode::TagValue(std::move(tv)) {
}
};

AstTagsNode(TagValue);
AstTagsNode(AstNode&& l, TagValue);

std::vector<TagValue> tags;
};

// Applies nearest neighbor search to the final result set
Expand Down Expand Up @@ -125,4 +137,5 @@ using AstExpr = AstNode;

namespace std {
ostream& operator<<(ostream& os, optional<size_t> o);
}
ostream& operator<<(ostream& os, dfly::search::AstTagsNode::TagValueProxy o);
} // namespace std
11 changes: 6 additions & 5 deletions src/core/search/parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ double toDouble(string_view src);
%nterm <bool> opt_lparen
%nterm <AstExpr> final_query filter search_expr search_unary_expr search_or_expr search_and_expr numeric_filter_expr
%nterm <AstExpr> field_cond field_cond_expr field_unary_expr field_or_expr field_and_expr tag_list
%nterm <std::string> tag_list_element
%nterm <AstTagsNode::TagValueProxy> tag_list_element

%nterm <AstKnnNode> knn_query
%nterm <std::string> opt_knn_alias
Expand Down Expand Up @@ -179,10 +179,11 @@ tag_list:
| tag_list OR_OP tag_list_element { $$ = AstTagsNode(std::move($1), std::move($3)); }

tag_list_element:
TERM { $$ = std::move($1); }
| UINT32 { $$ = std::move($1); }
| DOUBLE { $$ = std::move($1); }
| TAG_VAL { $$ = std::move($1); }
TERM { $$ = AstTermNode(std::move($1)); }
| PREFIX { $$ = AstPrefixNode(std::move($1)); }
| UINT32 { $$ = AstTermNode(std::move($1)); }
| DOUBLE { $$ = AstTermNode(std::move($1)); }
| TAG_VAL { $$ = AstTermNode(std::move($1)); }


%%
Expand Down
47 changes: 35 additions & 12 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ struct IndexResult {

struct ProfileBuilder {
string GetNodeInfo(const AstNode& node) {
struct NodeFormatter {
void operator()(std::string* out, const AstPrefixNode& node) const {
out->append(node.prefix);
}
void operator()(std::string* out, const AstTermNode& node) const {
out->append(node.term);
}
void operator()(std::string* out, const AstTagsNode::TagValue& value) const {
visit([this, out](const auto& n) { this->operator()(out, n); }, value);
}
};
Overloaded node_info{
[](monostate) -> string { return ""s; },
[](const AstTermNode& n) { return absl::StrCat("Term{", n.term, "}"); },
Expand All @@ -125,7 +136,9 @@ struct ProfileBuilder {
auto op = n.op == AstLogicalNode::AND ? "and" : "or";
return absl::StrCat("Logical{n=", n.nodes.size(), ",o=", op, "}");
},
[](const AstTagsNode& n) { return absl::StrCat("Tags{", absl::StrJoin(n.tags, ","), "}"); },
[](const AstTagsNode& n) {
return absl::StrCat("Tags{", absl::StrJoin(n.tags, ",", NodeFormatter()), "}");
},
[](const AstFieldNode& n) { return absl::StrCat("Field{", n.field, "}"); },
[](const AstKnnNode& n) { return absl::StrCat("KNN{l=", n.limit, "}"); },
[](const AstNegateNode& n) { return absl::StrCat("Negate{}"); },
Expand Down Expand Up @@ -248,6 +261,14 @@ struct BasicSearch {
return out;
}

template <typename C>
IndexResult CollectPrefixMatches(BaseStringIndex<C>* index, std::string_view prefix) {
IndexResult result{};
index->MatchingPrefix(
prefix, [&result, this](const auto* c) { Merge(IndexResult{c}, &result, LogicOp::OR); });
return result;
}

IndexResult Search(monostate, string_view) {
return vector<DocId>{};
}
Expand Down Expand Up @@ -283,13 +304,8 @@ struct BasicSearch {
}

auto mapping = [&node, this](TextIndex* index) {
IndexResult result{};
index->MatchingPrefix(node.prefix, [&result, this](const auto* c) {
Merge(IndexResult{c}, &result, LogicOp::OR);
});
return result;
return CollectPrefixMatches(index, node.prefix);
};

return UnifyResults(GetSubResults(indices, mapping), LogicOp::OR);
}

Expand Down Expand Up @@ -330,11 +346,18 @@ struct BasicSearch {

// {tags | ...}: Unify results for all tags
IndexResult Search(const AstTagsNode& node, string_view active_field) {
if (auto* tag_index = GetIndex<TagIndex>(active_field); tag_index) {
auto mapping = [tag_index](string_view tag) { return tag_index->Matching(tag); };
return UnifyResults(GetSubResults(node.tags, mapping), LogicOp::OR);
}
return IndexResult{};
auto* tag_index = GetIndex<TagIndex>(active_field);
if (!tag_index)
return IndexResult{};

Overloaded ov{[tag_index](const AstTermNode& term) -> IndexResult {
return tag_index->Matching(term.term);
},
[tag_index, this](const AstPrefixNode& prefix) {
return CollectPrefixMatches(tag_index, prefix.prefix);
}};
auto mapping = [ov](const auto& tag) { return visit(ov, tag); };
return UnifyResults(GetSubResults(node.tags, mapping), LogicOp::OR);
}

// SORTBY field [DESC]: Sort by field. Part of params and not "core query".
Expand Down
13 changes: 13 additions & 0 deletions src/core/search/search_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,19 @@ TEST_F(SearchTest, CheckTag) {
EXPECT_TRUE(Check()) << GetError();
}

TEST_F(SearchTest, CheckTagPrefix) {
PrepareSchema({{"color", SchemaField::TAG}});
PrepareQuery("@color:{green* | orange | yellow*}");

ExpectAll(Map{{"color", "green"}}, Map{{"color", "yellow"}}, Map{{"color", "greenish"}},
Map{{"color", "yellowish"}}, Map{{"color", "green-forestish"}},
Map{{"color", "yellowsunish"}}, Map{{"color", "orange"}});
ExpectNone(Map{{"color", "red"}}, Map{{"color", "blue"}}, Map{{"color", "orangeish"}},
Map{{"color", "darkgreen"}}, Map{{"color", "light-yellow"}});

EXPECT_TRUE(Check()) << GetError();
}

TEST_F(SearchTest, IntegerTerms) {
PrepareSchema({{"status", SchemaField::TAG}, {"title", SchemaField::TEXT}});

Expand Down
Loading