Skip to content

Commit

Permalink
fix(search_family): Remove the output of extra fields in the FT.AGGRE…
Browse files Browse the repository at this point in the history
…GATE command (#4231)

* fix(search_family): Remove the output of extra fields in the FT.AGGREGATE command

fixes #4230

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* refactor: address comments

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

---------

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
  • Loading branch information
BagritsevichStepan authored Dec 11, 2024
1 parent 1e3d9de commit 76f79f0
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 63 deletions.
47 changes: 33 additions & 14 deletions src/server/search/aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ namespace dfly::aggregate {
namespace {

struct GroupStep {
PipelineResult operator()(std::vector<DocValues> values) {
PipelineResult operator()(PipelineResult result) {
// Separate items into groups
absl::flat_hash_map<absl::FixedArray<Value>, std::vector<DocValues>> groups;
for (auto& value : values) {
for (auto& value : result.values) {
groups[Extract(value)].push_back(std::move(value));
}

Expand All @@ -28,7 +28,18 @@ struct GroupStep {
}
out.push_back(std::move(doc));
}
return out;

absl::flat_hash_set<std::string> fields_to_print;
fields_to_print.reserve(fields_.size() + reducers_.size());

for (auto& field : fields_) {
fields_to_print.insert(std::move(field));
}
for (auto& reducer : reducers_) {
fields_to_print.insert(std::move(reducer.result_field));
}

return {std::move(out), std::move(fields_to_print)};
}

absl::FixedArray<Value> Extract(const DocValues& dv) {
Expand Down Expand Up @@ -104,34 +115,42 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
}

PipelineStep MakeSortStep(std::string_view field, bool descending) {
return [field = std::string(field), descending](std::vector<DocValues> values) -> PipelineResult {
return [field = std::string(field), descending](PipelineResult result) -> PipelineResult {
auto& values = result.values;

std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
auto it1 = l.find(field);
auto it2 = r.find(field);
return it1 == l.end() || (it2 != r.end() && it1->second < it2->second);
});
if (descending)

if (descending) {
std::reverse(values.begin(), values.end());
return values;
}

result.fields_to_print.insert(field);
return result;
};
}

PipelineStep MakeLimitStep(size_t offset, size_t num) {
return [offset, num](std::vector<DocValues> values) -> PipelineResult {
return [offset, num](PipelineResult result) {
auto& values = result.values;
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
values.resize(std::min(num, values.size()));
return values;
return result;
};
}

PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps) {
PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps) {
PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
for (auto& step : steps) {
auto result = step(std::move(values));
if (!result.has_value())
return result;
values = std::move(result.value());
PipelineResult step_result = step(std::move(result));
result = std::move(step_result);
}
return values;
return result;
}

} // namespace dfly::aggregate
17 changes: 13 additions & 4 deletions src/server/search/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include <absl/container/flat_hash_map.h>
#include <absl/container/flat_hash_set.h>
#include <absl/types/span.h>

#include <string>
Expand All @@ -19,10 +20,16 @@ namespace dfly::aggregate {
using Value = ::dfly::search::SortableValue;
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline

// TODO: Replace DocValues with compact linear search map instead of hash map
struct PipelineResult {
// Values to be passed to the next step
// TODO: Replace DocValues with compact linear search map instead of hash map
std::vector<DocValues> values;

using PipelineResult = io::Result<std::vector<DocValues>, facade::ErrorReply>;
using PipelineStep = std::function<PipelineResult(std::vector<DocValues>)>; // Group, Sort, etc.
// Fields from values to be printed
absl::flat_hash_set<std::string> fields_to_print;
};

using PipelineStep = std::function<PipelineResult(PipelineResult)>; // Group, Sort, etc.

// Iterator over Span<DocValues> that yields doc[field] or monostate if not present.
// Extra clumsy for STL compatibility!
Expand Down Expand Up @@ -82,6 +89,8 @@ PipelineStep MakeSortStep(std::string_view field, bool descending = false);
PipelineStep MakeLimitStep(size_t offset, size_t num);

// Process values with given steps
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps);
PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps);

} // namespace dfly::aggregate
52 changes: 24 additions & 28 deletions src/server/search/aggregator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ TEST(AggregatorTest, Sort) {
};
PipelineStep steps[] = {MakeSortStep("a", false)};

auto result = Process(values, steps);
auto result = Process(values, {"a"}, steps);

EXPECT_TRUE(result);
EXPECT_EQ(result->at(0)["a"], Value(0.5));
EXPECT_EQ(result->at(1)["a"], Value(1.0));
EXPECT_EQ(result->at(2)["a"], Value(1.5));
EXPECT_EQ(result.values[0]["a"], Value(0.5));
EXPECT_EQ(result.values[1]["a"], Value(1.0));
EXPECT_EQ(result.values[2]["a"], Value(1.5));
}

