From 62286b4d101e252b98bbad7865d3c944560cadb3 Mon Sep 17 00:00:00 2001 From: HuaiyuXu <391585975@qq.com> Date: Thu, 12 Nov 2020 18:03:58 +0800 Subject: [PATCH] cherry pick #20977 to release-4.0 Signed-off-by: ti-srebot --- executor/join_test.go | 46 ++++++++++ planner/core/expression_rewriter.go | 37 +++++--- planner/core/logical_plan_builder.go | 124 +++++++++++++++++++-------- planner/core/logical_plans.go | 2 +- 4 files changed, 163 insertions(+), 46 deletions(-) diff --git a/executor/join_test.go b/executor/join_test.go index 949b908b8d718..50311f4adc840 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -584,6 +584,52 @@ func (s *testSuiteJoin1) TestUsing(c *C) { tk.MustExec("create table tt(b bigint, a int)") // Check whether this sql can execute successfully. tk.MustExec("select * from t join tt using(a)") + + tk.MustExec("drop table if exists t, s") + tk.MustExec("create table t(a int, b int)") + tk.MustExec("create table s(b int, a int)") + tk.MustExec("insert into t values(1,1), (2,2), (3,3), (null,null)") + tk.MustExec("insert into s values(1,1), (3,3), (null,null)") + + // For issue 20477 + tk.MustQuery("select t.*, s.* from t join s using(a)").Check(testkit.Rows("1 1 1 1", "3 3 3 3")) + tk.MustQuery("select s.a from t join s using(a)").Check(testkit.Rows("1", "3")) + tk.MustQuery("select s.a from t join s using(a) where s.a > 1").Check(testkit.Rows("3")) + tk.MustQuery("select s.a from t join s using(a) order by s.a").Check(testkit.Rows("1", "3")) + tk.MustQuery("select s.a from t join s using(a) where s.a > 1 order by s.a").Check(testkit.Rows("3")) + tk.MustQuery("select s.a from t join s using(a) where s.a > 1 order by s.a limit 2").Check(testkit.Rows("3")) + + // For issue 20441 + tk.MustExec(`DROP TABLE if exists t1, t2, t3`) + tk.MustExec(`create table t1 (i int)`) + tk.MustExec(`create table t2 (i int)`) + tk.MustExec(`create table t3 (i int)`) + tk.MustExec(`select * from t1,t2 natural left join t3 order by t1.i,t2.i,t3.i`) + tk.MustExec(`select t1.i,t2.i,t3.i from t2 natural left join t3,t1 order by t1.i,t2.i,t3.i`) + tk.MustExec(`select * from t1,t2 natural right join t3 order by t1.i,t2.i,t3.i`) + tk.MustExec(`select t1.i,t2.i,t3.i from t2 natural right join t3,t1 order by t1.i,t2.i,t3.i`) + + // For issue 15844 + tk.MustExec(`DROP TABLE if exists t0, t1`) + tk.MustExec(`CREATE TABLE t0(c0 INT)`) + tk.MustExec(`CREATE TABLE t1(c0 INT)`) + tk.MustExec(`SELECT t0.c0 FROM t0 NATURAL RIGHT JOIN t1 WHERE t1.c0`) + + // For issue 20958 + tk.MustExec(`DROP TABLE if exists t1, t2`) + tk.MustExec(`create table t1(id int, name varchar(20));`) + tk.MustExec(`create table t2(id int, address varchar(30));`) + tk.MustExec(`insert into t1 values(1,'gangshen');`) + tk.MustExec(`insert into t2 values(1,'HangZhou');`) + tk.MustQuery(`select t2.* from t1 inner join t2 using (id) limit 1;`).Check(testkit.Rows("1 HangZhou")) + tk.MustQuery(`select t2.* from t1 inner join t2 on t1.id = t2.id limit 1;`).Check(testkit.Rows("1 HangZhou")) + + // For issue 20476 + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(a int)") + tk.MustExec("insert into t1 (a) values(1)") + tk.MustQuery("select t1.*, t2.* from t1 join t1 t2 using(a)").Check(testkit.Rows("1 1")) + tk.MustQuery("select * from t1 join t1 t2 using(a)").Check(testkit.Rows("1")) } func (s *testSuiteJoin1) TestNaturalJoin(c *C) { diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 834e4c01c2c97..830d6b321e67b 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -1624,27 +1624,42 @@ func (er *expressionRewriter) toColumn(v *ast.ColumnName) { return } } - if join, ok := er.p.(*LogicalJoin); ok && join.redundantSchema != nil { - idx, err := expression.FindFieldName(join.redundantNames, v) - if err != nil { - er.err = err - return - } - if idx >= 0 { - er.ctxStackAppend(join.redundantSchema.Columns[idx], join.redundantNames[idx]) - return - } - } if _, ok := er.p.(*LogicalUnionAll); ok && v.Table.O != "" { er.err = ErrTablenameNotAllowedHere.GenWithStackByArgs(v.Table.O, "SELECT", clauseMsg[er.b.curClause]) return } + col, name, err := findFieldNameFromNaturalUsingJoin(er.p, v) + if err != nil { + er.err = err + return + } else if col != nil { + er.ctxStackAppend(col, name) + return + } if er.b.curClause == globalOrderByClause { er.b.curClause = orderByClause } er.err = ErrUnknownColumn.GenWithStackByArgs(v.String(), clauseMsg[er.b.curClause]) } +func findFieldNameFromNaturalUsingJoin(p LogicalPlan, v *ast.ColumnName) (col *expression.Column, name *types.FieldName, err error) { + switch x := p.(type) { + case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow: + return findFieldNameFromNaturalUsingJoin(p.Children()[0], v) + case *LogicalJoin: + if x.redundantSchema != nil { + idx, err := expression.FindFieldName(x.redundantNames, v) + if err != nil { + return nil, nil, err + } + if idx >= 0 { + return x.redundantSchema.Columns[idx], x.redundantNames[idx], nil + } + } + } + return nil, nil, nil +} + func (er *expressionRewriter) evalDefaultExpr(v *ast.DefaultExpr) { stkLen := len(er.ctxStack) name := er.ctxNameStk[stkLen-1] diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 0bef4438aedbe..5d5e11982b517 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -812,8 +812,19 @@ func (b *PlanBuilder) coalesceCommonColumns(p *LogicalJoin, leftPlan, rightPlan p.SetSchema(expression.NewSchema(schemaCols...)) p.names = names - p.redundantSchema = expression.MergeSchema(p.redundantSchema, expression.NewSchema(rColumns[:commonLen]...)) - p.redundantNames = append(p.redundantNames.Shallow(), rNames[:commonLen]...) + if joinTp == ast.RightJoin { + leftPlan, rightPlan = rightPlan, leftPlan + } + // We record the full `rightPlan.Schema` as `redundantSchema` in order to + // record the redundant column in `rightPlan` and the output columns order + // of the `rightPlan`. + // For SQL like `select t1.*, t2.* from t1 left join t2 using(a)`, we can + // retrieve the column order of `t2.*` from the `redundantSchema`. + p.redundantSchema = expression.MergeSchema(p.redundantSchema, expression.NewSchema(rightPlan.Schema().Clone().Columns...)) + p.redundantNames = append(p.redundantNames.Shallow(), rightPlan.OutputNames().Shallow()...) + if joinTp == ast.RightJoin || joinTp == ast.LeftJoin { + resetNotNullFlag(p.redundantSchema, 0, p.redundantSchema.Len()) + } p.OtherConditions = append(conds, p.OtherConditions...) return nil @@ -954,12 +965,8 @@ func (b *PlanBuilder) buildProjectionField(ctx context.Context, p LogicalPlan, f idx := p.Schema().ColumnIndex(col) var name *types.FieldName // The column maybe the one from join's redundant part. - // TODO: Fully support USING/NATURAL JOIN, refactor here. if idx == -1 { - if join, ok := p.(*LogicalJoin); ok { - idx = join.redundantSchema.ColumnIndex(col) - name = join.redundantNames[idx] - } + name = findColFromNaturalUsingJoin(p, col) } else { name = p.OutputNames()[idx] } @@ -991,6 +998,25 @@ func (b *PlanBuilder) buildProjectionField(ctx context.Context, p LogicalPlan, f return newCol, name, nil } +// findColFromNaturalUsingJoin is used to recursively find the column from the +// underlying natural-using-join. +// e.g. For SQL like `select t2.a from t1 join t2 using(a) where t2.a > 0`, the +// plan will be `join->selection->projection`. The schema of the `selection` +// will be `[t1.a]`, thus we need to recursively retrieve the `t2.a` from the +// underlying join. +func findColFromNaturalUsingJoin(p LogicalPlan, col *expression.Column) (name *types.FieldName) { + switch x := p.(type) { + case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow: + return findColFromNaturalUsingJoin(p.Children()[0], col) + case *LogicalJoin: + if x.redundantSchema != nil { + idx := x.redundantSchema.ColumnIndex(col) + return x.redundantNames[idx] + } + } + return nil +} + // buildProjection returns a Projection plan and non-aux columns length. func (b *PlanBuilder) buildProjection(ctx context.Context, p LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int, considerWindow bool, expandGenerateColumn bool) (LogicalPlan, int, error) { b.optFlag |= flagEliminateProjection @@ -1504,14 +1530,30 @@ func (a *havingWindowAndOrderbyExprResolver) resolveFromPlan(v *ast.ColumnNameEx if err != nil { return -1, err } + schemaCols, outputNames := p.Schema().Columns, p.OutputNames() if idx < 0 { - return -1, nil + // For SQL like `select t2.a from t1 join t2 using(a) where t2.a > 0 + // order by t2.a`, the query plan will be `join->selection->sort`. The + // schema of selection will be `[t1.a]`, thus we need to recursively + // retrieve the `t2.a` from the underlying join. + switch x := p.(type) { + case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow: + return a.resolveFromPlan(v, p.Children()[0]) + case *LogicalJoin: + if len(x.redundantNames) != 0 { + idx, err = expression.FindFieldName(x.redundantNames, v.Name) + schemaCols, outputNames = x.redundantSchema.Columns, x.redundantNames + } + } + if err != nil || idx < 0 { + return -1, err + } } - col := p.Schema().Columns[idx] + col := schemaCols[idx] if col.IsHidden { return -1, ErrUnknownColumn.GenWithStackByArgs(v.Name, clauseMsg[a.curClause]) } - name := p.OutputNames()[idx] + name := outputNames[idx] newColName := &ast.ColumnName{ Schema: name.DBName, Table: name.TblName, @@ -2259,6 +2301,7 @@ func (b *PlanBuilder) resolveGbyExprs(ctx context.Context, p LogicalPlan, gby *a } func (b *PlanBuilder) unfoldWildStar(p LogicalPlan, selectFields []*ast.SelectField) (resultList []*ast.SelectField, err error) { + join, isJoin := p.(*LogicalJoin) for i, field := range selectFields { if field.WildCard == nil { resultList = append(resultList, field) @@ -2267,37 +2310,50 @@ func (b *PlanBuilder) unfoldWildStar(p LogicalPlan, selectFields []*ast.SelectFi if field.WildCard.Table.L == "" && i > 0 { return nil, ErrInvalidWildCard } - dbName := field.WildCard.Schema - tblName := field.WildCard.Table - findTblNameInSchema := false - for i, name := range p.OutputNames() { - col := p.Schema().Columns[i] - if col.IsHidden { - continue - } - if (dbName.L == "" || dbName.L == name.DBName.L) && - (tblName.L == "" || tblName.L == name.TblName.L) && - col.ID != model.ExtraHandleID { - findTblNameInSchema = true - colName := &ast.ColumnNameExpr{ - Name: &ast.ColumnName{ - Schema: name.DBName, - Table: name.TblName, - Name: name.ColName, - }} - colName.SetType(col.GetType()) - field := &ast.SelectField{Expr: colName} - field.SetText(name.ColName.O) - resultList = append(resultList, field) + list := unfoldWildStar(field, p.OutputNames(), p.Schema().Columns) + // For sql like `select t1.*, t2.* from t1 join t2 using(a)`, we should + // not coalesce the `t2.a` in the output result. Thus we need to unfold + // the wildstar from the underlying join.redundantSchema. + if isJoin && join.redundantSchema != nil && field.WildCard.Table.L != "" { + redundantList := unfoldWildStar(field, join.redundantNames, join.redundantSchema.Columns) + if len(redundantList) > len(list) { + list = redundantList } } - if !findTblNameInSchema { - return nil, ErrBadTable.GenWithStackByArgs(tblName) + if len(list) == 0 { + return nil, ErrBadTable.GenWithStackByArgs(field.WildCard.Table) } + resultList = append(resultList, list...) } return resultList, nil } +func unfoldWildStar(field *ast.SelectField, outputName types.NameSlice, column []*expression.Column) (resultList []*ast.SelectField) { + dbName := field.WildCard.Schema + tblName := field.WildCard.Table + for i, name := range outputName { + col := column[i] + if col.IsHidden { + continue + } + if (dbName.L == "" || dbName.L == name.DBName.L) && + (tblName.L == "" || tblName.L == name.TblName.L) && + col.ID != model.ExtraHandleID { + colName := &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Schema: name.DBName, + Table: name.TblName, + Name: name.ColName, + }} + colName.SetType(col.GetType()) + field := &ast.SelectField{Expr: colName} + field.SetText(name.ColName.O) + resultList = append(resultList, field) + } + } + return resultList +} + func (b *PlanBuilder) pushHintWithoutTableWarning(hint *ast.TableOptimizerHint) { var sb strings.Builder ctx := format.NewRestoreCtx(0, &sb) diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index bf1a946b699a9..4cad55402ce8c 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -144,7 +144,7 @@ type LogicalJoin struct { DefaultValues []types.Datum // redundantSchema contains columns which are eliminated in join. - // For select * from a join b using (c); a.c will in output schema, and b.c will in redundantSchema. + // For select * from a join b using (c); a.c will in output schema, and b.c will only in redundantSchema. redundantSchema *expression.Schema redundantNames types.NameSlice