Skip to content

Commit

Permalink
This is an automated cherry-pick of #3822
Browse files Browse the repository at this point in the history
Signed-off-by: ti-chi-bot <ti-community-prow-bot@tidb.io>
  • Loading branch information
XuHuaiyu authored and ti-chi-bot committed Feb 14, 2022
1 parent 8242cf2 commit 3fdd446
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 17 deletions.
98 changes: 86 additions & 12 deletions dbms/src/Debug/dbgFuncCoprocessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,16 @@ std::unordered_map<String, tipb::ScalarFuncSig> func_name_to_sig({

});

std::unordered_map<String, tipb::ExprType> agg_func_name_to_sig({
{"min", tipb::ExprType::Min},
{"max", tipb::ExprType::Max},
{"count", tipb::ExprType::Count},
{"sum", tipb::ExprType::Sum},
{"first_row", tipb::ExprType::First},
{"uniqRawRes", tipb::ExprType::ApproxCountDistinct},
{"group_concat", tipb::ExprType::GroupConcat},
});

std::pair<String, String> splitQualifiedName(String s)
{
std::pair<String, String> ret;
Expand Down Expand Up @@ -337,12 +347,20 @@ BlockInputStreamPtr executeQuery(Context & context, RegionID region_id, const DA
throw Exception("Meet error while dispatch mpp task: " + call.getResp()->error().msg());
}
tipb::ExchangeReceiver tipb_exchange_receiver;
<<<<<<< HEAD
for (size_t i = 0; i < root_task_ids.size(); i++)
=======
for (const auto root_task_id : root_task_ids)
>>>>>>> 5bd08d6040 (set empty_result_for_aggregation_by_empty_set according to AggregateFuncMode (#3822))
{
mpp::TaskMeta tm;
tm.set_start_ts(properties.start_ts);
tm.set_address(LOCAL_HOST);
<<<<<<< HEAD
tm.set_task_id(root_task_ids[i]);
=======
tm.set_task_id(root_task_id);
>>>>>>> 5bd08d6040 (set empty_result_for_aggregation_by_empty_set according to AggregateFuncMode (#3822))
tm.set_partition_id(-1);
auto * tm_string = tipb_exchange_receiver.add_encoded_task_meta();
tm.AppendToString(tm_string);
Expand Down Expand Up @@ -861,12 +879,18 @@ void collectUsedColumnsFromExpr(const DAGSchema & input, ASTPtr ast, std::unorde
struct MPPCtx
{
Timestamp start_ts;
Int64 partition_num;
Int64 next_task_id;
std::vector<Int64> sender_target_task_ids;
<<<<<<< HEAD
std::vector<Int64> current_task_ids;
std::vector<Int64> partition_keys;
MPPCtx(Timestamp start_ts_, size_t partition_num_) : start_ts(start_ts_), partition_num(partition_num_), next_task_id(1) {}
=======
explicit MPPCtx(Timestamp start_ts_)
: start_ts(start_ts_)
, next_task_id(1)
{}
>>>>>>> 5bd08d6040 (set empty_result_for_aggregation_by_empty_set according to AggregateFuncMode (#3822))
};

using MPPCtxPtr = std::shared_ptr<MPPCtx>;
Expand Down Expand Up @@ -1170,16 +1194,26 @@ struct Aggregation : public Executor
tipb::Expr * arg_expr = agg_func->add_children();
astToPB(input_schema, arg, arg_expr, collator_id, context);
}
auto agg_sig_it = agg_func_name_to_sig.find(func->name);
if (agg_sig_it == agg_func_name_to_sig.end())
throw Exception("Unsupported agg function " + func->name, ErrorCodes::LOGICAL_ERROR);
auto agg_sig = agg_sig_it->second;
agg_func->set_tp(agg_sig);

if (func->name == "count")
if (agg_sig == tipb::ExprType::Count || agg_sig == tipb::ExprType::Sum)
{
<<<<<<< HEAD
agg_func->set_tp(tipb::Count);
auto ft = agg_func->mutable_field_type();
=======
auto * ft = agg_func->mutable_field_type();
>>>>>>> 5bd08d6040 (set empty_result_for_aggregation_by_empty_set according to AggregateFuncMode (#3822))
ft->set_tp(TiDB::TypeLongLong);
ft->set_flag(TiDB::ColumnFlagUnsigned | TiDB::ColumnFlagNotNull);
}
else if (func->name == "sum")
else if (agg_sig == tipb::ExprType::Min || agg_sig == tipb::ExprType::Max || agg_sig == tipb::ExprType::First)
{
<<<<<<< HEAD
agg_func->set_tp(tipb::Sum);
auto ft = agg_func->mutable_field_type();
ft->set_tp(TiDB::TypeLongLong);
Expand All @@ -1192,10 +1226,19 @@ struct Aggregation : public Executor
throw Exception("udaf max only accept 1 argument");
auto ft = agg_func->mutable_field_type();
ft->set_tp(agg_func->children(0).field_type().tp());
=======
if (agg_func->children_size() != 1)
throw Exception("udaf " + func->name + " only accept 1 argument");
auto * ft = agg_func->mutable_field_type();
ft->set_tp(agg_func->children(0).field_type().tp());
ft->set_decimal(agg_func->children(0).field_type().decimal());
ft->set_flag(agg_func->children(0).field_type().flag() & (~TiDB::ColumnFlagNotNull));
>>>>>>> 5bd08d6040 (set empty_result_for_aggregation_by_empty_set according to AggregateFuncMode (#3822))
ft->set_collate(collator_id);
}
else if (func->name == "min")
else if (agg_sig == tipb::ExprType::ApproxCountDistinct)
{
<<<<<<< HEAD
agg_func->set_tp(tipb::Min);
if (agg_func->children_size() != 1)
throw Exception("udaf min only accept 1 argument");
Expand All @@ -1211,10 +1254,21 @@ struct Aggregation : public Executor
ft->set_flag(1);
}
// TODO: Other agg func.
else
=======
auto * ft = agg_func->mutable_field_type();
ft->set_tp(TiDB::TypeString);
ft->set_flag(1);
}
else if (agg_sig == tipb::ExprType::GroupConcat)
{
throw Exception("Unsupported agg function " + func->name, ErrorCodes::LOGICAL_ERROR);
auto * ft = agg_func->mutable_field_type();
ft->set_tp(TiDB::TypeString);
}
if (is_final_mode)
agg_func->set_aggfuncmode(tipb::AggFunctionMode::FinalMode);
>>>>>>> 5bd08d6040 (set empty_result_for_aggregation_by_empty_set according to AggregateFuncMode (#3822))
else
agg_func->set_aggfuncmode(tipb::AggFunctionMode::Partial1Mode);
}

for (const auto & child : gby_exprs)
Expand Down Expand Up @@ -1265,8 +1319,11 @@ struct Aggregation : public Executor
// todo support avg
if (has_uniq_raw_res)
throw Exception("uniq raw res not supported in mpp query");
<<<<<<< HEAD
if (gby_exprs.size() == 0)
throw Exception("agg without group by columns not supported in mpp query");
=======
>>>>>>> 5bd08d6040 (set empty_result_for_aggregation_by_empty_set according to AggregateFuncMode (#3822))
std::shared_ptr<Aggregation> partial_agg = std::make_shared<Aggregation>(
executor_index, output_schema_for_partial_agg, has_uniq_raw_res, false, std::move(agg_exprs), std::move(gby_exprs), false);
partial_agg->children.push_back(children[0]);
Expand All @@ -1277,7 +1334,7 @@ struct Aggregation : public Executor
partition_keys.push_back(i + agg_func_num);
}
std::shared_ptr<ExchangeSender> exchange_sender
= std::make_shared<ExchangeSender>(executor_index, output_schema_for_partial_agg, tipb::Hash, partition_keys);
= std::make_shared<ExchangeSender>(executor_index, output_schema_for_partial_agg, partition_keys.empty() ? tipb::PassThrough : tipb::Hash, partition_keys);
exchange_sender->children.push_back(partial_agg);

std::shared_ptr<ExchangeReceiver> exchange_receiver
Expand Down Expand Up @@ -1826,9 +1883,10 @@ ExecutorPtr compileAggregation(ExecutorPtr input, size_t & executor_index, ASTPt
ci.tp = TiDB::TypeLongLong;
ci.flag = TiDB::ColumnFlagUnsigned | TiDB::ColumnFlagNotNull;
}
else if (func->name == "max" || func->name == "min")
else if (func->name == "max" || func->name == "min" || func->name == "first_row")
{
ci = children_ci[0];
ci.flag &= ~TiDB::ColumnFlagNotNull;
}
else if (func->name == UniqRawResName)
{
Expand Down Expand Up @@ -2045,17 +2103,30 @@ QueryFragments mppQueryToQueryFragments(
root_executor->toMPPSubPlan(executor_index, properties, exchange_map);
TableID table_id = findTableIdForQueryFragment(root_executor, exchange_map.empty());
std::vector<Int64> sender_target_task_ids = mpp_ctx->sender_target_task_ids;
std::vector<Int64> current_task_ids = mpp_ctx->current_task_ids;
std::unordered_map<String, std::vector<Int64>> receiver_source_task_ids_map;
size_t current_task_num = properties.mpp_partition_num;
for (auto & exchange : exchange_map)
{
if (exchange.second.second->type == tipb::ExchangeType::PassThrough)
{
current_task_num = 1;
break;
}
}
std::vector<Int64> current_task_ids;
for (size_t i = 0; i < current_task_num; i++)
current_task_ids.push_back(mpp_ctx->next_task_id++);
for (auto & exchange : exchange_map)
{
<<<<<<< HEAD
std::vector<Int64> task_ids;
for (size_t i = 0; i < (size_t)mpp_ctx->partition_num; i++)
task_ids.push_back(mpp_ctx->next_task_id++);
=======
>>>>>>> 5bd08d6040 (set empty_result_for_aggregation_by_empty_set according to AggregateFuncMode (#3822))
mpp_ctx->sender_target_task_ids = current_task_ids;
mpp_ctx->current_task_ids = task_ids;
receiver_source_task_ids_map[exchange.first] = task_ids;
auto sub_fragments = mppQueryToQueryFragments(exchange.second.second, executor_index, properties, false, mpp_ctx);
receiver_source_task_ids_map[exchange.first] = sub_fragments.cbegin()->task_ids;
fragments.insert(fragments.end(), sub_fragments.begin(), sub_fragments.end());
}
fragments.emplace_back(root_executor, table_id, for_root_fragment, std::move(sender_target_task_ids),
Expand All @@ -2071,10 +2142,13 @@ QueryFragments queryPlanToQueryFragments(const DAGProperties & properties, Execu
= std::make_shared<mock::ExchangeSender>(executor_index, root_executor->output_schema, tipb::PassThrough);
root_exchange_sender->children.push_back(root_executor);
root_executor = root_exchange_sender;
MPPCtxPtr mpp_ctx = std::make_shared<MPPCtx>(properties.start_ts, properties.mpp_partition_num);
MPPCtxPtr mpp_ctx = std::make_shared<MPPCtx>(properties.start_ts);
mpp_ctx->sender_target_task_ids.emplace_back(-1);
<<<<<<< HEAD
for (size_t i = 0; i < (size_t)properties.mpp_partition_num; i++)
mpp_ctx->current_task_ids.push_back(mpp_ctx->next_task_id++);
=======
>>>>>>> 5bd08d6040 (set empty_result_for_aggregation_by_empty_set according to AggregateFuncMode (#3822))
return mppQueryToQueryFragments(root_executor, executor_index, properties, true, mpp_ctx);
}
else
Expand Down
50 changes: 45 additions & 5 deletions dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ struct AnalysisResult
Names aggregation_keys;
TiDB::TiDBCollators aggregation_collators;
AggregateDescriptions aggregate_descriptions;
bool is_final_agg;
};

// add timezone cast for timestamp type, this is used to support session level timezone
Expand All @@ -122,6 +123,15 @@ bool addTimeZoneCastAfterTS(
return analyzer.appendTimeZoneCastsAfterTS(chain, is_ts_column);
}

bool isFinalAgg(const tipb::Expr & expr)
{
if (!expr.has_aggfuncmode())
/// set default value to true to make it compatible with old version of TiDB since before this
/// change, all the aggregation in TiFlash is treated as final aggregation
return true;
return expr.aggfuncmode() == tipb::AggFunctionMode::FinalMode || expr.aggfuncmode() == tipb::AggFunctionMode::CompleteMode;
}

AnalysisResult analyzeExpressions(
Context & context,
DAGExpressionAnalyzer & analyzer,
Expand Down Expand Up @@ -152,12 +162,22 @@ AnalysisResult analyzeExpressions(
// There will be either Agg...
if (query_block.aggregation)
{
/// set default value to true to make it compatible with old version of TiDB since before this
/// change, all the aggregation in TiFlash is treated as final aggregation
res.is_final_agg = true;
const auto & aggregation = query_block.aggregation->aggregation();
if (aggregation.agg_func_size() > 0 && !isFinalAgg(aggregation.agg_func(0)))
res.is_final_agg = false;
for (int i = 1; i < aggregation.agg_func_size(); i++)
{
if (res.is_final_agg != isFinalAgg(aggregation.agg_func(i)))
throw TiFlashException("Different aggregation mode detected", Errors::Coprocessor::BadRequest);
}
// todo now we can tell if the aggregation is final stage or partial stage, maybe we can do collation insensitive
// aggregation if the stage is partial
bool group_by_collation_sensitive =
/// collation sensitive group by is slower then normal group by, use normal group by by default
context.getSettingsRef().group_by_collation_sensitive ||
/// in mpp task, here is no way to tell whether this aggregation is first stage aggregation or
/// final stage aggregation, to make sure the result is right, always do collation sensitive aggregation
context.getDAGContext()->isMPPTask();
/// collation sensitive group by is slower than normal group by, use normal group by by default
context.getSettingsRef().group_by_collation_sensitive || context.getDAGContext()->isMPPTask();

analyzer.appendAggregation(
chain,
Expand Down Expand Up @@ -567,8 +587,18 @@ void DAGQueryBlockInterpreter::executeWhere(DAGPipeline & pipeline, const Expres
pipeline.transform([&](auto & stream) { stream = std::make_shared<FilterBlockInputStream>(stream, expr, filter_column); });
}

<<<<<<< HEAD
void DAGQueryBlockInterpreter::executeAggregation(DAGPipeline & pipeline, const ExpressionActionsPtr & expr, Names & key_names,
TiDB::TiDBCollators & collators, AggregateDescriptions & aggregates)
=======
void DAGQueryBlockInterpreter::executeAggregation(
DAGPipeline & pipeline,
const ExpressionActionsPtr & expression_actions_ptr,
Names & key_names,
TiDB::TiDBCollators & collators,
AggregateDescriptions & aggregate_descriptions,
bool is_final_agg)
>>>>>>> 5bd08d6040 (set empty_result_for_aggregation_by_empty_set according to AggregateFuncMode (#3822))
{
pipeline.transform([&](auto & stream) { stream = std::make_shared<ExpressionBlockInputStream>(stream, expr); });

Expand Down Expand Up @@ -609,7 +639,13 @@ void DAGQueryBlockInterpreter::executeAggregation(DAGPipeline & pipeline, const
Aggregator::Params params(header, keys, aggregates, false, settings.max_rows_to_group_by, settings.group_by_overflow_mode,
allow_to_use_two_level_group_by ? settings.group_by_two_level_threshold : SettingUInt64(0),
allow_to_use_two_level_group_by ? settings.group_by_two_level_threshold_bytes : SettingUInt64(0),
<<<<<<< HEAD
settings.max_bytes_before_external_group_by, settings.empty_result_for_aggregation_by_empty_set, context.getTemporaryPath(),
=======
settings.max_bytes_before_external_group_by,
!is_final_agg,
context.getTemporaryPath(),
>>>>>>> 5bd08d6040 (set empty_result_for_aggregation_by_empty_set according to AggregateFuncMode (#3822))
has_collator ? collators : TiDB::dummy_collators);

/// If there are several sources, then we perform parallel aggregation
Expand Down Expand Up @@ -1017,8 +1053,12 @@ void DAGQueryBlockInterpreter::executeImpl(DAGPipeline & pipeline)
if (res.need_aggregate)
{
// execute aggregation
<<<<<<< HEAD
executeAggregation(pipeline, res.before_aggregation, res.aggregation_keys, res.aggregation_collators, res.aggregate_descriptions);
recordProfileStreams(pipeline, query_block.aggregation_name);
=======
executeAggregation(pipeline, res.before_aggregation, res.aggregation_keys, res.aggregation_collators, res.aggregate_descriptions, res.is_final_agg);
>>>>>>> 5bd08d6040 (set empty_result_for_aggregation_by_empty_set according to AggregateFuncMode (#3822))
}
if (res.has_having)
{
Expand Down
10 changes: 10 additions & 0 deletions dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,18 @@ class DAGQueryBlockInterpreter
void executeExpression(DAGPipeline & pipeline, const ExpressionActionsPtr & expressionActionsPtr);
void executeOrder(DAGPipeline & pipeline, std::vector<NameAndTypePair> & order_columns);
void executeLimit(DAGPipeline & pipeline);
<<<<<<< HEAD
void executeAggregation(DAGPipeline & pipeline, const ExpressionActionsPtr & expressionActionsPtr, Names & aggregation_keys,
TiDB::TiDBCollators & collators, AggregateDescriptions & aggregate_descriptions);
=======
void executeAggregation(
DAGPipeline & pipeline,
const ExpressionActionsPtr & expression_actions_ptr,
Names & key_names,
TiDB::TiDBCollators & collators,
AggregateDescriptions & aggregate_descriptions,
bool is_final_agg);
>>>>>>> 5bd08d6040 (set empty_result_for_aggregation_by_empty_set according to AggregateFuncMode (#3822))
void executeProject(DAGPipeline & pipeline, NamesWithAliases & project_cols);

SortDescription getSortDescription(std::vector<NameAndTypePair> & order_columns);
Expand Down
43 changes: 43 additions & 0 deletions tests/delta-merge-test/query/mpp/aggregation_empty_input.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Preparation.
=> DBGInvoke __enable_schema_sync_service('true')

=> DBGInvoke __drop_tidb_table(default, test)
=> drop table if exists default.test

=> DBGInvoke __set_flush_threshold(1000000, 1000000)

# Data.
=> DBGInvoke __mock_tidb_table(default, test, 'col_1 String, col_2 Int64')
=> DBGInvoke __refresh_schemas()
=> DBGInvoke __put_region(4, 0, 100, default, test)
=> DBGInvoke __put_region(5, 100, 200, default, test)
=> DBGInvoke __put_region(6, 200, 300, default, test)

# shuffle agg with empty table
=> DBGInvoke tidb_query('select count(col_1) from default.test', 4,'mpp_query:true,mpp_partition_num:3')
┌─exchange_receiver_0─┐
│ 0 │
└─────────────────────┘

=> DBGInvoke __raft_insert_row(default, test, 4, 50, 'test1', 666)
=> DBGInvoke __raft_insert_row(default, test, 4, 51, 'test2', 666)
=> DBGInvoke __raft_insert_row(default, test, 4, 52, 'test3', 777)
=> DBGInvoke __raft_insert_row(default, test, 4, 53, 'test4', 888)
=> DBGInvoke __raft_insert_row(default, test, 5, 150, 'test1', 666)
=> DBGInvoke __raft_insert_row(default, test, 5, 151, 'test2', 666)
=> DBGInvoke __raft_insert_row(default, test, 5, 152, 'test3', 777)
=> DBGInvoke __raft_insert_row(default, test, 5, 153, 'test4', 888)
=> DBGInvoke __raft_insert_row(default, test, 6, 250, 'test1', 666)
=> DBGInvoke __raft_insert_row(default, test, 6, 251, 'test2', 666)
=> DBGInvoke __raft_insert_row(default, test, 6, 252, 'test3', 777)
=> DBGInvoke __raft_insert_row(default, test, 6, 253, 'test4', 999)

# shuffle agg
=> DBGInvoke tidb_query('select count(col_1), first_row(col_2) from default.test where col_2 = 999', 4,'mpp_query:true,mpp_partition_num:3')
┌─exchange_receiver_0─┬─exchange_receiver_1─┐
│ 1 │ 999 │
└─────────────────────┴─────────────────────┘

# Clean up.
=> DBGInvoke __drop_tidb_table(default, test)
=> drop table if exists default.test
Loading

0 comments on commit 3fdd446

Please sign in to comment.