From fbcbdc08539d21a8f38f0bbe59e397f20a28a80a Mon Sep 17 00:00:00 2001 From: xufei Date: Tue, 15 Oct 2019 18:02:55 +0800 Subject: [PATCH] remove duplicate agg funcs (#283) * 1. remove duplicate agg funcs, 2. for column ref expr, change column_id to column_index since the value stored in column ref expr is not column id * bug fix --- .../Coprocessor/DAGExpressionAnalyzer.cpp | 37 ++++++++++--------- .../Flash/Coprocessor/DAGExpressionAnalyzer.h | 2 +- dbms/src/Flash/Coprocessor/DAGQueryInfo.h | 10 ++--- dbms/src/Flash/Coprocessor/DAGUtils.cpp | 18 ++++----- dbms/src/Flash/Coprocessor/DAGUtils.h | 2 +- dbms/src/Flash/Coprocessor/InterpreterDAG.cpp | 2 +- dbms/src/Storages/MergeTree/KeyCondition.cpp | 2 +- dbms/src/Storages/MergeTree/RPNBuilder.cpp | 13 ++----- dbms/src/Storages/MergeTree/RPNBuilder.h | 4 +- tests/mutable-test/txn_dag/aggregation.test | 6 +++ 10 files changed, 46 insertions(+), 50 deletions(-) diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp index 4850400e44d..1a362d688e4 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.cpp @@ -47,8 +47,8 @@ static String genFuncString(const String & func_name, const Names & argument_nam return ss.str(); } -DAGExpressionAnalyzer::DAGExpressionAnalyzer(const std::vector && source_columns_, const Context & context_) - : source_columns(source_columns_), +DAGExpressionAnalyzer::DAGExpressionAnalyzer(std::vector && source_columns_, const Context & context_) + : source_columns(std::move(source_columns_)), context(context_), after_agg(false), implicit_cast_count(0), @@ -68,7 +68,6 @@ void DAGExpressionAnalyzer::appendAggregation( initChain(chain, getCurrentInputColumns()); ExpressionActionsChain::Step & step = chain.steps.back(); - Names agg_argument_names; for (const tipb::Expr & expr : agg.agg_func()) { const String & agg_func_name = getAggFunctionName(expr); @@ -78,13 +77,24 @@ void DAGExpressionAnalyzer::appendAggregation( for (Int32 i = 0; i < expr.children_size(); i++) { String arg_name = getActions(expr.children(i), step.actions); - agg_argument_names.push_back(arg_name); types[i] = step.actions->getSampleBlock().getByName(arg_name).type; aggregate.argument_names[i] = arg_name; + step.required_output.push_back(arg_name); } - String func_string = genFuncString(agg_func_name, agg_argument_names); + String func_string = genFuncString(agg_func_name, aggregate.argument_names); + bool duplicate = false; + for (const auto & pre_agg : aggregate_descriptions) + { + if (pre_agg.column_name == func_string) + { + aggregated_columns.emplace_back(func_string, pre_agg.function->getReturnType()); + duplicate = true; + break; + } + } + if (duplicate) + continue; aggregate.column_name = func_string; - //todo de-duplicate aggregation column aggregate.parameters = Array(); aggregate.function = AggregateFunctionFactory::instance().get(agg_func_name, types); aggregate_descriptions.push_back(aggregate); @@ -93,8 +103,6 @@ void DAGExpressionAnalyzer::appendAggregation( aggregated_columns.emplace_back(func_string, result_type); } - std::move(agg_argument_names.begin(), agg_argument_names.end(), std::back_inserter(step.required_output)); - for (const tipb::Expr & expr : agg.group_by()) { String name = getActions(expr, step.actions); @@ -286,7 +294,7 @@ void DAGExpressionAnalyzer::appendAggSelect( { initChain(chain, getCurrentInputColumns()); bool need_update_aggregated_columns = false; - NamesAndTypesList updated_aggregated_columns; + std::vector updated_aggregated_columns; ExpressionActionsChain::Step step = chain.steps.back(); bool need_append_timezone_cast = hasMeaningfulTZInfo(rqst); tipb::Expr tz_expr; @@ -344,12 +352,10 @@ void DAGExpressionAnalyzer::appendAggSelect( if (need_update_aggregated_columns) { - auto updated_agg_col_names = updated_aggregated_columns.getNames(); - auto updated_agg_col_types = updated_aggregated_columns.getTypes(); aggregated_columns.clear(); for (size_t i = 0; i < updated_aggregated_columns.size(); i++) { - aggregated_columns.emplace_back(updated_agg_col_names[i], updated_agg_col_types[i]); + aggregated_columns.emplace_back(updated_aggregated_columns[i].name, updated_aggregated_columns[i].type); } } } @@ -471,13 +477,8 @@ String DAGExpressionAnalyzer::getActions(const tipb::Expr & expr, ExpressionActi } else if (isColumnExpr(expr)) { - ColumnID column_id = getColumnID(expr); - if (column_id < 0 || column_id >= (ColumnID)getCurrentInputColumns().size()) - { - throw Exception("column id out of bound", ErrorCodes::COP_BAD_DAG_REQUEST); - } //todo check if the column type need to be cast to field type - return getCurrentInputColumns()[column_id].name; + return getColumnNameForColumnExpr(expr, getCurrentInputColumns()); } else if (isFunctionExpr(expr)) { diff --git a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h index 1b5b65f0ff0..d2a6b5751be 100644 --- a/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h +++ b/dbms/src/Flash/Coprocessor/DAGExpressionAnalyzer.h @@ -37,7 +37,7 @@ class DAGExpressionAnalyzer : private boost::noncopyable Poco::Logger * log; public: - DAGExpressionAnalyzer(const std::vector && source_columns_, const Context & context_); + DAGExpressionAnalyzer(std::vector && source_columns_, const Context & context_); void appendWhere(ExpressionActionsChain & chain, const tipb::Selection & sel, String & filter_column_name); void appendOrderBy(ExpressionActionsChain & chain, const tipb::TopN & topN, Strings & order_column_names); void appendAggregation(ExpressionActionsChain & chain, const tipb::Aggregation & agg, Names & aggregate_keys, diff --git a/dbms/src/Flash/Coprocessor/DAGQueryInfo.h b/dbms/src/Flash/Coprocessor/DAGQueryInfo.h index 20274503782..cb01768d473 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryInfo.h +++ b/dbms/src/Flash/Coprocessor/DAGQueryInfo.h @@ -10,14 +10,10 @@ namespace DB struct DAGQueryInfo { - DAGQueryInfo(const DAGQuerySource & dag_, DAGPreparedSets dag_sets_, std::vector & source_columns_) - : dag(dag_), dag_sets(std::move(dag_sets_)) - { - for (auto & c : source_columns_) - source_columns.emplace_back(c.name, c.type); - }; + DAGQueryInfo(const DAGQuerySource & dag_, DAGPreparedSets dag_sets_, const std::vector & source_columns_) + : dag(dag_), dag_sets(std::move(dag_sets_)), source_columns(source_columns_){}; const DAGQuerySource & dag; DAGPreparedSets dag_sets; - NamesAndTypesList source_columns; + const std::vector & source_columns; }; } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.cpp b/dbms/src/Flash/Coprocessor/DAGUtils.cpp index 09fc5d0b87a..0f196df8db1 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.cpp +++ b/dbms/src/Flash/Coprocessor/DAGUtils.cpp @@ -52,7 +52,6 @@ const String & getFunctionName(const tipb::Expr & expr) String exprToString(const tipb::Expr & expr, const std::vector & input_col) { std::stringstream ss; - Int64 column_id = 0; String func_name; Field f; switch (expr.tp()) @@ -94,12 +93,7 @@ String exprToString(const tipb::Expr & expr, const std::vector return std::to_string(TiDB::DatumFlat(t, static_cast(expr.field_type().tp())).field().get()); } case tipb::ExprType::ColumnRef: - column_id = decodeDAGInt64(expr.val()); - if (column_id < 0 || column_id >= (ColumnID)input_col.size()) - { - throw Exception("Column id out of bound", ErrorCodes::COP_BAD_DAG_REQUEST); - } - return input_col[column_id].name; + return getColumnNameForColumnExpr(expr, input_col); case tipb::ExprType::Count: case tipb::ExprType::Sum: case tipb::ExprType::Avg: @@ -247,10 +241,14 @@ Field decodeLiteral(const tipb::Expr & expr) } } -ColumnID getColumnID(const tipb::Expr & expr) +String getColumnNameForColumnExpr(const tipb::Expr & expr, const std::vector & input_col) { - auto column_id = decodeDAGInt64(expr.val()); - return column_id; + auto column_index = decodeDAGInt64(expr.val()); + if (column_index < 0 || column_index >= (Int64)input_col.size()) + { + throw Exception("Column index out of bound", ErrorCodes::COP_BAD_DAG_REQUEST); + } + return input_col[column_index].name; } bool isInOrGlobalInOperator(const String & name) { return name == "in" || name == "notIn" || name == "globalIn" || name == "globalNotIn"; } diff --git a/dbms/src/Flash/Coprocessor/DAGUtils.h b/dbms/src/Flash/Coprocessor/DAGUtils.h index 709b7602dba..b45c12680c3 100644 --- a/dbms/src/Flash/Coprocessor/DAGUtils.h +++ b/dbms/src/Flash/Coprocessor/DAGUtils.h @@ -22,7 +22,7 @@ bool isAggFunctionExpr(const tipb::Expr & expr); const String & getFunctionName(const tipb::Expr & expr); const String & getAggFunctionName(const tipb::Expr & expr); bool isColumnExpr(const tipb::Expr & expr); -ColumnID getColumnID(const tipb::Expr & expr); +String getColumnNameForColumnExpr(const tipb::Expr & expr, const std::vector & input_col); const String & getTypeName(const tipb::Expr & expr); String exprToString(const tipb::Expr & expr, const std::vector & input_col); bool isInOrGlobalInOperator(const String & name); diff --git a/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp b/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp index d20016f21ad..f0321b3bd7b 100644 --- a/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp +++ b/dbms/src/Flash/Coprocessor/InterpreterDAG.cpp @@ -250,7 +250,7 @@ void InterpreterDAG::executeTS(const tipb::TableScan & ts, Pipeline & pipeline) SelectQueryInfo query_info; // set query to avoid unexpected NPE query_info.query = dag.getAST(); - query_info.dag_query = std::make_unique(dag, analyzer->getPreparedSets(), source_columns); + query_info.dag_query = std::make_unique(dag, analyzer->getPreparedSets(), analyzer->getCurrentInputColumns()); query_info.mvcc_query_info = std::make_unique(); query_info.mvcc_query_info->resolve_locks = true; query_info.mvcc_query_info->read_tso = settings.read_tso; diff --git a/dbms/src/Storages/MergeTree/KeyCondition.cpp b/dbms/src/Storages/MergeTree/KeyCondition.cpp index b23b9340b35..1e1062df6b9 100644 --- a/dbms/src/Storages/MergeTree/KeyCondition.cpp +++ b/dbms/src/Storages/MergeTree/KeyCondition.cpp @@ -282,7 +282,7 @@ KeyCondition::KeyCondition( if (query_info.fromAST()) { - RPNBuilder rpn_builder(key_expr_, key_columns, all_columns); + RPNBuilder rpn_builder(key_expr_, key_columns, {}); PreparedSets sets(query_info.sets); /** Evaluation of expressions that depend only on constants. diff --git a/dbms/src/Storages/MergeTree/RPNBuilder.cpp b/dbms/src/Storages/MergeTree/RPNBuilder.cpp index 9a2830612b9..52e458dea20 100644 --- a/dbms/src/Storages/MergeTree/RPNBuilder.cpp +++ b/dbms/src/Storages/MergeTree/RPNBuilder.cpp @@ -58,19 +58,14 @@ const String getFuncName(const ASTPtr & node) return ""; } -const String getColumnName(const tipb::Expr & node, const NamesAndTypesList & source_columns) +const String getColumnName(const tipb::Expr & node, const std::vector & source_columns) { - if (node.tp() == tipb::ExprType::ColumnRef) - { - auto col_id = getColumnID(node); - if (col_id < 0 || col_id >= (Int64)source_columns.size()) - return ""; - return source_columns.getNames()[col_id]; - } + if (isColumnExpr(node)) + return getColumnNameForColumnExpr(node, source_columns); return ""; } -const String getColumnName(const ASTPtr & node, const NamesAndTypesList &) { return node->getColumnName(); } +const String getColumnName(const ASTPtr & node, const std::vector &) { return node->getColumnName(); } bool isFuncNode(const ASTPtr & node) { return typeid_cast(node.get()); } diff --git a/dbms/src/Storages/MergeTree/RPNBuilder.h b/dbms/src/Storages/MergeTree/RPNBuilder.h index f9eaf263cf5..5eafac0d704 100644 --- a/dbms/src/Storages/MergeTree/RPNBuilder.h +++ b/dbms/src/Storages/MergeTree/RPNBuilder.h @@ -21,7 +21,7 @@ template class RPNBuilder { public: - RPNBuilder(const ExpressionActionsPtr & key_expr_, ColumnIndices & key_columns_, const NamesAndTypesList & source_columns_) + RPNBuilder(const ExpressionActionsPtr & key_expr_, ColumnIndices & key_columns_, const std::vector & source_columns_) : key_expr(key_expr_), key_columns(key_columns_), source_columns(source_columns_) {} @@ -62,6 +62,6 @@ class RPNBuilder protected: const ExpressionActionsPtr & key_expr; ColumnIndices & key_columns; - const NamesAndTypesList & source_columns; + const std::vector & source_columns; }; } // namespace DB diff --git a/tests/mutable-test/txn_dag/aggregation.test b/tests/mutable-test/txn_dag/aggregation.test index 0f8ec4c30e3..ef1905a5dc4 100644 --- a/tests/mutable-test/txn_dag/aggregation.test +++ b/tests/mutable-test/txn_dag/aggregation.test @@ -21,6 +21,12 @@ │ 1 │ 777 │ └──────────────┴───────┘ +=> DBGInvoke dag('select count(col_1),count(col_1) from default.test group by col_2') +┌─count(col_1)─┬─count(col_1)─┬─col_2─┐ +│ 2 │ 2 │ 666 │ +│ 1 │ 1 │ 777 │ +└──────────────┴──────────────┴───────┘ + # DAG read by explicitly specifying region id, where + group by. => DBGInvoke dag('select count(col_1) from default.test where col_2 = 666 group by col_2', 4) ┌─count(col_1)─┬─col_2─┐