Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor, planner: fix some cases for natural_using_join (#20977) #21021

Merged
merged 3 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions executor/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,52 @@ func (s *testSuiteJoin1) TestUsing(c *C) {
tk.MustExec("create table tt(b bigint, a int)")
// Check whether this sql can execute successfully.
tk.MustExec("select * from t join tt using(a)")

tk.MustExec("drop table if exists t, s")
tk.MustExec("create table t(a int, b int)")
tk.MustExec("create table s(b int, a int)")
tk.MustExec("insert into t values(1,1), (2,2), (3,3), (null,null)")
tk.MustExec("insert into s values(1,1), (3,3), (null,null)")

// For issue 20477
tk.MustQuery("select t.*, s.* from t join s using(a)").Check(testkit.Rows("1 1 1 1", "3 3 3 3"))
tk.MustQuery("select s.a from t join s using(a)").Check(testkit.Rows("1", "3"))
tk.MustQuery("select s.a from t join s using(a) where s.a > 1").Check(testkit.Rows("3"))
tk.MustQuery("select s.a from t join s using(a) order by s.a").Check(testkit.Rows("1", "3"))
tk.MustQuery("select s.a from t join s using(a) where s.a > 1 order by s.a").Check(testkit.Rows("3"))
tk.MustQuery("select s.a from t join s using(a) where s.a > 1 order by s.a limit 2").Check(testkit.Rows("3"))

// For issue 20441
tk.MustExec(`DROP TABLE if exists t1, t2, t3`)
tk.MustExec(`create table t1 (i int)`)
tk.MustExec(`create table t2 (i int)`)
tk.MustExec(`create table t3 (i int)`)
tk.MustExec(`select * from t1,t2 natural left join t3 order by t1.i,t2.i,t3.i`)
tk.MustExec(`select t1.i,t2.i,t3.i from t2 natural left join t3,t1 order by t1.i,t2.i,t3.i`)
tk.MustExec(`select * from t1,t2 natural right join t3 order by t1.i,t2.i,t3.i`)
tk.MustExec(`select t1.i,t2.i,t3.i from t2 natural right join t3,t1 order by t1.i,t2.i,t3.i`)

// For issue 15844
tk.MustExec(`DROP TABLE if exists t0, t1`)
tk.MustExec(`CREATE TABLE t0(c0 INT)`)
tk.MustExec(`CREATE TABLE t1(c0 INT)`)
tk.MustExec(`SELECT t0.c0 FROM t0 NATURAL RIGHT JOIN t1 WHERE t1.c0`)

// For issue 20958
tk.MustExec(`DROP TABLE if exists t1, t2`)
tk.MustExec(`create table t1(id int, name varchar(20));`)
tk.MustExec(`create table t2(id int, address varchar(30));`)
tk.MustExec(`insert into t1 values(1,'gangshen');`)
tk.MustExec(`insert into t2 values(1,'HangZhou');`)
tk.MustQuery(`select t2.* from t1 inner join t2 using (id) limit 1;`).Check(testkit.Rows("1 HangZhou"))
tk.MustQuery(`select t2.* from t1 inner join t2 on t1.id = t2.id limit 1;`).Check(testkit.Rows("1 HangZhou"))

// For issue 20476
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1(a int)")
tk.MustExec("insert into t1 (a) values(1)")
tk.MustQuery("select t1.*, t2.* from t1 join t1 t2 using(a)").Check(testkit.Rows("1 1"))
tk.MustQuery("select * from t1 join t1 t2 using(a)").Check(testkit.Rows("1"))
}

func (s *testSuiteJoin1) TestNaturalJoin(c *C) {
Expand Down
37 changes: 26 additions & 11 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1624,27 +1624,42 @@ func (er *expressionRewriter) toColumn(v *ast.ColumnName) {
return
}
}
if join, ok := er.p.(*LogicalJoin); ok && join.redundantSchema != nil {
idx, err := expression.FindFieldName(join.redundantNames, v)
if err != nil {
er.err = err
return
}
if idx >= 0 {
er.ctxStackAppend(join.redundantSchema.Columns[idx], join.redundantNames[idx])
return
}
}
if _, ok := er.p.(*LogicalUnionAll); ok && v.Table.O != "" {
er.err = ErrTablenameNotAllowedHere.GenWithStackByArgs(v.Table.O, "SELECT", clauseMsg[er.b.curClause])
return
}
col, name, err := findFieldNameFromNaturalUsingJoin(er.p, v)
if err != nil {
er.err = err
return
} else if col != nil {
er.ctxStackAppend(col, name)
return
}
if er.b.curClause == globalOrderByClause {
er.b.curClause = orderByClause
}
er.err = ErrUnknownColumn.GenWithStackByArgs(v.String(), clauseMsg[er.b.curClause])
}

