Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(search_family): Fix crash in FT.PROFILE command for invalid queries #4043

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 100 additions & 47 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,49 +226,49 @@ search::QueryParams ParseQueryParams(CmdArgParser* parser) {
return params;
}

optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, SinkReplyBuilder* builder) {
optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyBuilder* builder) {
SearchParams params;

while (parser.HasNext()) {
while (parser->HasNext()) {
// [LIMIT offset total]
if (parser.Check("LIMIT")) {
params.limit_offset = parser.Next<size_t>();
params.limit_total = parser.Next<size_t>();
} else if (parser.Check("LOAD")) {
if (parser->Check("LIMIT")) {
params.limit_offset = parser->Next<size_t>();
params.limit_total = parser->Next<size_t>();
} else if (parser->Check("LOAD")) {
if (params.return_fields) {
builder->SendError("LOAD cannot be applied after RETURN");
return std::nullopt;
}

ParseLoadFields(&parser, &params.load_fields);
} else if (parser.Check("RETURN")) {
ParseLoadFields(parser, &params.load_fields);
} else if (parser->Check("RETURN")) {
if (params.load_fields) {
builder->SendError("RETURN cannot be applied after LOAD");
return std::nullopt;
}

// RETURN {num} [{ident} AS {name}...]
size_t num_fields = parser.Next<size_t>();
size_t num_fields = parser->Next<size_t>();
params.return_fields.emplace();
while (params.return_fields->size() < num_fields) {
string_view ident = parser.Next();
string_view alias = parser.Check("AS") ? parser.Next() : ident;
string_view ident = parser->Next();
string_view alias = parser->Check("AS") ? parser->Next() : ident;
params.return_fields->emplace_back(ident, alias);
}
} else if (parser.Check("NOCONTENT")) { // NOCONTENT
} else if (parser->Check("NOCONTENT")) { // NOCONTENT
params.load_fields.emplace();
params.return_fields.emplace();
} else if (parser.Check("PARAMS")) { // [PARAMS num(ignored) name(ignored) knn_vector]
params.query_params = ParseQueryParams(&parser);
} else if (parser.Check("SORTBY")) {
params.sort_option = search::SortOption{string{parser.Next()}, bool(parser.Check("DESC"))};
} else if (parser->Check("PARAMS")) { // [PARAMS num(ignored) name(ignored) knn_vector]
params.query_params = ParseQueryParams(parser);
} else if (parser->Check("SORTBY")) {
params.sort_option = search::SortOption{string{parser->Next()}, bool(parser->Check("DESC"))};
} else {
// Unsupported parameters are ignored for now
parser.Skip(1);
parser->Skip(1);
}
}

if (auto err = parser.Error(); err) {
if (auto err = parser->Error(); err) {
builder->SendError(err->MakeReply());
return nullopt;
}
Expand Down Expand Up @@ -716,10 +716,11 @@ void SearchFamily::FtList(CmdArgList args, Transaction* tx, SinkReplyBuilder* bu
}

void SearchFamily::FtSearch(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
string_view index_name = ArgS(args, 0);
string_view query_str = ArgS(args, 1);
CmdArgParser parser{args};
string_view index_name = parser.Next();
string_view query_str = parser.Next();

auto params = ParseSearchParamsOrReply(args.subspan(2), builder);
auto params = ParseSearchParamsOrReply(&parser, builder);
if (!params.has_value())
return;

Expand Down Expand Up @@ -749,77 +750,129 @@ void SearchFamily::FtSearch(CmdArgList args, Transaction* tx, SinkReplyBuilder*
}

if (auto agg = search_algo.HasAggregation(); agg)
ReplySorted(std::move(*agg), *params, absl::MakeSpan(docs), builder);
ReplySorted(*agg, *params, absl::MakeSpan(docs), builder);
else
ReplyWithResults(*params, absl::MakeSpan(docs), builder);
}

void SearchFamily::FtProfile(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
string_view index_name = ArgS(args, 0);
string_view query_str = ArgS(args, 3);
CmdArgParser parser{args};

string_view index_name = parser.Next();

optional<SearchParams> params = ParseSearchParamsOrReply(args.subspan(4), builder);
if (!parser.Check("SEARCH") && !parser.Check("AGGREGATE")) {
return builder->SendError("no `SEARCH` or `AGGREGATE` provided");
}

parser.Check("LIMITED"); // TODO: Implement limited profiling
parser.ExpectTag("QUERY");

string_view query_str = parser.Next();

optional<SearchParams> params = ParseSearchParamsOrReply(&parser, builder);
if (!params.has_value())
return;

search::SearchAlgorithm search_algo;
search::SortOption* sort_opt = params->sort_option.has_value() ? &*params->sort_option : nullptr;
if (!search_algo.Init(query_str, &params->query_params, sort_opt))
return builder->SendError("Query syntax error");
return builder->SendError("query syntax error");

search_algo.EnableProfiling();

absl::Time start = absl::Now();
atomic_uint total_docs = 0;
atomic_uint total_serialized = 0;
const size_t shards_count = shard_set->size();

vector<pair<search::AlgorithmProfile, absl::Duration>> results(shard_set->size());
// Because our coordinator thread may not have a shard, we can't check ahead if the index exists.
std::atomic<bool> index_not_found{false};
std::vector<SearchResult> search_results(shards_count);
std::vector<absl::Duration> profile_results(shards_count);

tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) {
auto* index = es->search_indices()->GetIndex(index_name);
if (!index)
if (!index) {
index_not_found.store(true, memory_order_relaxed);
return OpStatus::OK;
}

auto shard_start = absl::Now();
auto res = index->Search(t->GetOpArgs(es), *params, &search_algo);

total_docs.fetch_add(res.total_hits);
total_serialized.fetch_add(res.docs.size());
const ShardId shard_id = es->shard_id();

DCHECK(res.profile);
results[es->shard_id()] = {std::move(*res.profile), absl::Now() - shard_start};
auto shard_start = absl::Now();
search_results[shard_id] = index->Search(t->GetOpArgs(es), *params, &search_algo);
profile_results[shard_id] = {absl::Now() - shard_start};

return OpStatus::OK;
});

if (index_not_found.load())
return builder->SendError(std::string{index_name} + ": no such index");

auto took = absl::Now() - start;

bool result_is_empty = false;
size_t total_docs = 0;
size_t total_serialized = 0;
for (const auto& result : search_results) {
if (!result.error) {
total_docs += result.total_hits;
total_serialized += result.docs.size();
} else {
result_is_empty = true;
}
}

auto* rb = static_cast<RedisReplyBuilder*>(builder);
rb->StartArray(results.size() + 1);
// First element -> Result of the search command
// Second element -> Profile information
rb->StartArray(2);

// Result of the search command
if (!result_is_empty) {
auto agg = search_algo.HasAggregation();
if (agg) {
ReplySorted(*agg, *params, absl::MakeSpan(search_results), builder);
} else {
ReplyWithResults(*params, absl::MakeSpan(search_results), builder);
}
} else {
rb->StartArray(1);
rb->SendLong(0);
}

// Profile information
rb->StartArray(shards_count + 1);

// General stats
rb->StartCollection(3, RedisReplyBuilder::MAP);
rb->SendBulkString("took");
rb->SendLong(absl::ToInt64Microseconds(took));
rb->SendBulkString("hits");
rb->SendLong(total_docs);
rb->SendLong(static_cast<long>(total_docs));
rb->SendBulkString("serialized");
rb->SendLong(total_serialized);
rb->SendLong(static_cast<long>(total_serialized));

// Per-shard stats
for (const auto& [profile, shard_took] : results) {
for (size_t shard_id = 0; shard_id < shards_count; shard_id++) {
rb->StartCollection(2, RedisReplyBuilder::MAP);
rb->SendBulkString("took");
rb->SendLong(absl::ToInt64Microseconds(shard_took));
rb->SendLong(absl::ToInt64Microseconds(profile_results[shard_id]));
rb->SendBulkString("tree");

for (size_t i = 0; i < profile.events.size(); i++) {
const auto& event = profile.events[i];
const auto& search_result = search_results[shard_id];
if (search_result.error || !search_result.profile || search_result.profile->events.empty()) {
rb->SendEmptyArray();
continue;
}

const auto& events = search_result.profile->events;
for (size_t i = 0; i < events.size(); i++) {
const auto& event = events[i];

size_t children = 0;
for (size_t j = i + 1; j < profile.events.size(); j++) {
if (profile.events[j].depth == event.depth)
for (size_t j = i + 1; j < events.size(); j++) {
if (events[j].depth == event.depth)
break;
if (profile.events[j].depth == event.depth + 1)
if (events[j].depth == event.depth + 1)
BagritsevichStepan marked this conversation as resolved.
Show resolved Hide resolved
children++;
}

Expand Down
62 changes: 59 additions & 3 deletions src/server/search/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,27 @@ class SearchFamilyTest : public BaseFamilyTest {

const auto kNoResults = IntArg(0); // tests auto destruct single element arrays

/* Asserts that response is array of two arrays. Used to test FT.PROFILE response */
::testing::AssertionResult AssertArrayOfTwoArrays(const RespExpr& resp) {
if (resp.GetVec().size() != 2) {
return ::testing::AssertionFailure()
<< "Expected response array length to be 2, but was " << resp.GetVec().size();
}

const auto& vec = resp.GetVec();
if (vec[0].type != RespExpr::ARRAY) {
return ::testing::AssertionFailure()
<< "Expected resp[0] to be an array, but was " << vec[0].type;
}
if (vec[1].type != RespExpr::ARRAY) {
return ::testing::AssertionFailure()
<< "Expected resp[1] to be an array, but was " << vec[1].type;
}
return ::testing::AssertionSuccess();
}

#define ASSERT_ARRAY_OF_TWO_ARRAYS(resp) ASSERT_PRED1(AssertArrayOfTwoArrays, resp)

MATCHER_P2(DocIds, total, arg_ids, "") {
if (arg_ids.empty()) {
if (auto res = arg.GetInt(); !res || *res != 0) {
Expand Down Expand Up @@ -790,20 +811,55 @@ TEST_F(SearchFamilyTest, FtProfile) {
Run({"ft.create", "i1", "schema", "name", "text"});

auto resp = Run({"ft.profile", "i1", "search", "query", "(a | b) c d"});
ASSERT_ARRAY_OF_TWO_ARRAYS(resp);

const auto& top_level = resp.GetVec();
EXPECT_EQ(top_level.size(), shard_set->size() + 1);
EXPECT_THAT(top_level[0], IsMapWithSize());

const auto& profile_result = top_level[1].GetVec();
EXPECT_EQ(profile_result.size(), shard_set->size() + 1);

EXPECT_THAT(top_level[0].GetVec(), ElementsAre("took", _, "hits", _, "serialized", _));
EXPECT_THAT(profile_result[0].GetVec(), ElementsAre("took", _, "hits", _, "serialized", _));

for (size_t sid = 0; sid < shard_set->size(); sid++) {
const auto& shard_resp = top_level[sid + 1].GetVec();
const auto& shard_resp = profile_result[sid + 1].GetVec();
EXPECT_THAT(shard_resp, ElementsAre("took", _, "tree", _));

const auto& tree = shard_resp[3].GetVec();
EXPECT_THAT(tree[0].GetString(), HasSubstr("Logical{n=3,o=and}"sv));
EXPECT_EQ(tree[1].GetVec().size(), 3);
}

// Test LIMITED throws no errors
resp = Run({"ft.profile", "i1", "search", "limited", "query", "(a | b) c d"});
ASSERT_ARRAY_OF_TWO_ARRAYS(resp);
}

TEST_F(SearchFamilyTest, FtProfileInvalidQuery) {
Run({"json.set", "j1", ".", R"({"id":"1"})"});
Run({"ft.create", "i1", "on", "json", "schema", "$.id", "as", "id", "tag"});

auto resp = Run({"ft.profile", "i1", "search", "query", "@id:[1 1]"});
ASSERT_ARRAY_OF_TWO_ARRAYS(resp);

EXPECT_THAT(resp.GetVec()[0], IsMapWithSize());

resp = Run({"ft.profile", "i1", "search", "query", "@{invalid13289}"});
EXPECT_THAT(resp, ErrArg("query syntax error"));
}

TEST_F(SearchFamilyTest, FtProfileErrorReply) {
Run({"ft.create", "i1", "schema", "name", "text"});
;

auto resp = Run({"ft.profile", "i1", "not_search", "query", "(a | b) c d"});
EXPECT_THAT(resp, ErrArg("no `SEARCH` or `AGGREGATE` provided"));

resp = Run({"ft.profile", "i1", "search", "not_query", "(a | b) c d"});
EXPECT_THAT(resp, ErrArg("syntax error"));

resp = Run({"ft.profile", "non_existent_key", "search", "query", "(a | b) c d"});
EXPECT_THAT(resp, ErrArg("non_existent_key: no such index"));
}

TEST_F(SearchFamilyTest, SimpleExpiry) {
Expand Down
Loading