Skip to content

Commit

Permalink
fix(search_family): Fix crash in FT.PROFILE command for invalid queri…
Browse files Browse the repository at this point in the history
…es (#4043)

* refactor(search_family): Remove unnecessary std::move in FT.SEARCH

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* fix(search_family): Fix crash in FT.PROFILE command for invalid queries

fixes #3983

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* refactor(search_family_test): address comments

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

---------

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
  • Loading branch information
BagritsevichStepan authored Nov 4, 2024
1 parent 9c2fc3f commit 7ac8535
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 50 deletions.
147 changes: 100 additions & 47 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,49 +226,49 @@ search::QueryParams ParseQueryParams(CmdArgParser* parser) {
return params;
}

optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser parser, SinkReplyBuilder* builder) {
optional<SearchParams> ParseSearchParamsOrReply(CmdArgParser* parser, SinkReplyBuilder* builder) {
SearchParams params;

while (parser.HasNext()) {
while (parser->HasNext()) {
// [LIMIT offset total]
if (parser.Check("LIMIT")) {
params.limit_offset = parser.Next<size_t>();
params.limit_total = parser.Next<size_t>();
} else if (parser.Check("LOAD")) {
if (parser->Check("LIMIT")) {
params.limit_offset = parser->Next<size_t>();
params.limit_total = parser->Next<size_t>();
} else if (parser->Check("LOAD")) {
if (params.return_fields) {
builder->SendError("LOAD cannot be applied after RETURN");
return std::nullopt;
}

ParseLoadFields(&parser, &params.load_fields);
} else if (parser.Check("RETURN")) {
ParseLoadFields(parser, &params.load_fields);
} else if (parser->Check("RETURN")) {
if (params.load_fields) {
builder->SendError("RETURN cannot be applied after LOAD");
return std::nullopt;
}

// RETURN {num} [{ident} AS {name}...]
size_t num_fields = parser.Next<size_t>();
size_t num_fields = parser->Next<size_t>();
params.return_fields.emplace();
while (params.return_fields->size() < num_fields) {
string_view ident = parser.Next();
string_view alias = parser.Check("AS") ? parser.Next() : ident;
string_view ident = parser->Next();
string_view alias = parser->Check("AS") ? parser->Next() : ident;
params.return_fields->emplace_back(ident, alias);
}
} else if (parser.Check("NOCONTENT")) { // NOCONTENT
} else if (parser->Check("NOCONTENT")) { // NOCONTENT
params.load_fields.emplace();
params.return_fields.emplace();
} else if (parser.Check("PARAMS")) { // [PARAMS num(ignored) name(ignored) knn_vector]
params.query_params = ParseQueryParams(&parser);
} else if (parser.Check("SORTBY")) {
params.sort_option = search::SortOption{string{parser.Next()}, bool(parser.Check("DESC"))};
} else if (parser->Check("PARAMS")) { // [PARAMS num(ignored) name(ignored) knn_vector]
params.query_params = ParseQueryParams(parser);
} else if (parser->Check("SORTBY")) {
params.sort_option = search::SortOption{string{parser->Next()}, bool(parser->Check("DESC"))};
} else {
// Unsupported parameters are ignored for now
parser.Skip(1);
parser->Skip(1);
}
}

if (auto err = parser.Error(); err) {
if (auto err = parser->Error(); err) {
builder->SendError(err->MakeReply());
return nullopt;
}
Expand Down Expand Up @@ -716,10 +716,11 @@ void SearchFamily::FtList(CmdArgList args, Transaction* tx, SinkReplyBuilder* bu
}

void SearchFamily::FtSearch(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
string_view index_name = ArgS(args, 0);
string_view query_str = ArgS(args, 1);
CmdArgParser parser{args};
string_view index_name = parser.Next();
string_view query_str = parser.Next();

auto params = ParseSearchParamsOrReply(args.subspan(2), builder);
auto params = ParseSearchParamsOrReply(&parser, builder);
if (!params.has_value())
return;

Expand Down Expand Up @@ -749,77 +750,129 @@ void SearchFamily::FtSearch(CmdArgList args, Transaction* tx, SinkReplyBuilder*
}

if (auto agg = search_algo.HasAggregation(); agg)
ReplySorted(std::move(*agg), *params, absl::MakeSpan(docs), builder);
ReplySorted(*agg, *params, absl::MakeSpan(docs), builder);
else
ReplyWithResults(*params, absl::MakeSpan(docs), builder);
}

void SearchFamily::FtProfile(CmdArgList args, Transaction* tx, SinkReplyBuilder* builder) {
string_view index_name = ArgS(args, 0);
string_view query_str = ArgS(args, 3);
CmdArgParser parser{args};

string_view index_name = parser.Next();

optional<SearchParams> params = ParseSearchParamsOrReply(args.subspan(4), builder);
if (!parser.Check("SEARCH") && !parser.Check("AGGREGATE")) {
return builder->SendError("no `SEARCH` or `AGGREGATE` provided");
}

parser.Check("LIMITED"); // TODO: Implement limited profiling
parser.ExpectTag("QUERY");

string_view query_str = parser.Next();

optional<SearchParams> params = ParseSearchParamsOrReply(&parser, builder);
if (!params.has_value())
return;

search::SearchAlgorithm search_algo;
search::SortOption* sort_opt = params->sort_option.has_value() ? &*params->sort_option : nullptr;
if (!search_algo.Init(query_str, &params->query_params, sort_opt))
return builder->SendError("Query syntax error");
return builder->SendError("query syntax error");

search_algo.EnableProfiling();

absl::Time start = absl::Now();
atomic_uint total_docs = 0;
atomic_uint total_serialized = 0;
const size_t shards_count = shard_set->size();

vector<pair<search::AlgorithmProfile, absl::Duration>> results(shard_set->size());
// Because our coordinator thread may not have a shard, we can't check ahead if the index exists.
std::atomic<bool> index_not_found{false};
std::vector<SearchResult> search_results(shards_count);
std::vector<absl::Duration> profile_results(shards_count);

tx->ScheduleSingleHop([&](Transaction* t, EngineShard* es) {
auto* index = es->search_indices()->GetIndex(index_name);
if (!index)
if (!index) {
index_not_found.store(true, memory_order_relaxed);
return OpStatus::OK;
}

auto shard_start = absl::Now();
auto res = index->Search(t->GetOpArgs(es), *params, &search_algo);

total_docs.fetch_add(res.total_hits);
total_serialized.fetch_add(res.docs.size());
const ShardId shard_id = es->shard_id();

DCHECK(res.profile);
results[es->shard_id()] = {std::move(*res.profile), absl::Now() - shard_start};
auto shard_start = absl::Now();
search_results[shard_id] = index->Search(t->GetOpArgs(es), *params, &search_algo);
profile_results[shard_id] = {absl::Now() - shard_start};

return OpStatus::OK;
});

if (index_not_found.load())
return builder->SendError(std::string{index_name} + ": no such index");

auto took = absl::Now() - start;

bool result_is_empty = false;
size_t total_docs = 0;
size_t total_serialized = 0;
for (const auto& result : search_results) {
if (!result.error) {
total_docs += result.total_hits;
total_serialized += result.docs.size();
} else {
result_is_empty = true;
}
}

auto* rb = static_cast<RedisReplyBuilder*>(builder);
rb->StartArray(results.size() + 1);
// First element -> Result of the search command
// Second element -> Profile information
rb->StartArray(2);

// Result of the search command
if (!result_is_empty) {
auto agg = search_algo.HasAggregation();
if (agg) {
ReplySorted(*agg, *params, absl::MakeSpan(search_results), builder);
} else {
ReplyWithResults(*params, absl::MakeSpan(search_results), builder);
}
} else {
rb->StartArray(1);
rb->SendLong(0);
}

// Profile information
rb->StartArray(shards_count + 1);

// General stats
rb->StartCollection(3, RedisReplyBuilder::MAP);
rb->SendBulkString("took");
rb->SendLong(absl::ToInt64Microseconds(took));
rb->SendBulkString("hits");
rb->SendLong(total_docs);
rb->SendLong(static_cast<long>(total_docs));
rb->SendBulkString("serialized");
rb->SendLong(total_serialized);
rb->SendLong(static_cast<long>(total_serialized));

// Per-shard stats
for (const auto& [profile, shard_took] : results) {
for (size_t shard_id = 0; shard_id < shards_count; shard_id++) {
rb->StartCollection(2, RedisReplyBuilder::MAP);
rb->SendBulkString("took");
rb->SendLong(absl::ToInt64Microseconds(shard_took));
rb->SendLong(absl::ToInt64Microseconds(profile_results[shard_id]));
rb->SendBulkString("tree");

for (size_t i = 0; i < profile.events.size(); i++) {
const auto& event = profile.events[i];
const auto& search_result = search_results[shard_id];
if (search_result.error || !search_result.profile || search_result.profile->events.empty()) {
rb->SendEmptyArray();
continue;
}

const auto& events = search_result.profile->events;
for (size_t i = 0; i < events.size(); i++) {
const auto& event = events[i];

size_t children = 0;
for (size_t j = i + 1; j < profile.events.size(); j++) {
if (profile.events[j].depth == event.depth)
for (size_t j = i + 1; j < events.size(); j++) {
if (events[j].depth == event.depth)
break;
if (profile.events[j].depth == event.depth + 1)
if (events[j].depth == event.depth + 1)
children++;
}

Expand Down
62 changes: 59 additions & 3 deletions src/server/search/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,27 @@ class SearchFamilyTest : public BaseFamilyTest {

const auto kNoResults = IntArg(0); // tests auto destruct single element arrays

/* Asserts that response is array of two arrays. Used to test FT.PROFILE response */
::testing::AssertionResult AssertArrayOfTwoArrays(const RespExpr& resp) {
if (resp.GetVec().size() != 2) {
return ::testing::AssertionFailure()
<< "Expected response array length to be 2, but was " << resp.GetVec().size();
}

const auto& vec = resp.GetVec();
if (vec[0].type != RespExpr::ARRAY) {
return ::testing::AssertionFailure()
<< "Expected resp[0] to be an array, but was " << vec[0].type;
}
if (vec[1].type != RespExpr::ARRAY) {
return ::testing::AssertionFailure()
<< "Expected resp[1] to be an array, but was " << vec[1].type;
}
return ::testing::AssertionSuccess();
}

#define ASSERT_ARRAY_OF_TWO_ARRAYS(resp) ASSERT_PRED1(AssertArrayOfTwoArrays, resp)

MATCHER_P2(DocIds, total, arg_ids, "") {
if (arg_ids.empty()) {
if (auto res = arg.GetInt(); !res || *res != 0) {
Expand Down Expand Up @@ -790,20 +811,55 @@ TEST_F(SearchFamilyTest, FtProfile) {
Run({"ft.create", "i1", "schema", "name", "text"});

auto resp = Run({"ft.profile", "i1", "search", "query", "(a | b) c d"});
ASSERT_ARRAY_OF_TWO_ARRAYS(resp);

const auto& top_level = resp.GetVec();
EXPECT_EQ(top_level.size(), shard_set->size() + 1);
EXPECT_THAT(top_level[0], IsMapWithSize());

const auto& profile_result = top_level[1].GetVec();
EXPECT_EQ(profile_result.size(), shard_set->size() + 1);

EXPECT_THAT(top_level[0].GetVec(), ElementsAre("took", _, "hits", _, "serialized", _));
EXPECT_THAT(profile_result[0].GetVec(), ElementsAre("took", _, "hits", _, "serialized", _));

for (size_t sid = 0; sid < shard_set->size(); sid++) {
const auto& shard_resp = top_level[sid + 1].GetVec();
const auto& shard_resp = profile_result[sid + 1].GetVec();
EXPECT_THAT(shard_resp, ElementsAre("took", _, "tree", _));

const auto& tree = shard_resp[3].GetVec();
EXPECT_THAT(tree[0].GetString(), HasSubstr("Logical{n=3,o=and}"sv));
EXPECT_EQ(tree[1].GetVec().size(), 3);
}

// Test LIMITED throws no errors
resp = Run({"ft.profile", "i1", "search", "limited", "query", "(a | b) c d"});
ASSERT_ARRAY_OF_TWO_ARRAYS(resp);
}

TEST_F(SearchFamilyTest, FtProfileInvalidQuery) {
Run({"json.set", "j1", ".", R"({"id":"1"})"});
Run({"ft.create", "i1", "on", "json", "schema", "$.id", "as", "id", "tag"});

auto resp = Run({"ft.profile", "i1", "search", "query", "@id:[1 1]"});
ASSERT_ARRAY_OF_TWO_ARRAYS(resp);

EXPECT_THAT(resp.GetVec()[0], IsMapWithSize());

resp = Run({"ft.profile", "i1", "search", "query", "@{invalid13289}"});
EXPECT_THAT(resp, ErrArg("query syntax error"));
}

TEST_F(SearchFamilyTest, FtProfileErrorReply) {
Run({"ft.create", "i1", "schema", "name", "text"});
;

auto resp = Run({"ft.profile", "i1", "not_search", "query", "(a | b) c d"});
EXPECT_THAT(resp, ErrArg("no `SEARCH` or `AGGREGATE` provided"));

resp = Run({"ft.profile", "i1", "search", "not_query", "(a | b) c d"});
EXPECT_THAT(resp, ErrArg("syntax error"));

resp = Run({"ft.profile", "non_existent_key", "search", "query", "(a | b) c d"});
EXPECT_THAT(resp, ErrArg("non_existent_key: no such index"));
}

TEST_F(SearchFamilyTest, SimpleExpiry) {
Expand Down

0 comments on commit 7ac8535

Please sign in to comment.