func findFieldNameFromNaturalUsingJoin(p LogicalPlan, v *ast.ColumnName) (col *expression.Column, name *types.FieldName, err error) {
switch x := p.(type) {
case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow:
return findFieldNameFromNaturalUsingJoin(p.Children()[0], v)
case *LogicalJoin:
if x.redundantSchema != nil {
idx, err := expression.FindFieldName(x.redundantNames, v)
if err != nil {
return nil, nil, err
}
if idx >= 0 {
return x.redundantSchema.Columns[idx], x.redundantNames[idx], nil
}
}
}
return nil, nil, nil
}

func (er *expressionRewriter) evalDefaultExpr(v *ast.DefaultExpr) {
stkLen := len(er.ctxStack)
name := er.ctxNameStk[stkLen-1]
Expand Down
124 changes: 90 additions & 34 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -812,8 +812,19 @@ func (b *PlanBuilder) coalesceCommonColumns(p *LogicalJoin, leftPlan, rightPlan

p.SetSchema(expression.NewSchema(schemaCols...))
p.names = names
p.redundantSchema = expression.MergeSchema(p.redundantSchema, expression.NewSchema(rColumns[:commonLen]...))
p.redundantNames = append(p.redundantNames.Shallow(), rNames[:commonLen]...)
if joinTp == ast.RightJoin {
leftPlan, rightPlan = rightPlan, leftPlan
}
// We record the full `rightPlan.Schema` as `redundantSchema` in order to
// record the redundant column in `rightPlan` and the output columns order
// of the `rightPlan`.
// For SQL like `select t1.*, t2.* from t1 left join t2 using(a)`, we can
// retrieve the column order of `t2.*` from the `redundantSchema`.
p.redundantSchema = expression.MergeSchema(p.redundantSchema, expression.NewSchema(rightPlan.Schema().Clone().Columns...))
p.redundantNames = append(p.redundantNames.Shallow(), rightPlan.OutputNames().Shallow()...)
if joinTp == ast.RightJoin || joinTp == ast.LeftJoin {
resetNotNullFlag(p.redundantSchema, 0, p.redundantSchema.Len())
}
p.OtherConditions = append(conds, p.OtherConditions...)

return nil
Expand Down Expand Up @@ -954,12 +965,8 @@ func (b *PlanBuilder) buildProjectionField(ctx context.Context, p LogicalPlan, f
idx := p.Schema().ColumnIndex(col)
var name *types.FieldName
// The column maybe the one from join's redundant part.
// TODO: Fully support USING/NATURAL JOIN, refactor here.
if idx == -1 {
if join, ok := p.(*LogicalJoin); ok {
idx = join.redundantSchema.ColumnIndex(col)
name = join.redundantNames[idx]
}
name = findColFromNaturalUsingJoin(p, col)
} else {
name = p.OutputNames()[idx]
}
Expand Down Expand Up @@ -991,6 +998,25 @@ func (b *PlanBuilder) buildProjectionField(ctx context.Context, p LogicalPlan, f
return newCol, name, nil
}

// findColFromNaturalUsingJoin is used to recursively find the column from the
// underlying natural-using-join.
// e.g. For SQL like `select t2.a from t1 join t2 using(a) where t2.a > 0`, the
// plan will be `join->selection->projection`. The schema of the `selection`
// will be `[t1.a]`, thus we need to recursively retrieve the `t2.a` from the
// underlying join.
func findColFromNaturalUsingJoin(p LogicalPlan, col *expression.Column) (name *types.FieldName) {
switch x := p.(type) {
case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow:
return findColFromNaturalUsingJoin(p.Children()[0], col)
case *LogicalJoin:
if x.redundantSchema != nil {
idx := x.redundantSchema.ColumnIndex(col)
return x.redundantNames[idx]
}
}
return nil
}

// buildProjection returns a Projection plan and non-aux columns length.
func (b *PlanBuilder) buildProjection(ctx context.Context, p LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int, considerWindow bool, expandGenerateColumn bool) (LogicalPlan, int, error) {
b.optFlag |= flagEliminateProjection
Expand Down Expand Up @@ -1504,14 +1530,30 @@ func (a *havingWindowAndOrderbyExprResolver) resolveFromPlan(v *ast.ColumnNameEx
if err != nil {
return -1, err
}
schemaCols, outputNames := p.Schema().Columns, p.OutputNames()
if idx < 0 {
return -1, nil
// For SQL like `select t2.a from t1 join t2 using(a) where t2.a > 0
// order by t2.a`, the query plan will be `join->selection->sort`. The
// schema of selection will be `[t1.a]`, thus we need to recursively
// retrieve the `t2.a` from the underlying join.
switch x := p.(type) {
case *LogicalLimit, *LogicalSelection, *LogicalTopN, *LogicalSort, *LogicalMaxOneRow:
return a.resolveFromPlan(v, p.Children()[0])
case *LogicalJoin:
if len(x.redundantNames) != 0 {
idx, err = expression.FindFieldName(x.redundantNames, v.Name)
schemaCols, outputNames = x.redundantSchema.Columns, x.redundantNames
}
}
if err != nil || idx < 0 {
return -1, err
}
}
col := p.Schema().Columns[idx]
col := schemaCols[idx]
if col.IsHidden {
return -1, ErrUnknownColumn.GenWithStackByArgs(v.Name, clauseMsg[a.curClause])
}
name := p.OutputNames()[idx]
name := outputNames[idx]
newColName := &ast.ColumnName{
Schema: name.DBName,
Table: name.TblName,
Expand Down Expand Up @@ -2259,6 +2301,7 @@ func (b *PlanBuilder) resolveGbyExprs(ctx context.Context, p LogicalPlan, gby *a
}

func (b *PlanBuilder) unfoldWildStar(p LogicalPlan, selectFields []*ast.SelectField) (resultList []*ast.SelectField, err error) {
join, isJoin := p.(*LogicalJoin)
for i, field := range selectFields {
if field.WildCard == nil {
resultList = append(resultList, field)
Expand All @@ -2267,37 +2310,50 @@ func (b *PlanBuilder) unfoldWildStar(p LogicalPlan, selectFields []*ast.SelectFi
if field.WildCard.Table.L == "" && i > 0 {
return nil, ErrInvalidWildCard
}
dbName := field.WildCard.Schema
tblName := field.WildCard.Table
findTblNameInSchema := false
for i, name := range p.OutputNames() {
col := p.Schema().Columns[i]
if col.IsHidden {
continue
}
if (dbName.L == "" || dbName.L == name.DBName.L) &&
(tblName.L == "" || tblName.L == name.TblName.L) &&
col.ID != model.ExtraHandleID {
findTblNameInSchema = true
colName := &ast.ColumnNameExpr{
Name: &ast.ColumnName{
Schema: name.DBName,
Table: name.TblName,
Name: name.ColName,
}}
colName.SetType(col.GetType())
field := &ast.SelectField{Expr: colName}
field.SetText(name.ColName.O)
resultList = append(resultList, field)
list := unfoldWildStar(field, p.OutputNames(), p.Schema().Columns)
// For sql like `select t1.*, t2.* from t1 join t2 using(a)`, we should
// not coalesce the `t2.a` in the output result. Thus we need to unfold
// the wildstar from the underlying join.redundantSchema.
if isJoin && join.redundantSchema != nil && field.WildCard.Table.L != "" {
redundantList := unfoldWildStar(field, join.redundantNames, join.redundantSchema.Columns)
if len(redundantList) > len(list) {
list = redundantList
}
}
if !findTblNameInSchema {
return nil, ErrBadTable.GenWithStackByArgs(tblName)
if len(list) == 0 {
return nil, ErrBadTable.GenWithStackByArgs(field.WildCard.Table)
}
resultList = append(resultList, list...)
}
return resultList, nil
}

func unfoldWildStar(field *ast.SelectField, outputName types.NameSlice, column []*expression.Column) (resultList []*ast.SelectField) {
dbName := field.WildCard.Schema
tblName := field.WildCard.Table
for i, name := range outputName {
col := column[i]
if col.IsHidden {
continue
}
if (dbName.L == "" || dbName.L == name.DBName.L) &&
(tblName.L == "" || tblName.L == name.TblName.L) &&
col.ID != model.ExtraHandleID {
colName := &ast.ColumnNameExpr{
Name: &ast.ColumnName{
Schema: name.DBName,
Table: name.TblName,
Name: name.ColName,
}}
colName.SetType(col.GetType())
field := &ast.SelectField{Expr: colName}
field.SetText(name.ColName.O)
resultList = append(resultList, field)
}
}
return resultList
}

func (b *PlanBuilder) pushHintWithoutTableWarning(hint *ast.TableOptimizerHint) {
var sb strings.Builder
ctx := format.NewRestoreCtx(0, &sb)
Expand Down
2 changes: 1 addition & 1 deletion planner/core/logical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ type LogicalJoin struct {
DefaultValues []types.Datum

// redundantSchema contains columns which are eliminated in join.
// For select * from a join b using (c); a.c will in output schema, and b.c will in redundantSchema.
// For select * from a join b using (c); a.c will in output schema, and b.c will only in redundantSchema.
redundantSchema *expression.Schema
redundantNames types.NameSlice

Expand Down