TEST(AggregatorTest, Limit) {
Expand All @@ -35,12 +34,11 @@ TEST(AggregatorTest, Limit) {
};
PipelineStep steps[] = {MakeLimitStep(1, 2)};

auto result = Process(values, steps);
auto result = Process(values, {"i"}, steps);

EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
EXPECT_EQ(result->at(0)["i"], Value(2.0));
EXPECT_EQ(result->at(1)["i"], Value(3.0));
EXPECT_EQ(result.values.size(), 2);
EXPECT_EQ(result.values[0]["i"], Value(2.0));
EXPECT_EQ(result.values[1]["i"], Value(3.0));
}

TEST(AggregatorTest, SimpleGroup) {
Expand All @@ -54,12 +52,11 @@ TEST(AggregatorTest, SimpleGroup) {
std::string_view fields[] = {"tag"};
PipelineStep steps[] = {MakeGroupStep(fields, {})};

auto result = Process(values, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
auto result = Process(values, {"i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);

EXPECT_EQ(result->at(0).size(), 1);
std::set<Value> groups{result->at(0)["tag"], result->at(1)["tag"]};
EXPECT_EQ(result.values[0].size(), 1);
std::set<Value> groups{result.values[0]["tag"], result.values[1]["tag"]};
std::set<Value> expected{"even", "odd"};
EXPECT_EQ(groups, expected);
}
Expand All @@ -83,25 +80,24 @@ TEST(AggregatorTest, GroupWithReduce) {
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};

auto result = Process(values, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
auto result = Process(values, {"i", "half-i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);

// Reorder even first
if (result->at(0).at("tag") == Value("odd"))
std::swap(result->at(0), result->at(1));
if (result.values[0].at("tag") == Value("odd"))
std::swap(result.values[0], result.values[1]);

// Even
EXPECT_EQ(result->at(0).at("count"), Value{(double)5});
EXPECT_EQ(result->at(0).at("sum-i"), Value{(double)2 + 4 + 6 + 8});
EXPECT_EQ(result->at(0).at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(0).at("distinct-null"), Value{(double)1});
EXPECT_EQ(result.values[0].at("count"), Value{(double)5});
EXPECT_EQ(result.values[0].at("sum-i"), Value{(double)2 + 4 + 6 + 8});
EXPECT_EQ(result.values[0].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result.values[0].at("distinct-null"), Value{(double)1});

// Odd
EXPECT_EQ(result->at(1).at("count"), Value{(double)5});
EXPECT_EQ(result->at(1).at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9});
EXPECT_EQ(result->at(1).at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(1).at("distinct-null"), Value{(double)1});
EXPECT_EQ(result.values[1].at("count"), Value{(double)5});
EXPECT_EQ(result.values[1].at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9});
EXPECT_EQ(result.values[1].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result.values[1].at("distinct-null"), Value{(double)1});
}

} // namespace dfly::aggregate
30 changes: 21 additions & 9 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -981,22 +981,34 @@ void SearchFamily::FtAggregate(CmdArgList args, const CommandContext& cmd_cntx)
make_move_iterator(sub_results.end()));
}

auto agg_results = aggregate::Process(std::move(values), params->steps);
if (!agg_results.has_value())
return builder->SendError(agg_results.error());
std::vector<std::string_view> load_fields;
if (params->load_fields) {
load_fields.reserve(params->load_fields->size());
for (const auto& field : params->load_fields.value()) {
load_fields.push_back(field.GetShortName());
}
}

auto agg_results = aggregate::Process(std::move(values), load_fields, params->steps);

size_t result_size = agg_results->size();
auto* rb = static_cast<RedisReplyBuilder*>(cmd_cntx.rb);
auto sortable_value_sender = SortableValueSender(rb);

const size_t result_size = agg_results.values.size();
rb->StartArray(result_size + 1);
rb->SendLong(result_size);

for (const auto& result : agg_results.value()) {
rb->StartArray(result.size() * 2);
for (const auto& [k, v] : result) {
rb->SendBulkString(k);
std::visit(sortable_value_sender, v);
const size_t field_count = agg_results.fields_to_print.size();
for (const auto& value : agg_results.values) {
rb->StartArray(field_count * 2);
for (const auto& field : agg_results.fields_to_print) {
rb->SendBulkString(field);

if (auto it = value.find(field); it != value.end()) {
std::visit(sortable_value_sender, it->second);
} else {
rb->SendNull();
}
}
}
}
Expand Down
41 changes: 33 additions & 8 deletions src/server/search/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -962,15 +962,12 @@ TEST_F(SearchFamilyTest, AggregateGroupBy) {
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("foo_total", "20", "word", "item2"),
IsMap("foo_total", "50", "word", "item1")));

/*
Temporary not supported
resp = Run({"ft.aggregate", "i1", "*", "LOAD", "2", "foo", "text", "GROUPBY", "2", "@word",
"@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"}); EXPECT_THAT(resp,
IsUnordArrayWithSize(IsMap("foo_total", "20", "word", ArgType(RespExpr::NIL), "text", "\"second
key\""), IsMap("foo_total", "40", "word", ArgType(RespExpr::NIL), "text", "\"third key\""),
IsMap({"foo_total", "10", "word", ArgType(RespExpr::NIL), "text", "\"first key"})));
*/
"@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("foo_total", "40", "word", "item1", "text", "\"third key\""),
IsMap("foo_total", "20", "word", "item2", "text", "\"second key\""),
IsMap("foo_total", "10", "word", "item1", "text", "\"first key\"")));
}

TEST_F(SearchFamilyTest, JsonAggregateGroupBy) {
Expand Down Expand Up @@ -1632,4 +1629,32 @@ TEST_F(SearchFamilyTest, SearchLoadReturnHash) {
EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("a", "two"), "h1", IsMap("a", "one")));
}

// Test that FT.AGGREGATE prints only needed fields
TEST_F(SearchFamilyTest, AggregateResultFields) {
Run({"JSON.SET", "j1", ".", R"({"a":"1","b":"2","c":"3"})"});
Run({"JSON.SET", "j2", ".", R"({"a":"4","b":"5","c":"6"})"});
Run({"JSON.SET", "j3", ".", R"({"a":"7","b":"8","c":"9"})"});

auto resp = Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.a", "AS", "a", "TEXT",
"SORTABLE", "$.b", "AS", "b", "TEXT", "$.c", "AS", "c", "TEXT"});
EXPECT_EQ(resp, "OK");

resp = Run({"FT.AGGREGATE", "index", "*"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap(), IsMap(), IsMap()));

resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("a", "1"), IsMap("a", "4"), IsMap("a", "7")));

resp = Run({"FT.AGGREGATE", "index", "*", "LOAD", "1", "@b", "SORTBY", "1", "a"});
EXPECT_THAT(resp,
IsUnordArrayWithSize(IsMap("b", "\"2\"", "a", "1"), IsMap("b", "\"5\"", "a", "4"),
IsMap("b", "\"8\"", "a", "7")));

resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a", "GROUPBY", "2", "@b", "@a",
"REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("b", "\"8\"", "a", "7", "count", "1"),
IsMap("b", "\"2\"", "a", "1", "count", "1"),
IsMap("b", "\"5\"", "a", "4", "count", "1")));
}

} // namespace dfly

0 comments on commit 76f79f0

Please sign in to comment.