diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index ee4682a04e7..90503f46c2c 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -46,6 +46,7 @@ type ( SetLimit(*Limit) SetLock(lock Lock) MakeDistinct() + GetColumnCount() int } // DDLStatement represents any DDL Statement diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 5029e6d5691..7f8b090c3b9 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -744,6 +744,11 @@ func (node *Select) MakeDistinct() { node.Distinct = true } +// GetColumnCount return SelectExprs count. +func (node *Select) GetColumnCount() int { + return len(node.SelectExprs) +} + // AddWhere adds the boolean expression to the // WHERE clause as an AND condition. func (node *Select) AddWhere(expr Expr) { @@ -796,6 +801,11 @@ func (node *ParenSelect) MakeDistinct() { node.Select.MakeDistinct() } +// GetColumnCount implements the SelectStatement interface +func (node *ParenSelect) GetColumnCount() int { + return node.Select.GetColumnCount() +} + // AddWhere adds the boolean expression to the // WHERE clause as an AND condition. func (node *Update) AddWhere(expr Expr) { @@ -832,6 +842,11 @@ func (node *Union) MakeDistinct() { node.UnionSelects[len(node.UnionSelects)-1].Distinct = true } +// GetColumnCount implements the SelectStatement interface +func (node *Union) GetColumnCount() int { + return node.FirstStatement.GetColumnCount() +} + //Unionize returns a UNION, either creating one or adding SELECT to an existing one func Unionize(lhs, rhs SelectStatement, distinct bool, by OrderBy, limit *Limit, lock Lock) *Union { union, isUnion := lhs.(*Union) diff --git a/go/vt/vtgate/planbuilder/expand_star.go b/go/vt/vtgate/planbuilder/expand_star.go new file mode 100644 index 00000000000..0b5977fa99b --- /dev/null +++ b/go/vt/vtgate/planbuilder/expand_star.go @@ -0,0 +1,113 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package planbuilder + +import ( + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +type ( + expandStarInfo struct { + proceed bool + tblColMap map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs + } + starRewriter struct { + err error + semTable *semantics.SemTable + } +) + +func (sr *starRewriter) starRewrite(cursor *sqlparser.Cursor) bool { + switch node := cursor.Node().(type) { + case *sqlparser.Select: + tables := sr.semTable.GetSelectTables(node) + var selExprs sqlparser.SelectExprs + for _, selectExpr := range node.SelectExprs { + starExpr, isStarExpr := selectExpr.(*sqlparser.StarExpr) + if !isStarExpr { + selExprs = append(selExprs, selectExpr) + continue + } + colNames, expStar, err := expandTableColumns(tables, starExpr) + if err != nil { + sr.err = err + return false + } + if !expStar.proceed { + selExprs = append(selExprs, selectExpr) + continue + } + selExprs = append(selExprs, colNames...) + for tbl, cols := range expStar.tblColMap { + sr.semTable.AddExprs(tbl, cols) + } + } + node.SelectExprs = selExprs + } + return true +} + +func expandTableColumns(tables []semantics.TableInfo, starExpr *sqlparser.StarExpr) (sqlparser.SelectExprs, *expandStarInfo, error) { + unknownTbl := true + var colNames sqlparser.SelectExprs + expStar := &expandStarInfo{ + tblColMap: map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs{}, + } + + for _, tbl := range tables { + if !starExpr.TableName.IsEmpty() && !tbl.Matches(starExpr.TableName) { + continue + } + unknownTbl = false + if !tbl.Authoritative() { + expStar.proceed = false + break + } + expStar.proceed = true + tblName, err := tbl.Name() + if err != nil { + return nil, nil, err + } + for _, col := range tbl.GetColumns() { + colNames = append(colNames, &sqlparser.AliasedExpr{ + Expr: sqlparser.NewColNameWithQualifier(col.Name, tblName), + As: sqlparser.NewColIdent(col.Name), + }) + } + expStar.tblColMap[tbl.GetExpr()] = colNames + } + + if unknownTbl { + // This will only happen for case when starExpr has qualifier. + return nil, nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadDb, "Unknown table '%s'", sqlparser.String(starExpr.TableName)) + } + return colNames, expStar, nil +} + +func expandStar(sel *sqlparser.Select, semTable *semantics.SemTable) (*sqlparser.Select, error) { + // TODO we could store in semTable whether there are any * in the query that needs expanding or not + sr := &starRewriter{semTable: semTable} + + _ = sqlparser.Rewrite(sel, sr.starRewrite, nil) + if sr.err != nil { + return nil, sr.err + } + return sel, nil +} diff --git a/go/vt/vtgate/planbuilder/expand_star_test.go b/go/vt/vtgate/planbuilder/expand_star_test.go new file mode 100644 index 00000000000..62ffb23a4d5 --- /dev/null +++ b/go/vt/vtgate/planbuilder/expand_star_test.go @@ -0,0 +1,180 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package planbuilder + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/semantics" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +func TestExpandStar(t *testing.T) { + schemaInfo := &semantics.FakeSI{ + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewTableIdent("t1"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewColIdent("a"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewColIdent("b"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewColIdent("c"), + Type: sqltypes.VarChar, + }}, + ColumnListAuthoritative: true, + }, + "t2": { + Name: sqlparser.NewTableIdent("t2"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewColIdent("c1"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewColIdent("c2"), + Type: sqltypes.VarChar, + }}, + ColumnListAuthoritative: true, + }, + "t3": { // non authoritative table. + Name: sqlparser.NewTableIdent("t3"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewColIdent("col"), + Type: sqltypes.VarChar, + }}, + ColumnListAuthoritative: false, + }, + }, + } + cDB := "db" + tcases := []struct { + sql string + expSQL string + expErr string + }{{ + sql: "select * from t1", + expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1", + }, { + sql: "select t1.* from t1", + expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1", + }, { + sql: "select *, 42, t1.* from t1", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, 42, t1.a as a, t1.b as b, t1.c as c from t1", + }, { + sql: "select 42, t1.* from t1", + expSQL: "select 42, t1.a as a, t1.b as b, t1.c as c from t1", + }, { + sql: "select * from t1, t2", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1, t2", + }, { + sql: "select t1.* from t1, t2", + expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1, t2", + }, { + sql: "select *, t1.* from t1, t2", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t1.a as a, t1.b as b, t1.c as c from t1, t2", + }, { // aliased table + sql: "select * from t1 a, t2 b", + expSQL: "select a.a as a, a.b as b, a.c as c, b.c1 as c1, b.c2 as c2 from t1 as a, t2 as b", + }, { // t3 is non-authoritative table + sql: "select * from t3", + expSQL: "select * from t3", + }, { // t3 is non-authoritative table + sql: "select * from t1, t2, t3", + expSQL: "select * from t1, t2, t3", + }, { // t3 is non-authoritative table + sql: "select t1.*, t2.*, t3.* from t1, t2, t3", + expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t3.* from t1, t2, t3", + }, { + sql: "select foo.* from t1, t2", + expErr: "Unknown table 'foo'", + }} + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + ast, err := sqlparser.Parse(tcase.sql) + require.NoError(t, err) + semTable, err := semantics.Analyze(ast, cDB, schemaInfo) + require.NoError(t, err) + expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable) + if tcase.expErr == "" { + require.NoError(t, err) + assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect)) + } else { + require.EqualError(t, err, tcase.expErr) + } + }) + } +} + +func TestSemTableDependenciesAfterExpandStar(t *testing.T) { + schemaInfo := &semantics.FakeSI{Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewTableIdent("t1"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewColIdent("a"), + Type: sqltypes.VarChar, + }}, + ColumnListAuthoritative: true, + }}} + tcases := []struct { + sql string + expSQL string + sameTbl int + otherTbl int + expandedCol int + }{{ + sql: "select a, * from t1", + expSQL: "select a, t1.a as a from t1", + otherTbl: -1, sameTbl: 0, expandedCol: 1, + }, { + sql: "select t2.a, t1.a, t1.* from t1, t2", + expSQL: "select t2.a, t1.a, t1.a as a from t1, t2", + otherTbl: 0, sameTbl: 1, expandedCol: 2, + }, { + sql: "select t2.a, t.a, t.* from t1 t, t2", + expSQL: "select t2.a, t.a, t.a as a from t1 as t, t2", + otherTbl: 0, sameTbl: 1, expandedCol: 2, + }} + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + ast, err := sqlparser.Parse(tcase.sql) + require.NoError(t, err) + semTable, err := semantics.Analyze(ast, "", schemaInfo) + require.NoError(t, err) + expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable) + require.NoError(t, err) + assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect)) + if tcase.otherTbl != -1 { + assert.NotEqual(t, + semTable.Dependencies(expandedSelect.SelectExprs[tcase.otherTbl].(*sqlparser.AliasedExpr).Expr), + semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr), + ) + } + if tcase.sameTbl != -1 { + assert.Equal(t, + semTable.Dependencies(expandedSelect.SelectExprs[tcase.sameTbl].(*sqlparser.AliasedExpr).Expr), + semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr), + ) + } + }) + } +} diff --git a/go/vt/vtgate/planbuilder/grouping.go b/go/vt/vtgate/planbuilder/grouping.go index a1b8614a64b..fec8ad38411 100644 --- a/go/vt/vtgate/planbuilder/grouping.go +++ b/go/vt/vtgate/planbuilder/grouping.go @@ -68,7 +68,7 @@ func planGroupBy(pb *primitiveBuilder, input logicalPlan, groupBy sqlparser.Grou return nil, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: in scatter query: group by column must reference column in SELECT list") } case *sqlparser.Literal: - num, err := ResultFromNumber(node.resultColumns, e) + num, err := ResultFromNumber(node.resultColumns, e, "group statement") if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/memory_sort.go b/go/vt/vtgate/planbuilder/memory_sort.go index 8f06423699f..4b9890ce0cf 100644 --- a/go/vt/vtgate/planbuilder/memory_sort.go +++ b/go/vt/vtgate/planbuilder/memory_sort.go @@ -51,7 +51,7 @@ func newMemorySort(plan logicalPlan, orderBy sqlparser.OrderBy) (*memorySort, er switch expr := order.Expr.(type) { case *sqlparser.Literal: var err error - if colNumber, err = ResultFromNumber(ms.ResultColumns(), expr); err != nil { + if colNumber, err = ResultFromNumber(ms.ResultColumns(), expr, "order clause"); err != nil { return nil, err } case *sqlparser.ColName: diff --git a/go/vt/vtgate/planbuilder/memory_sort_gen4.go b/go/vt/vtgate/planbuilder/memory_sort_gen4.go new file mode 100644 index 00000000000..9f93df3a49b --- /dev/null +++ b/go/vt/vtgate/planbuilder/memory_sort_gen4.go @@ -0,0 +1,91 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package planbuilder + +import ( + "vitess.io/vitess/go/sqltypes" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +var _ logicalPlan = (*memorySortGen4)(nil) + +type memorySortGen4 struct { + orderBy []engine.OrderbyParams + input logicalPlan + truncateColumnCount int +} + +func (m *memorySortGen4) Order() int { + panic("implement me") +} + +func (m *memorySortGen4) ResultColumns() []*resultColumn { + panic("implement me") +} + +func (m *memorySortGen4) Reorder(i int) { + panic("implement me") +} + +func (m *memorySortGen4) Wireup(lp logicalPlan, jt *jointab) error { + panic("implement me") +} + +func (m *memorySortGen4) WireupGen4(semTable *semantics.SemTable) error { + return m.input.WireupGen4(semTable) +} + +func (m *memorySortGen4) SupplyVar(from, to int, col *sqlparser.ColName, varname string) { + panic("implement me") +} + +func (m *memorySortGen4) SupplyCol(col *sqlparser.ColName) (rc *resultColumn, colNumber int) { + panic("implement me") +} + +func (m *memorySortGen4) SupplyWeightString(colNumber int) (weightcolNumber int, err error) { + panic("implement me") +} + +func (m *memorySortGen4) Primitive() engine.Primitive { + return &engine.MemorySort{ + UpperLimit: sqltypes.PlanValue{}, + OrderBy: m.orderBy, + Input: m.input.Primitive(), + TruncateColumnCount: m.truncateColumnCount, + } +} + +func (m *memorySortGen4) Inputs() []logicalPlan { + return []logicalPlan{m.input} +} + +func (m *memorySortGen4) Rewrite(inputs ...logicalPlan) error { + if len(inputs) != 1 { + return vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG]: expected only 1 input") + } + m.input = inputs[0] + return nil +} + +func (m *memorySortGen4) ContainsTables() semantics.TableSet { + return m.input.ContainsTables() +} diff --git a/go/vt/vtgate/planbuilder/ordering.go b/go/vt/vtgate/planbuilder/ordering.go index 57f3b268db6..d3d53f4d720 100644 --- a/go/vt/vtgate/planbuilder/ordering.go +++ b/go/vt/vtgate/planbuilder/ordering.go @@ -84,7 +84,7 @@ func planOAOrdering(pb *primitiveBuilder, orderBy sqlparser.OrderBy, oa *ordered var orderByCol *column switch expr := order.Expr.(type) { case *sqlparser.Literal: - num, err := ResultFromNumber(oa.resultColumns, expr) + num, err := ResultFromNumber(oa.resultColumns, expr, "order clause") if err != nil { return nil, err } @@ -183,7 +183,7 @@ func planJoinOrdering(pb *primitiveBuilder, orderBy sqlparser.OrderBy, node *joi if e, ok := order.Expr.(*sqlparser.Literal); ok { // This block handles constructs that use ordinals for 'ORDER BY'. For example: // SELECT a, b, c FROM t1, t2 ORDER BY 1, 2, 3. - num, err := ResultFromNumber(node.ResultColumns(), e) + num, err := ResultFromNumber(node.ResultColumns(), e, "order clause") if err != nil { return nil, err } @@ -258,7 +258,7 @@ func planRouteOrdering(orderBy sqlparser.OrderBy, node *route) (logicalPlan, err switch expr := order.Expr.(type) { case *sqlparser.Literal: var err error - if colNumber, err = ResultFromNumber(node.resultColumns, expr); err != nil { + if colNumber, err = ResultFromNumber(node.resultColumns, expr, "order clause"); err != nil { return nil, err } case *sqlparser.ColName: diff --git a/go/vt/vtgate/planbuilder/postprocess.go b/go/vt/vtgate/planbuilder/postprocess.go index 2df10473fe6..014305e80eb 100644 --- a/go/vt/vtgate/planbuilder/postprocess.go +++ b/go/vt/vtgate/planbuilder/postprocess.go @@ -97,7 +97,7 @@ var _ planVisitor = setUpperLimit func setUpperLimit(plan logicalPlan) (bool, logicalPlan, error) { arg := sqlparser.NewArgument("__upper_limit") switch node := plan.(type) { - case *join: + case *join, *joinGen4: return false, node, nil case *memorySort: pv, err := sqlparser.NewPlanValue(arg) diff --git a/go/vt/vtgate/planbuilder/queryprojection.go b/go/vt/vtgate/planbuilder/queryprojection.go index 032fe0e2830..2bfb5264815 100644 --- a/go/vt/vtgate/planbuilder/queryprojection.go +++ b/go/vt/vtgate/planbuilder/queryprojection.go @@ -17,6 +17,8 @@ limitations under the License. package planbuilder import ( + "strconv" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" @@ -24,21 +26,18 @@ import ( ) type queryProjection struct { - selectExprs []*sqlparser.AliasedExpr - aggrExprs []*sqlparser.AliasedExpr - groupOrderingCommonExpr map[sqlparser.Expr]*sqlparser.Order - - orderExprs sqlparser.OrderBy + selectExprs []*sqlparser.AliasedExpr + aggrExprs []*sqlparser.AliasedExpr + orderExprs []orderBy +} - // orderExprColMap keeps a map between the Order object and the offset into the select expressions list - orderExprColMap map[*sqlparser.Order]int +type orderBy struct { + inner *sqlparser.Order + weightStrExpr sqlparser.Expr } func newQueryProjection() *queryProjection { - return &queryProjection{ - groupOrderingCommonExpr: map[sqlparser.Expr]*sqlparser.Order{}, - orderExprColMap: map[*sqlparser.Order]int{}, - } + return &queryProjection{} } func createQPFromSelect(sel *sqlparser.Select) (*queryProjection, error) { @@ -63,22 +62,61 @@ func createQPFromSelect(sel *sqlparser.Select) (*queryProjection, error) { qp.selectExprs = append(qp.selectExprs, exp) } - qp.orderExprs = sel.OrderBy + if len(qp.selectExprs) > 0 && len(qp.aggrExprs) > 0 { + return nil, semantics.Gen4NotSupportedF("aggregation and non-aggregation expressions, together are not supported in cross-shard query") + } allExpr := append(qp.selectExprs, qp.aggrExprs...) + for _, order := range sel.OrderBy { - for offset, expr := range allExpr { - if sqlparser.EqualsExpr(order.Expr, expr.Expr) { - qp.orderExprColMap[order] = offset - break + qp.addOrderBy(order, allExpr) + } + return qp, nil +} + +func (qp *queryProjection) addOrderBy(order *sqlparser.Order, allExpr []*sqlparser.AliasedExpr) { + // Order by is the column offset to be used from the select expressions + // Eg - select id from music order by 1 + literalExpr, isLiteral := order.Expr.(*sqlparser.Literal) + if isLiteral && literalExpr.Type == sqlparser.IntVal { + num, _ := strconv.Atoi(literalExpr.Val) + aliasedExpr := allExpr[num-1] + expr := aliasedExpr.Expr + if !aliasedExpr.As.IsEmpty() { + // the column is aliased, so we'll add an expression ordering by the alias and not the underlying expression + expr = &sqlparser.ColName{ + Name: aliasedExpr.As, } - // TODO: handle alias and column offset } + qp.orderExprs = append(qp.orderExprs, orderBy{ + inner: &sqlparser.Order{ + Expr: expr, + Direction: order.Direction, + }, + weightStrExpr: aliasedExpr.Expr, + }) + return } - if sel.GroupBy == nil || sel.OrderBy == nil { - return qp, nil + // If the ORDER BY is against a column alias, we need to remember the expression + // behind the alias. The weightstring(.) calls needs to be done against that expression and not the alias. + // Eg - select music.foo as bar, weightstring(music.foo) from music order by bar + colExpr, isColName := order.Expr.(*sqlparser.ColName) + if isColName && colExpr.Qualifier.IsEmpty() { + for _, expr := range allExpr { + isAliasExpr := !expr.As.IsEmpty() + if isAliasExpr && colExpr.Name.Equal(expr.As) { + qp.orderExprs = append(qp.orderExprs, orderBy{ + inner: order, + weightStrExpr: expr.Expr, + }) + return + } + } } - return qp, nil + qp.orderExprs = append(qp.orderExprs, orderBy{ + inner: order, + weightStrExpr: order.Expr, + }) } diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index dc687c1b828..3af16a96e14 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -131,111 +131,6 @@ func optimizeQuery(opTree abstract.Operator, semTable *semantics.SemTable, vsche } } -type starRewriter struct { - err error - semTable *semantics.SemTable -} - -func (sr *starRewriter) starRewrite(cursor *sqlparser.Cursor) bool { - switch node := cursor.Node().(type) { - case *sqlparser.Select: - tables := sr.semTable.GetSelectTables(node) - var selExprs sqlparser.SelectExprs - for _, selectExpr := range node.SelectExprs { - starExpr, isStarExpr := selectExpr.(*sqlparser.StarExpr) - if !isStarExpr { - selExprs = append(selExprs, selectExpr) - continue - } - colNames, expStar, err := expandTableColumns(tables, starExpr) - if err != nil { - sr.err = err - return false - } - if !expStar.proceed { - selExprs = append(selExprs, selectExpr) - continue - } - selExprs = append(selExprs, colNames...) - for tbl, cols := range expStar.tblColMap { - sr.semTable.AddExprs(tbl, cols) - } - } - node.SelectExprs = selExprs - } - return true -} - -func expandTableColumns(tables []*semantics.TableInfo, starExpr *sqlparser.StarExpr) (sqlparser.SelectExprs, *expandStarInfo, error) { - unknownTbl := true - var colNames sqlparser.SelectExprs - expStar := &expandStarInfo{ - tblColMap: map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs{}, - } - - for _, tbl := range tables { - if !starExpr.TableName.IsEmpty() { - if !tbl.ASTNode.As.IsEmpty() { - if !starExpr.TableName.Qualifier.IsEmpty() { - continue - } - if starExpr.TableName.Name.String() != tbl.ASTNode.As.String() { - continue - } - } else { - if !starExpr.TableName.Qualifier.IsEmpty() { - if starExpr.TableName.Qualifier.String() != tbl.Table.Keyspace.Name { - continue - } - } - tblName := tbl.ASTNode.Expr.(sqlparser.TableName) - if starExpr.TableName.Name.String() != tblName.Name.String() { - continue - } - } - } - unknownTbl = false - if tbl.Table == nil || !tbl.Table.ColumnListAuthoritative { - expStar.proceed = false - break - } - expStar.proceed = true - tblName, err := tbl.ASTNode.TableName() - if err != nil { - return nil, nil, err - } - for _, col := range tbl.Table.Columns { - colNames = append(colNames, &sqlparser.AliasedExpr{ - Expr: sqlparser.NewColNameWithQualifier(col.Name.String(), tblName), - As: sqlparser.NewColIdent(col.Name.String()), - }) - } - expStar.tblColMap[tbl.ASTNode] = colNames - } - - if unknownTbl { - // This will only happen for case when starExpr has qualifier. - return nil, nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadDb, "Unknown table '%s'", sqlparser.String(starExpr.TableName)) - } - return colNames, expStar, nil -} - -type expandStarInfo struct { - proceed bool - tblColMap map[*sqlparser.AliasedTableExpr]sqlparser.SelectExprs -} - -func expandStar(sel *sqlparser.Select, semTable *semantics.SemTable) (*sqlparser.Select, error) { - // TODO we could store in semTable whether there are any * in the query that needs expanding or not - sr := &starRewriter{semTable: semTable} - - _ = sqlparser.Rewrite(sel, sr.starRewrite, nil) - if sr.err != nil { - return nil, sr.err - } - return sel, nil -} - func planLimit(limit *sqlparser.Limit, plan logicalPlan) (logicalPlan, error) { if limit == nil { return plan, nil @@ -279,7 +174,7 @@ func planHorizon(sel *sqlparser.Select, plan logicalPlan, semTable *semantics.Se return nil, err } for _, e := range qp.selectExprs { - if _, err := pushProjection(e, plan, semTable, true); err != nil { + if _, _, err := pushProjection(e, plan, semTable, true); err != nil { return nil, err } } @@ -301,7 +196,7 @@ func planHorizon(sel *sqlparser.Select, plan logicalPlan, semTable *semantics.Se } } if len(sel.OrderBy) > 0 { - plan, err = planOrderBy(qp, plan, semTable) + plan, err = planOrderBy(qp, qp.orderExprs, plan, semTable) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/route_planning_test.go b/go/vt/vtgate/planbuilder/route_planning_test.go index e3d8517c7f6..75b656c88cc 100644 --- a/go/vt/vtgate/planbuilder/route_planning_test.go +++ b/go/vt/vtgate/planbuilder/route_planning_test.go @@ -22,10 +22,6 @@ import ( "vitess.io/vitess/go/vt/vtgate/planbuilder/abstract" - "github.com/stretchr/testify/require" - - "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/vtgate/semantics" "github.com/stretchr/testify/assert" @@ -189,154 +185,3 @@ func colName(table, column string) *sqlparser.ColName { func tableName(name string) sqlparser.TableName { return sqlparser.TableName{Name: sqlparser.NewTableIdent(name)} } - -func TestExpandStar(t *testing.T) { - schemaInfo := &semantics.FakeSI{ - Tables: map[string]*vindexes.Table{ - "t1": { - Name: sqlparser.NewTableIdent("t1"), - Columns: []vindexes.Column{{ - Name: sqlparser.NewColIdent("a"), - Type: sqltypes.VarChar, - }, { - Name: sqlparser.NewColIdent("b"), - Type: sqltypes.VarChar, - }, { - Name: sqlparser.NewColIdent("c"), - Type: sqltypes.VarChar, - }}, - ColumnListAuthoritative: true, - }, - "t2": { - Name: sqlparser.NewTableIdent("t2"), - Columns: []vindexes.Column{{ - Name: sqlparser.NewColIdent("c1"), - Type: sqltypes.VarChar, - }, { - Name: sqlparser.NewColIdent("c2"), - Type: sqltypes.VarChar, - }}, - ColumnListAuthoritative: true, - }, - "t3": { // non authoritative table. - Name: sqlparser.NewTableIdent("t3"), - Columns: []vindexes.Column{{ - Name: sqlparser.NewColIdent("col"), - Type: sqltypes.VarChar, - }}, - ColumnListAuthoritative: false, - }, - }, - } - cDB := "db" - tcases := []struct { - sql string - expSQL string - expErr string - }{{ - sql: "select * from t1", - expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1", - }, { - sql: "select t1.* from t1", - expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1", - }, { - sql: "select *, 42, t1.* from t1", - expSQL: "select t1.a as a, t1.b as b, t1.c as c, 42, t1.a as a, t1.b as b, t1.c as c from t1", - }, { - sql: "select 42, t1.* from t1", - expSQL: "select 42, t1.a as a, t1.b as b, t1.c as c from t1", - }, { - sql: "select * from t1, t2", - expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2 from t1, t2", - }, { - sql: "select t1.* from t1, t2", - expSQL: "select t1.a as a, t1.b as b, t1.c as c from t1, t2", - }, { - sql: "select *, t1.* from t1, t2", - expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t1.a as a, t1.b as b, t1.c as c from t1, t2", - }, { // aliased table - sql: "select * from t1 a, t2 b", - expSQL: "select a.a as a, a.b as b, a.c as c, b.c1 as c1, b.c2 as c2 from t1 as a, t2 as b", - }, { // t3 is non-authoritative table - sql: "select * from t3", - expSQL: "select * from t3", - }, { // t3 is non-authoritative table - sql: "select * from t1, t2, t3", - expSQL: "select * from t1, t2, t3", - }, { // t3 is non-authoritative table - sql: "select t1.*, t2.*, t3.* from t1, t2, t3", - expSQL: "select t1.a as a, t1.b as b, t1.c as c, t2.c1 as c1, t2.c2 as c2, t3.* from t1, t2, t3", - }, { - sql: "select foo.* from t1, t2", - expErr: "Unknown table 'foo'", - }} - for _, tcase := range tcases { - t.Run(tcase.sql, func(t *testing.T) { - ast, err := sqlparser.Parse(tcase.sql) - require.NoError(t, err) - semTable, err := semantics.Analyze(ast, cDB, schemaInfo) - require.NoError(t, err) - expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable) - if tcase.expErr == "" { - require.NoError(t, err) - assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect)) - } else { - require.EqualError(t, err, tcase.expErr) - } - }) - } -} - -func TestSemTableDependenciesAfterExpandStar(t *testing.T) { - schemaInfo := &semantics.FakeSI{Tables: map[string]*vindexes.Table{ - "t1": { - Name: sqlparser.NewTableIdent("t1"), - Columns: []vindexes.Column{{ - Name: sqlparser.NewColIdent("a"), - Type: sqltypes.VarChar, - }}, - ColumnListAuthoritative: true, - }}} - tcases := []struct { - sql string - expSQL string - sameTbl int - otherTbl int - expandedCol int - }{{ - sql: "select a, * from t1", - expSQL: "select a, t1.a as a from t1", - otherTbl: -1, sameTbl: 0, expandedCol: 1, - }, { - sql: "select t2.a, t1.a, t1.* from t1, t2", - expSQL: "select t2.a, t1.a, t1.a as a from t1, t2", - otherTbl: 0, sameTbl: 1, expandedCol: 2, - }, { - sql: "select t2.a, t.a, t.* from t1 t, t2", - expSQL: "select t2.a, t.a, t.a as a from t1 as t, t2", - otherTbl: 0, sameTbl: 1, expandedCol: 2, - }} - for _, tcase := range tcases { - t.Run(tcase.sql, func(t *testing.T) { - ast, err := sqlparser.Parse(tcase.sql) - require.NoError(t, err) - semTable, err := semantics.Analyze(ast, "", schemaInfo) - require.NoError(t, err) - expandedSelect, err := expandStar(ast.(*sqlparser.Select), semTable) - require.NoError(t, err) - assert.Equal(t, tcase.expSQL, sqlparser.String(expandedSelect)) - if tcase.otherTbl != -1 { - assert.NotEqual(t, - semTable.Dependencies(expandedSelect.SelectExprs[tcase.otherTbl].(*sqlparser.AliasedExpr).Expr), - semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr), - ) - } - if tcase.sameTbl != -1 { - assert.Equal(t, - semTable.Dependencies(expandedSelect.SelectExprs[tcase.sameTbl].(*sqlparser.AliasedExpr).Expr), - semTable.Dependencies(expandedSelect.SelectExprs[tcase.expandedCol].(*sqlparser.AliasedExpr).Expr), - ) - } - }) - } -} diff --git a/go/vt/vtgate/planbuilder/selectGen4.go b/go/vt/vtgate/planbuilder/selectGen4.go index 9e803d13895..42520e04403 100644 --- a/go/vt/vtgate/planbuilder/selectGen4.go +++ b/go/vt/vtgate/planbuilder/selectGen4.go @@ -28,51 +28,65 @@ import ( "vitess.io/vitess/go/vt/vtgate/engine" ) -func pushProjection(expr *sqlparser.AliasedExpr, plan logicalPlan, semTable *semantics.SemTable, inner bool) (int, error) { +func pushProjection(expr *sqlparser.AliasedExpr, plan logicalPlan, semTable *semantics.SemTable, inner bool) (int, bool, error) { switch node := plan.(type) { case *route: value, err := makePlanValue(expr.Expr) if err != nil { - return 0, err + return 0, false, err } _, isColName := expr.Expr.(*sqlparser.ColName) badExpr := value == nil && !isColName if !inner && badExpr { - return 0, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: cross-shard left join and column expressions") + return 0, false, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: cross-shard left join and column expressions") } sel := node.Select.(*sqlparser.Select) i := checkIfAlreadyExists(expr, sel) if i != -1 { - return i, nil + return i, false, nil } expr = removeQualifierFromColName(expr) offset := len(sel.SelectExprs) sel.SelectExprs = append(sel.SelectExprs, expr) - return offset, nil + return offset, true, nil case *joinGen4: lhsSolves := node.Left.ContainsTables() rhsSolves := node.Right.ContainsTables() deps := semTable.Dependencies(expr.Expr) + var column int + var appended bool switch { case deps.IsSolvedBy(lhsSolves): - offset, err := pushProjection(expr, node.Left, semTable, inner) + offset, added, err := pushProjection(expr, node.Left, semTable, inner) if err != nil { - return 0, err + return 0, false, err } - node.Cols = append(node.Cols, -(offset + 1)) + column = -(offset + 1) + appended = added case deps.IsSolvedBy(rhsSolves): - offset, err := pushProjection(expr, node.Right, semTable, inner && node.Opcode != engine.LeftJoin) + offset, added, err := pushProjection(expr, node.Right, semTable, inner && node.Opcode != engine.LeftJoin) if err != nil { - return 0, err + return 0, false, err } - node.Cols = append(node.Cols, offset+1) + column = offset + 1 + appended = added default: - return 0, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unknown dependencies for %s", sqlparser.String(expr)) + return 0, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unknown dependencies for %s", sqlparser.String(expr)) } - return len(node.Cols) - 1, nil + if !appended { + for idx, col := range node.Cols { + if column == col { + return idx, false, nil + } + } + // the column was not appended to either child, but we could not find it in out cols list, + // so we'll still add it + } + node.Cols = append(node.Cols, column) + return len(node.Cols) - 1, true, nil default: - return 0, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "%T not yet supported", node) + return 0, false, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "%T not yet supported", node) } } @@ -88,9 +102,18 @@ func removeQualifierFromColName(expr *sqlparser.AliasedExpr) *sqlparser.AliasedE func checkIfAlreadyExists(expr *sqlparser.AliasedExpr, sel *sqlparser.Select) int { for i, selectExpr := range sel.SelectExprs { if selectExpr, ok := selectExpr.(*sqlparser.AliasedExpr); ok { - if sqlparser.EqualsExpr(selectExpr.Expr, expr.Expr) { - return i + if selectExpr.As.IsEmpty() { + // we don't have an alias, so we can compare the expressions + if sqlparser.EqualsExpr(selectExpr.Expr, expr.Expr) { + return i + } + // we have an aliased column, so let's check if the expression is matching the alias + } else if colName, ok := expr.Expr.(*sqlparser.ColName); ok { + if selectExpr.As.Equal(colName.Name) { + return i + } } + } } return -1 @@ -99,11 +122,15 @@ func checkIfAlreadyExists(expr *sqlparser.AliasedExpr, sel *sqlparser.Select) in func planAggregations(qp *queryProjection, plan logicalPlan, semTable *semantics.SemTable) (logicalPlan, error) { eaggr := &engine.OrderedAggregate{} oa := &orderedAggregate{ - resultsBuilder: newResultsBuilder(plan, eaggr), - eaggr: eaggr, + resultsBuilder: resultsBuilder{ + logicalPlanCommon: newBuilderCommon(plan), + weightStrings: make(map[*resultColumn]int), + truncater: eaggr, + }, + eaggr: eaggr, } for _, e := range qp.aggrExprs { - offset, err := pushProjection(e, plan, semTable, true) + offset, _, err := pushProjection(e, plan, semTable, true) if err != nil { return nil, err } @@ -117,75 +144,149 @@ func planAggregations(qp *queryProjection, plan logicalPlan, semTable *semantics return oa, nil } -func planOrderBy(qp *queryProjection, plan logicalPlan, semTable *semantics.SemTable) (logicalPlan, error) { +func planOrderBy(qp *queryProjection, orderExprs []orderBy, plan logicalPlan, semTable *semantics.SemTable) (logicalPlan, error) { switch plan := plan.(type) { case *route: - additionalColAdded := false - for _, order := range qp.orderExprs { - offset, exists := qp.orderExprColMap[order] - colName, ok := order.Expr.(*sqlparser.ColName) - if !ok { - return nil, semantics.Gen4NotSupportedF("order by non-column expression") - } - if !exists { - expr := &sqlparser.AliasedExpr{ - Expr: order.Expr, - } - var err error - offset, err = pushProjection(expr, plan, semTable, true) - if err != nil { - return nil, err - } - additionalColAdded = true - } + return planOrderByForRoute(orderExprs, plan, semTable) + case *joinGen4: + return planOrderByForJoin(qp, orderExprs, plan, semTable) + default: + return nil, semantics.Gen4NotSupportedF("ordering on complex query") + } +} - table := semTable.Dependencies(colName) - tableInfo, err := semTable.TableInfoFor(table) +func planOrderByForRoute(orderExprs []orderBy, plan *route, semTable *semantics.SemTable) (logicalPlan, error) { + origColCount := plan.Select.GetColumnCount() + for _, order := range orderExprs { + expr := order.inner.Expr + offset, err := wrapExprAndPush(expr, plan, semTable) + if err != nil { + return nil, err + } + colName, ok := order.inner.Expr.(*sqlparser.ColName) + if !ok { + return nil, semantics.Gen4NotSupportedF("order by non-column expression") + } + + table := semTable.Dependencies(colName) + tbl, err := semTable.TableInfoFor(table) + if err != nil { + return nil, err + } + weightStringNeeded := needsWeightString(tbl, colName) + + weightStringOffset := -1 + if weightStringNeeded { + weightStringOffset, err = wrapExprAndPush(weightStringFor(order.weightStrExpr), plan, semTable) if err != nil { return nil, err } - weightStringNeeded := true - for _, c := range tableInfo.Table.Columns { - if colName.Name.Equal(c.Name) { - if sqltypes.IsNumber(c.Type) { - weightStringNeeded = false - } - break - } - } + } - weightStringOffset := -1 - if weightStringNeeded { - expr := &sqlparser.AliasedExpr{ - Expr: &sqlparser.FuncExpr{ - Name: sqlparser.NewColIdent("weight_string"), - Exprs: []sqlparser.SelectExpr{ - &sqlparser.AliasedExpr{ - Expr: order.Expr, - }, - }, - }, - } - weightStringOffset, err = pushProjection(expr, plan, semTable, true) - if err != nil { - return nil, err - } - additionalColAdded = true - } + plan.eroute.OrderBy = append(plan.eroute.OrderBy, engine.OrderbyParams{ + Col: offset, + WeightStringCol: weightStringOffset, + Desc: order.inner.Direction == sqlparser.DescOrder, + }) + plan.Select.AddOrder(order.inner) + } + if origColCount != plan.Select.GetColumnCount() { + plan.eroute.TruncateColumnCount = origColCount + } - plan.eroute.OrderBy = append(plan.eroute.OrderBy, engine.OrderbyParams{ - Col: offset, - WeightStringCol: weightStringOffset, - Desc: order.Direction == sqlparser.DescOrder, - }) - plan.Select.AddOrder(order) - } - if additionalColAdded { - plan.eroute.TruncateColumnCount = len(qp.selectExprs) + len(qp.aggrExprs) + return plan, nil +} + +func weightStringFor(expr sqlparser.Expr) sqlparser.Expr { + return &sqlparser.FuncExpr{ + Name: sqlparser.NewColIdent("weight_string"), + Exprs: []sqlparser.SelectExpr{ + &sqlparser.AliasedExpr{ + Expr: expr, + }, + }, + } + +} + +func needsWeightString(tbl semantics.TableInfo, colName *sqlparser.ColName) bool { + for _, c := range tbl.GetColumns() { + if colName.Name.String() == c.Name { + return !sqltypes.IsNumber(c.Type) } + } + return true // we didn't find the column. better to add just to be safe1 +} +func wrapExprAndPush(exp sqlparser.Expr, plan logicalPlan, semTable *semantics.SemTable) (int, error) { + aliasedExpr := &sqlparser.AliasedExpr{Expr: exp} + offset, _, err := pushProjection(aliasedExpr, plan, semTable, true) + return offset, err +} + +func planOrderByForJoin(qp *queryProjection, orderExprs []orderBy, plan *joinGen4, semTable *semantics.SemTable) (logicalPlan, error) { + if allLeft(orderExprs, semTable, plan.Left.ContainsTables()) { + newLeft, err := planOrderBy(qp, orderExprs, plan.Left, semTable) + if err != nil { + return nil, err + } + plan.Left = newLeft return plan, nil - default: - return nil, semantics.Gen4NotSupportedF("ordering on complex query") } + + primitive := &engine.MemorySort{} + ms := &memorySort{ + resultsBuilder: resultsBuilder{ + logicalPlanCommon: newBuilderCommon(plan), + weightStrings: make(map[*resultColumn]int), + truncater: primitive, + }, + eMemorySort: primitive, + } + + for _, order := range orderExprs { + expr := order.inner.Expr + offset, err := wrapExprAndPush(expr, plan, semTable) + if err != nil { + return nil, err + } + + table := semTable.Dependencies(expr) + tbl, err := semTable.TableInfoFor(table) + if err != nil { + return nil, err + } + col, isCol := expr.(*sqlparser.ColName) + if !isCol { + return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "order by complex expression not supported") + } + + weightStringOffset := -1 + if needsWeightString(tbl, col) { + weightStringOffset, err = wrapExprAndPush(weightStringFor(order.weightStrExpr), plan, semTable) + if err != nil { + return nil, err + } + } + + ms.eMemorySort.OrderBy = append(ms.eMemorySort.OrderBy, engine.OrderbyParams{ + Col: offset, + WeightStringCol: weightStringOffset, + Desc: order.inner.Direction == sqlparser.DescOrder, + StarColFixedIndex: offset, + }) + } + + return ms, nil + +} + +func allLeft(orderExprs []orderBy, semTable *semantics.SemTable, lhsTables semantics.TableSet) bool { + for _, expr := range orderExprs { + exprDependencies := semTable.Dependencies(expr.inner.Expr) + if !exprDependencies.IsSolvedBy(lhsTables) { + return false + } + } + return true } diff --git a/go/vt/vtgate/planbuilder/symtab.go b/go/vt/vtgate/planbuilder/symtab.go index 658f52d5c9e..5c5c9c67b14 100644 --- a/go/vt/vtgate/planbuilder/symtab.go +++ b/go/vt/vtgate/planbuilder/symtab.go @@ -379,7 +379,7 @@ func (st *symtab) searchTables(col *sqlparser.ColName) (*column, error) { // ResultFromNumber returns the result column index based on the column // order expression. -func ResultFromNumber(rcs []*resultColumn, val *sqlparser.Literal) (int, error) { +func ResultFromNumber(rcs []*resultColumn, val *sqlparser.Literal, caller string) (int, error) { if val.Type != sqlparser.IntVal { return 0, errors.New("column number is not an int") } @@ -388,7 +388,7 @@ func ResultFromNumber(rcs []*resultColumn, val *sqlparser.Literal) (int, error) return 0, fmt.Errorf("error parsing column number: %s", sqlparser.String(val)) } if num < 1 || num > int64(len(rcs)) { - return 0, fmt.Errorf("column number out of range: %d", num) + return 0, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown column '%d' in '%s'", num, caller) } return int(num - 1), nil } diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt index b109137932f..343ddb4d82f 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.txt @@ -602,7 +602,7 @@ Gen4 plan same as above ] } } -Gen4 plan same as above +"gen4 does not yet support: aggregation and non-aggregation expressions, together are not supported in cross-shard query" # scatter aggregate using distinct "select distinct col from user" @@ -995,7 +995,7 @@ Gen4 plan same as above # scatter aggregate group by invalid column number "select col from user group by 2" -"column number out of range: 2" +"Unknown column '2' in 'group statement'" # scatter aggregate order by null "select count(*) from user order by null" @@ -1168,7 +1168,8 @@ Gen4 plan same as above # invalid order by column numner for scatter "select col, count(*) from user group by col order by 5 limit 10" -"column number out of range: 5" +"Unknown column '5' in 'order clause'" +Gen4 plan same as above # aggregate with limit "select col, count(*) from user group by col limit 10" @@ -1337,7 +1338,7 @@ Gen4 plan same as above # Group by out of range column number (code is duplicated from symab). "select id from user group by 2" -"column number out of range: 2" +"Unknown column '2' in 'group statement'" # syntax error detected by planbuilder "select count(distinct *) from user" diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.txt b/go/vt/vtgate/planbuilder/testdata/from_cases.txt index 3a0e1b308de..d80acba2589 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.txt @@ -1755,6 +1755,29 @@ Gen4 plan same as above } } +# recursive derived table lookups +"select id from (select id from (select id from user) as u) as t where id = 5" +{ + "QueryType": "SELECT", + "Original": "select id from (select id from (select id from user) as u) as t where id = 5", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id from (select id from (select id from `user` where 1 != 1) as u where 1 != 1) as t where 1 != 1", + "Query": "select id from (select id from (select id from `user`) as u) as t where id = 5", + "Table": "`user`", + "Values": [ + 5 + ], + "Vindex": "user_index" + } +} + + # merge subqueries with single-shard routes "select u.col, e.col from (select col from user where id = 5) as u join (select col from user_extra where user_id = 5) as e" { diff --git a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.txt b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.txt index 5b256260111..dd16a534f33 100644 --- a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.txt @@ -319,6 +319,55 @@ ] } } +{ + "QueryType": "SELECT", + "Original": "select user.col1 as a, user.col2 b, music.col3 c from user, music where user.id = music.id and user.id = 1 order by c", + "Instructions": { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "2 ASC", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-2,-3,1,2", + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.id, `user`.col1 as a, `user`.col2 as b from `user` where 1 != 1", + "Query": "select `user`.id, `user`.col1 as a, `user`.col2 as b from `user` where `user`.id = 1", + "Table": "`user`", + "Values": [ + 1 + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.col3 as c, weight_string(music.col3) from music where 1 != 1", + "Query": "select music.col3 as c, weight_string(music.col3) from music where music.id = :user_id", + "Table": "music", + "Values": [ + ":user_id" + ], + "Vindex": "music_user_map" + } + ] + } + ] + } +} # Order by for join, with mixed cross-shard ordering "select user.col1 as a, user.col2, music.col3 from user join music on user.id = music.id where user.id = 1 order by 1 asc, 3 desc, 2 asc" @@ -371,6 +420,55 @@ ] } } +{ + "QueryType": "SELECT", + "Original": "select user.col1 as a, user.col2, music.col3 from user join music on user.id = music.id where user.id = 1 order by 1 asc, 3 desc, 2 asc", + "Instructions": { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "0 ASC, 2 DESC, 1 ASC", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-2,-3,1,-4,2,-5", + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.id, `user`.col1 as a, `user`.col2, weight_string(`user`.col1), weight_string(`user`.col2) from `user` where 1 != 1", + "Query": "select `user`.id, `user`.col1 as a, `user`.col2, weight_string(`user`.col1), weight_string(`user`.col2) from `user` where `user`.id = 1", + "Table": "`user`", + "Values": [ + 1 + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.col3, weight_string(music.col3) from music where 1 != 1", + "Query": "select music.col3, weight_string(music.col3) from music where music.id = :user_id", + "Table": "music", + "Values": [ + ":user_id" + ], + "Vindex": "music_user_map" + } + ] + } + ] + } +} # Order by for join, on text column in LHS. "select u.a, u.textcol1, un.col2 from user u join unsharded un order by u.textcol1, un.col2" @@ -415,6 +513,7 @@ ] } } +Gen4 plan same as above # Order by for join, on text column in RHS. "select u.a, u.textcol1, un.col2 from unsharded un join user u order by u.textcol1, un.col2" @@ -459,6 +558,7 @@ ] } } +Gen4 plan same as above # order by for vindex func "select id, keyspace_id, range_start, range_end from user_index where id = :id order by range_start" diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt index e204b6fea3d..90720212366 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.txt @@ -299,6 +299,23 @@ Gen4 plan same as above "Table": "authoritative" } } +{ + "QueryType": "SELECT", + "Original": "select * from authoritative order by user_id", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select authoritative.user_id as user_id, authoritative.col1 as col1, authoritative.col2 as col2, weight_string(authoritative.user_id) from authoritative where 1 != 1", + "OrderBy": "0 ASC", + "Query": "select authoritative.user_id as user_id, authoritative.col1 as col1, authoritative.col2 as col2, weight_string(authoritative.user_id) from authoritative order by user_id asc", + "ResultColumns": 3, + "Table": "authoritative" + } +} # ORDER BY works for select * from authoritative table "select * from authoritative order by col1" @@ -319,6 +336,23 @@ Gen4 plan same as above "Table": "authoritative" } } +{ + "QueryType": "SELECT", + "Original": "select * from authoritative order by col1", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select authoritative.user_id as user_id, authoritative.col1 as col1, authoritative.col2 as col2, weight_string(authoritative.col1) from authoritative where 1 != 1", + "OrderBy": "1 ASC", + "Query": "select authoritative.user_id as user_id, authoritative.col1 as col1, authoritative.col2 as col2, weight_string(authoritative.col1) from authoritative order by col1 asc", + "ResultColumns": 3, + "Table": "authoritative" + } +} # ORDER BY on scatter with text column "select a, textcol1, b from user order by a, textcol1, b" @@ -341,7 +375,7 @@ Gen4 plan same as above } Gen4 plan same as above -# ORDER BY on scatter with text column, qualified name +# ORDER BY on scatter with text column, qualified name TODO: can plan better "select a, user.textcol1, b from user order by a, textcol1, b" { "QueryType": "SELECT", @@ -360,6 +394,23 @@ Gen4 plan same as above "Table": "`user`" } } +{ + "QueryType": "SELECT", + "Original": "select a, user.textcol1, b from user order by a, textcol1, b", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select a, `user`.textcol1, b, weight_string(a), textcol1, weight_string(textcol1), weight_string(b) from `user` where 1 != 1", + "OrderBy": "0 ASC, 4 ASC, 2 ASC", + "Query": "select a, `user`.textcol1, b, weight_string(a), textcol1, weight_string(textcol1), weight_string(b) from `user` order by a asc, textcol1 asc, b asc", + "ResultColumns": 3, + "Table": "`user`" + } +} # ORDER BY on scatter with multiple text columns "select a, textcol1, b, textcol2 from user order by a, textcol1, b, textcol2" @@ -384,7 +435,45 @@ Gen4 plan same as above # ORDER BY invalid col number on scatter "select col from user order by 2" -"column number out of range: 2" +"Unknown column '2' in 'order clause'" +Gen4 plan same as above + +# ORDER BY column offset +"select id as foo from music order by 1" +{ + "QueryType": "SELECT", + "Original": "select id as foo from music order by 1", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id as foo, weight_string(id) from music where 1 != 1", + "OrderBy": "0 ASC", + "Query": "select id as foo, weight_string(id) from music order by 1 asc", + "ResultColumns": 1, + "Table": "music" + } +} +{ + "QueryType": "SELECT", + "Original": "select id as foo from music order by 1", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id as foo, weight_string(id) from music where 1 != 1", + "OrderBy": "0 ASC", + "Query": "select id as foo, weight_string(id) from music order by foo asc", + "ResultColumns": 1, + "Table": "music" + } +} # ORDER BY NULL "select col from user order by null" @@ -575,6 +664,50 @@ Gen4 plan same as above ] } } +{ + "QueryType": "SELECT", + "Original": "select user.col1 as a, user.col2, music.col3 from user, music where user.id = music.id and user.id = 1 order by a", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-2,-3,1", + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.id, `user`.col1 as a, `user`.col2, weight_string(`user`.col1) from `user` where 1 != 1", + "OrderBy": "1 ASC", + "Query": "select `user`.id, `user`.col1 as a, `user`.col2, weight_string(`user`.col1) from `user` where `user`.id = 1 order by a asc", + "ResultColumns": 3, + "Table": "`user`", + "Values": [ + 1 + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "SelectEqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.col3 from music where 1 != 1", + "Query": "select music.col3 from music where music.id = :user_id", + "Table": "music", + "Values": [ + ":user_id" + ], + "Vindex": "music_user_map" + } + ] + } +} # ORDER BY NULL after pull-out subquery "select col from user where col in (select col2 from user) order by null" @@ -830,7 +963,8 @@ Gen4 plan same as above # Order by, out of range column number "select col from user order by 2" -"column number out of range: 2" +"Unknown column '2' in 'order clause'" +Gen4 plan same as above # Order by, '*' expression with qualified reference and using collate "select * from user where id = 5 order by user.col collate utf8_general_ci" @@ -1035,6 +1169,7 @@ Gen4 plan same as above ] } } +Gen4 plan same as above # limit for scatter "select col from user limit 1" @@ -1174,3 +1309,210 @@ Gen4 plan same as above "select id from user limit 1+1" "unexpected expression in LIMIT: expression is too complex '1 + 1'" Gen4 plan same as above + +# order by column alias +"select id as foo from music order by foo" +{ + "QueryType": "SELECT", + "Original": "select id as foo from music order by foo", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id as foo, weight_string(id) from music where 1 != 1", + "OrderBy": "0 ASC", + "Query": "select id as foo, weight_string(id) from music order by foo asc", + "ResultColumns": 1, + "Table": "music" + } +} +Gen4 plan same as above + +# column alias for a table column in order by +"select id as foo, id2 as id from music order by id" +{ + "QueryType": "SELECT", + "Original": "select id as foo, id2 as id from music order by id", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id as foo, id2 as id, weight_string(id2) from music where 1 != 1", + "OrderBy": "1 ASC", + "Query": "select id as foo, id2 as id, weight_string(id2) from music order by id asc", + "ResultColumns": 2, + "Table": "music" + } +} +Gen4 plan same as above + +# ordering on the left side of the join +"select name from user, music order by name" +{ + "QueryType": "SELECT", + "Original": "select name from user, music order by name", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1", + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `name`, weight_string(`name`) from `user` where 1 != 1", + "OrderBy": "0 ASC", + "Query": "select `name`, weight_string(`name`) from `user` order by `name` asc", + "ResultColumns": 1, + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from music where 1 != 1", + "Query": "select 1 from music", + "Table": "music" + } + ] + } +} +Gen4 plan same as above + +# aggregation and non-aggregations column without group by +"select count(id), num from user" +{ + "QueryType": "SELECT", + "Original": "select count(id), num from user", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count(0)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(id), num from `user` where 1 != 1", + "Query": "select count(id), num from `user`", + "Table": "`user`" + } + ] + } +} +"gen4 does not yet support: aggregation and non-aggregation expressions, together are not supported in cross-shard query" + +# aggregation and non-aggregations column with order by +"select count(id), num from user order by 2" +{ + "QueryType": "SELECT", + "Original": "select count(id), num from user order by 2", + "Instructions": { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "1 ASC", + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count(0)", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(id), num, weight_string(num) from `user` where 1 != 1", + "Query": "select count(id), num, weight_string(num) from `user`", + "Table": "`user`" + } + ] + } + ] + } +} +"gen4 does not yet support: aggregation and non-aggregation expressions, together are not supported in cross-shard query" + +# aggregation and non-aggregations column with group by +"select count(id), num from user group by 2" +{ + "QueryType": "SELECT", + "Original": "select count(id), num from user group by 2", + "Instructions": { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count(0)", + "GroupBy": "1", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(id), num, weight_string(num) from `user` where 1 != 1 group by 2", + "OrderBy": "1 ASC", + "Query": "select count(id), num, weight_string(num) from `user` group by 2 order by num asc", + "ResultColumns": 2, + "Table": "`user`" + } + ] + } +} +"gen4 does not yet support: aggregation and non-aggregation expressions, together are not supported in cross-shard query" + +# aggregation and non-aggregations column with group by and order by +"select count(id), num from user group by 2 order by 1" +{ + "QueryType": "SELECT", + "Original": "select count(id), num from user group by 2 order by 1", + "Instructions": { + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "0 ASC", + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "count(0)", + "GroupBy": "1", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select count(id), num, weight_string(num) from `user` where 1 != 1 group by 2", + "OrderBy": "1 ASC", + "Query": "select count(id), num, weight_string(num) from `user` group by 2 order by num asc", + "ResultColumns": 2, + "Table": "`user`" + } + ] + } + ] + } +} +"gen4 does not yet support: aggregation and non-aggregation expressions, together are not supported in cross-shard query" + diff --git a/go/vt/vtgate/planbuilder/testdata/show_cases_no_default_keyspace.txt b/go/vt/vtgate/planbuilder/testdata/show_cases_no_default_keyspace.txt index 4e60edf2363..be60be9680b 100644 --- a/go/vt/vtgate/planbuilder/testdata/show_cases_no_default_keyspace.txt +++ b/go/vt/vtgate/planbuilder/testdata/show_cases_no_default_keyspace.txt @@ -14,6 +14,7 @@ "SingleShardOnly": true } } +Gen4 plan same as above # show columns from routed table "show full fields from `route1`" @@ -31,6 +32,7 @@ "SingleShardOnly": true } } +Gen4 plan same as above # show variables "show variables" @@ -53,6 +55,7 @@ ] } } +Gen4 plan same as above # show full columns from system schema "show full columns from sys.sys_config" @@ -70,6 +73,7 @@ "SingleShardOnly": true } } +Gen4 plan same as above # show full columns from system schema replacing qualifier "show full columns from x.sys_config from sys" @@ -87,3 +91,4 @@ "SingleShardOnly": true } } +Gen4 plan same as above diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt index 19e52bc5890..26349d4c255 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.txt @@ -104,6 +104,7 @@ Gen4 plan same as above # Multi-value aggregates not supported "select count(a,b) from user" "unsupported: only one expression allowed inside aggregates: count(a, b)" +"aggregate functions take a single argument 'count(a, b)'" # Cannot have more than one aggr(distinct... "select count(distinct a), count(distinct b) from user" @@ -120,6 +121,23 @@ Gen4 plan same as above # order by with ambiguous column reference ; valid in MySQL "select id, id from user order by id" "ambiguous symbol reference: id" +{ + "QueryType": "SELECT", + "Original": "select id, id from user order by id", + "Instructions": { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select id, weight_string(id) from `user` where 1 != 1", + "OrderBy": "0 ASC", + "Query": "select id, weight_string(id) from `user` order by id asc", + "ResultColumns": 1, + "Table": "`user`" + } +} # scatter aggregate with ambiguous aliases "select distinct a, b as a from user" @@ -446,6 +464,7 @@ Gen4 plan same as above # create view with Cannot auto-resolve for cross-shard joins "create view user.view_a as select col from user join user_extra" "symbol col not found" +"Complex select queries are not supported in create or alter view statements" # create view with join that cannot be served in each shard separately "create view user.view_a as select user_extra.id from user join user_extra" diff --git a/go/vt/vtgate/planbuilder/testdata/wireup_cases.txt b/go/vt/vtgate/planbuilder/testdata/wireup_cases.txt index 80ca3f3367a..0671c868a24 100644 --- a/go/vt/vtgate/planbuilder/testdata/wireup_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/wireup_cases.txt @@ -893,7 +893,7 @@ "Sharded": true }, "FieldQuery": "select u.col, u.id from `user` as u where 1 != 1", - "Query": "select u.col, u.id from `user` as u limit :__upper_limit", + "Query": "select u.col, u.id from `user` as u", "Table": "`user`" }, { @@ -904,7 +904,7 @@ "Sharded": true }, "FieldQuery": "select e.id from user_extra as e where 1 != 1", - "Query": "select e.id from user_extra as e where e.id = :u_col limit :__upper_limit", + "Query": "select e.id from user_extra as e where e.id = :u_col", "Table": "user_extra" } ] @@ -912,7 +912,6 @@ ] } } - # Wire-up in subquery "select 1 from user where id in (select u.id, e.id from user u join user_extra e where e.id = u.col limit 10)" { diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 0f0c7b49815..7eb9e4321fb 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -18,6 +18,9 @@ package semantics import ( "fmt" + "strconv" + + "vitess.io/vitess/go/vt/vtgate/vindexes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" @@ -29,25 +32,27 @@ type ( analyzer struct { si SchemaInformation - Tables []*TableInfo + Tables []TableInfo scopes []*scope - exprDeps map[sqlparser.Expr]TableSet + exprDeps ExprDependencies err error currentDb string inProjection []bool - selectScope map[*sqlparser.Select]*scope - projErr error + rScope map[*sqlparser.Select]*scope + wScope map[*sqlparser.Select]*scope + projErr error } ) // newAnalyzer create the semantic analyzer func newAnalyzer(dbName string, si SchemaInformation) *analyzer { return &analyzer{ - exprDeps: map[sqlparser.Expr]TableSet{}, - selectScope: map[*sqlparser.Select]*scope{}, - currentDb: dbName, - si: si, + exprDeps: map[sqlparser.Expr]TableSet{}, + rScope: map[*sqlparser.Select]*scope{}, + wScope: map[*sqlparser.Select]*scope{}, + currentDb: dbName, + si: si, } } @@ -59,7 +64,7 @@ func Analyze(statement sqlparser.Statement, currentDb string, si SchemaInformati if err != nil { return nil, err } - return &SemTable{exprDependencies: analyzer.exprDeps, Tables: analyzer.Tables, selectScope: analyzer.selectScope, ProjectionErr: analyzer.projErr}, nil + return &SemTable{exprDependencies: analyzer.exprDeps, Tables: analyzer.Tables, selectScope: analyzer.rScope, ProjectionErr: analyzer.projErr}, nil } func (a *analyzer) setError(err error) { @@ -81,17 +86,19 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool { current := a.currentScope() n := cursor.Node() switch node := n.(type) { - case sqlparser.SelectExprs: - if isParentSelect(cursor) { - a.inProjection = append(a.inProjection, true) - fmt.Println(len(a.inProjection)) - } case *sqlparser.Select: if node.Having != nil { a.setError(Gen4NotSupportedF("HAVING")) } - a.push(newScope(current)) - a.selectScope[node] = a.currentScope() + + currScope := newScope(current) + a.push(currScope) + + // Needed for order by with Literal to find the Expression. + currScope.selectExprs = node.SelectExprs + + a.rScope[node] = currScope + a.wScope[node] = newScope(nil) case *sqlparser.DerivedTable: a.setError(Gen4NotSupportedF("derived tables")) case *sqlparser.Subquery: @@ -114,6 +121,70 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool { } case *sqlparser.Union: a.push(newScope(current)) + case sqlparser.SelectExprs: + if isParentSelect(cursor) { + a.inProjection = append(a.inProjection, true) + } + sel, ok := cursor.Parent().(*sqlparser.Select) + if !ok { + break + } + + wScope, exists := a.wScope[sel] + if !exists { + break + } + + vTbl := &vTableInfo{} + for _, selectExpr := range node { + expr, ok := selectExpr.(*sqlparser.AliasedExpr) + if !ok { + continue + } + vTbl.cols = append(vTbl.cols, expr.Expr) + if !expr.As.IsEmpty() { + vTbl.columnNames = append(vTbl.columnNames, expr.As.String()) + } else { + vTbl.columnNames = append(vTbl.columnNames, sqlparser.String(expr)) + } + } + wScope.tables = append(wScope.tables, vTbl) + case sqlparser.OrderBy: + a.changeScopeForOrderBy(cursor) + case *sqlparser.Order: + l, ok := node.Expr.(*sqlparser.Literal) + if !ok { + break + } + if l.Type != sqlparser.IntVal { + break + } + currScope := a.currentScope() + num, err := strconv.Atoi(l.Val) + if err != nil { + a.err = err + break + } + if num < 1 || num > len(currScope.selectExprs) { + a.err = vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown column '%d' in 'order clause'", num) + break + } + + expr, ok := currScope.selectExprs[num-1].(*sqlparser.AliasedExpr) + if !ok { + break + } + + var deps TableSet + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + expr, ok := node.(sqlparser.Expr) + if ok { + deps = deps.Merge(a.exprDeps[expr]) + } + return true, nil + }, expr.Expr) + + a.exprDeps[node.Expr] = deps case *sqlparser.ColName: t, err := a.resolveColumn(node, current) if err != nil { @@ -141,37 +212,50 @@ func (a *analyzer) analyzeDown(cursor *sqlparser.Cursor) bool { return true } +func (a *analyzer) changeScopeForOrderBy(cursor *sqlparser.Cursor) { + sel, ok := cursor.Parent().(*sqlparser.Select) + if !ok { + return + } + // In ORDER BY, we can see both the scope in the FROM part of the query, and the SELECT columns created + // so before walking the rest of the tree, we change the scope to match this behaviour + incomingScope := a.currentScope() + nScope := newScope(incomingScope) + a.push(nScope) + wScope := a.wScope[sel] + nScope.tables = append(nScope.tables, wScope.tables...) + nScope.selectExprs = incomingScope.selectExprs + + if a.rScope[sel] != incomingScope { + panic("BUG: scope counts did not match") + } +} + func isParentSelect(cursor *sqlparser.Cursor) bool { _, isSelect := cursor.Parent().(*sqlparser.Select) return isSelect } func (a *analyzer) resolveColumn(colName *sqlparser.ColName, current *scope) (TableSet, error) { - var t *TableInfo - var err error if colName.Qualifier.IsEmpty() { - t, err = a.resolveUnQualifiedColumn(current, colName) - } else { - t, err = a.resolveQualifiedColumn(current, colName) + return a.resolveUnQualifiedColumn(current, colName) } + t, err := a.resolveQualifiedColumn(current, colName) if err != nil { return 0, err } if t == nil { return 0, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqError, fmt.Sprintf("Column '%s' in field list is ambiguous", sqlparser.String(colName))) } - return a.tableSetFor(t.ASTNode), nil + return a.tableSetFor(t.GetExpr()), nil } // resolveQualifiedColumn handles `tabl.col` expressions -func (a *analyzer) resolveQualifiedColumn(current *scope, expr *sqlparser.ColName) (*TableInfo, error) { +func (a *analyzer) resolveQualifiedColumn(current *scope, expr *sqlparser.ColName) (TableInfo, error) { // search up the scope stack until we find a match for current != nil { - dbName := expr.Qualifier.Qualifier.String() - tableName := expr.Qualifier.Name.String() for _, table := range current.tables { - if tableName == table.tableName && - (dbName == table.dbName || (dbName == "" && (table.dbName == a.currentDb || a.currentDb == ""))) { + if table.Matches(expr.Qualifier) { return table, nil } } @@ -180,40 +264,62 @@ func (a *analyzer) resolveQualifiedColumn(current *scope, expr *sqlparser.ColNam return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "symbol %s not found", sqlparser.String(expr)) } -// resolveUnQualifiedColumn -func (a *analyzer) resolveUnQualifiedColumn(current *scope, expr *sqlparser.ColName) (*TableInfo, error) { - if len(current.tables) == 1 { - for _, tableExpr := range current.tables { - return tableExpr, nil - } - } +type originable interface { + tableSetFor(t *sqlparser.AliasedTableExpr) TableSet + depsForExpr(expr sqlparser.Expr) TableSet +} - var tblInfo *TableInfo +func (a *analyzer) depsForExpr(expr sqlparser.Expr) TableSet { + return a.exprDeps.Dependencies(expr) +} + +// resolveUnQualifiedColumn +func (a *analyzer) resolveUnQualifiedColumn(current *scope, expr *sqlparser.ColName) (TableSet, error) { + var tsp *TableSet for _, tbl := range current.tables { - if tbl.Table == nil { - return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqError, fmt.Sprintf("Column '%s' in field list is ambiguous", sqlparser.String(expr))) + ts := tbl.DepsFor(expr, a, len(current.tables) == 1) + if ts != nil && tsp != nil { + return 0, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqError, fmt.Sprintf("Column '%s' in field list is ambiguous", sqlparser.String(expr))) } - for _, col := range tbl.Table.Columns { - if expr.Name.Equal(col.Name) { - if tblInfo != nil { - return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqError, fmt.Sprintf("Column '%s' in field list is ambiguous", sqlparser.String(expr))) - } - tblInfo = tbl - } + if ts != nil { + tsp = ts } } - return tblInfo, nil + if tsp == nil { + return 0, nil + } + return *tsp, nil } func (a *analyzer) tableSetFor(t *sqlparser.AliasedTableExpr) TableSet { for i, t2 := range a.Tables { - if t == t2.ASTNode { + if t == t2.GetExpr() { return TableSet(1 << i) } } panic("unknown table") } +func (a *analyzer) createTable(t sqlparser.TableName, alias *sqlparser.AliasedTableExpr, tbl *vindexes.Table) TableInfo { + dbName := t.Qualifier.String() + if dbName == "" { + dbName = a.currentDb + } + if alias.As.IsEmpty() { + return &RealTable{ + dbName: dbName, + tableName: t.Name.String(), + ASTNode: alias, + Table: tbl, + } + } + return &AliasedTable{ + tableName: alias.As.String(), + ASTNode: alias, + Table: tbl, + } +} + func (a *analyzer) bindTable(alias *sqlparser.AliasedTableExpr, expr sqlparser.SimpleTableExpr) error { switch t := expr.(type) { case *sqlparser.DerivedTable: @@ -230,24 +336,10 @@ func (a *analyzer) bindTable(alias *sqlparser.AliasedTableExpr, expr sqlparser.S return Gen4NotSupportedF("vindex in FROM") } scope := a.currentScope() - dbName := t.Qualifier.String() - if dbName == "" { - dbName = a.currentDb - } - var tableName string - if alias.As.IsEmpty() { - tableName = t.Name.String() - } else { - tableName = alias.As.String() - } - table := &TableInfo{ - dbName: dbName, - tableName: tableName, - ASTNode: alias, - Table: tbl, - } - a.Tables = append(a.Tables, table) - return scope.addTable(table) + tableInfo := a.createTable(t, alias, tbl) + + a.Tables = append(a.Tables, tableInfo) + return scope.addTable(tableInfo) } return nil } @@ -266,7 +358,7 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { if isParentSelect(cursor) { a.popProjection() } - case *sqlparser.Union, *sqlparser.Select: + case *sqlparser.Union, *sqlparser.Select, sqlparser.OrderBy: a.popScope() case sqlparser.TableExpr: if isParentSelect(cursor) { @@ -283,6 +375,7 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { } } } + return a.shouldContinue() } diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index ee1c2cad29b..8ad7712a9a8 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -31,13 +31,15 @@ import ( "vitess.io/vitess/go/vt/sqlparser" ) +const T0 TableSet = 0 + const ( // Just here to make outputs more readable - T0 TableSet = 1 << iota - T1 + T1 TableSet = 1 << iota T2 - _ // T3 is not used in the tests - T4 + T3 + _ // T4 is not used in the tests + T5 ) func extract(in *sqlparser.Select, idx int) sqlparser.Expr { @@ -59,7 +61,7 @@ from x as t` s1 := semTable.Dependencies(extract(sel2, 0)) // if scoping works as expected, we should be able to see the inner table being used by the inner expression - assert.Equal(t, T1, s1) + assert.Equal(t, T2, s1) } func TestBindingSingleTable(t *testing.T) { @@ -81,7 +83,7 @@ func TestBindingSingleTable(t *testing.T) { assert.EqualValues(t, 1, ts) d := semTable.Dependencies(extract(sel, 0)) - require.Equal(t, T0, d, query) + require.Equal(t, T1, d, query) }) } }) @@ -104,6 +106,43 @@ func TestBindingSingleTable(t *testing.T) { }) } +func TestOrderByBindingSingleTable(t *testing.T) { + t.Run("positive tests", func(t *testing.T) { + tcases := []struct { + sql string + deps TableSet + }{{ + "select col from tabl order by col", + T1, + }, { + "select col from tabl order by tabl.col", + T1, + }, { + + "select col from tabl order by d.tabl.col", + T1, + }, { + "select col from tabl order by 1", + T1, + }, { + "select col as c from tabl order by c", + T1, + }, { + "select 1 as c from tabl order by c", + T0, + }} + for _, tc := range tcases { + t.Run(tc.sql, func(t *testing.T) { + stmt, semTable := parseAndAnalyze(t, tc.sql, "d") + sel, _ := stmt.(*sqlparser.Select) + order := sel.OrderBy[0].Expr + d := semTable.Dependencies(order) + require.Equal(t, tc.deps, d, tc.sql) + }) + } + }) +} + func TestBindingSingleAliasedTable(t *testing.T) { t.Run("positive tests", func(t *testing.T) { queries := []string{ @@ -111,7 +150,6 @@ func TestBindingSingleAliasedTable(t *testing.T) { "select tabl.col from X as tabl", "select col from d.X as tabl", "select tabl.col from d.X as tabl", - "select d.tabl.col from d.X as tabl", } for _, query := range queries { t.Run(query, func(t *testing.T) { @@ -122,7 +160,7 @@ func TestBindingSingleAliasedTable(t *testing.T) { assert.EqualValues(t, 1, ts) d := semTable.Dependencies(extract(sel, 0)) - require.Equal(t, T0, d, query) + require.Equal(t, T1, d, query) }) } }) @@ -132,6 +170,7 @@ func TestBindingSingleAliasedTable(t *testing.T) { "select d.X.col from d.X as tabl", "select d.tabl.col from X as tabl", "select d.tabl.col from ks.X as tabl", + "select d.tabl.col from d.X as tabl", } for _, query := range queries { t.Run(query, func(t *testing.T) { @@ -165,8 +204,8 @@ func TestUnion(t *testing.T) { d1 := semTable.Dependencies(extract(sel1, 0)) d2 := semTable.Dependencies(extract(sel2, 0)) - assert.Equal(t, T0, d1) - assert.Equal(t, T1, d2) + assert.Equal(t, T1, d1) + assert.Equal(t, T2, d2) } func TestBindingMultiTable(t *testing.T) { @@ -178,44 +217,44 @@ func TestBindingMultiTable(t *testing.T) { } queries := []testCase{{ query: "select t.col from t, s", - deps: T0, + deps: T1, }, { query: "select s.col from t join s", - deps: T1, + deps: T2, }, { query: "select max(t.col+s.col) from t, s", - deps: T0 | T1, + deps: T1 | T2, }, { query: "select max(t.col+s.col) from t join s", - deps: T0 | T1, + deps: T1 | T2, }, { query: "select case t.col when s.col then r.col else u.col end from t, s, r, w, u", - deps: T0 | T1 | T2 | T4, + deps: T1 | T2 | T3 | T5, // }, { // // make sure that we don't let sub-query Dependencies leak out by mistake // query: "select t.col + (select 42 from s) from t", - // deps: T0, + // deps: T1, // }, { // query: "select (select 42 from s where r.id = s.id) from r", - // deps: T0 | T1, + // deps: T1 | T2, }, { query: "select X.col from t as X, s as S", - deps: T0, + deps: T1, }, { query: "select X.col+S.col from t as X, s as S", - deps: T0 | T1, + deps: T1 | T2, }, { query: "select max(X.col+S.col) from t as X, s as S", - deps: T0 | T1, + deps: T1 | T2, }, { query: "select max(X.col+s.col) from t as X, s", - deps: T0 | T1, + deps: T1 | T2, }, { query: "select b.t.col from b.t, t", - deps: T0, + deps: T1, }, { query: "select u1.a + u2.a from u1, u2", - deps: T0 | T1, + deps: T1 | T2, }} for _, query := range queries { t.Run(query.query, func(t *testing.T) { @@ -228,7 +267,6 @@ func TestBindingMultiTable(t *testing.T) { t.Run("negative tests", func(t *testing.T) { queries := []string{ - "select 1 from d.tabl, d.foo as tabl", "select 1 from d.tabl, d.tabl", "select 1 from d.tabl, tabl", "select 1 from user join user_extra user", @@ -256,7 +294,7 @@ func TestBindingSingleDepPerTable(t *testing.T) { d := semTable.Dependencies(extract(sel, 0)) assert.Equal(t, 1, d.NumberOfTables(), "size wrong") - assert.Equal(t, T0, d) + assert.Equal(t, T1, d) } func TestNotUniqueTableName(t *testing.T) { @@ -336,12 +374,12 @@ func TestUnknownColumnMap2(t *testing.T) { { name: "no info about tables", schema: map[string]*vindexes.Table{"a": {}, "b": {}}, - err: true, + err: false, }, { name: "non authoritative columns", schema: map[string]*vindexes.Table{"a": &nonAuthoritativeTblA, "b": &nonAuthoritativeTblA}, - err: true, + err: false, }, { name: "non authoritative columns - one authoritative and one not", @@ -398,7 +436,7 @@ func TestUnknownPredicate(t *testing.T) { { name: "no info about tables", schema: map[string]*vindexes.Table{"a": authoritativeTblA, "b": authoritativeTblB}, - err: true, + err: false, }, } for _, test := range tests { diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 5694f80607b..c2287c77846 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -18,6 +18,7 @@ package semantics import ( "vitess.io/vitess/go/vt/key" + querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -27,31 +28,63 @@ import ( ) type ( - // TableInfo contains the alias table expr and vindex table - TableInfo struct { + // TableInfo contains information about tables + TableInfo interface { + Matches(name sqlparser.TableName) bool + Authoritative() bool + Name() (sqlparser.TableName, error) + GetExpr() *sqlparser.AliasedTableExpr + GetColumns() []ColumnInfo + IsVirtual() bool + DepsFor(col *sqlparser.ColName, org originable, single bool) *TableSet + } + + // ColumnInfo contains information about columns + ColumnInfo struct { + Name string + Type querypb.Type + } + + // RealTable contains the alias table expr and vindex table + RealTable struct { dbName, tableName string ASTNode *sqlparser.AliasedTableExpr Table *vindexes.Table } + // AliasedTable contains the alias table expr and vindex table + AliasedTable struct { + tableName string + ASTNode *sqlparser.AliasedTableExpr + Table *vindexes.Table + } + + vTableInfo struct { + columnNames []string + cols []sqlparser.Expr + } + // TableSet is how a set of tables is expressed. // Tables get unique bits assigned in the order that they are encountered during semantic analysis TableSet uint64 // we can only join 64 tables with this underlying data type // TODO : change uint64 to struct to support arbitrary number of tables. + ExprDependencies map[sqlparser.Expr]TableSet + // SemTable contains semantic analysis information about the query. SemTable struct { - Tables []*TableInfo + Tables []TableInfo // ProjectionErr stores the error that we got during the semantic analysis of the SelectExprs. // This is only a real error if we are unable to plan the query as a single route ProjectionErr error - exprDependencies map[sqlparser.Expr]TableSet + exprDependencies ExprDependencies selectScope map[*sqlparser.Select]*scope } scope struct { - parent *scope - tables []*TableInfo + parent *scope + selectExprs sqlparser.SelectExprs + tables []TableInfo } // SchemaInformation is used tp provide table information from Vschema. @@ -60,15 +93,167 @@ type ( } ) +func (v *vTableInfo) DepsFor(col *sqlparser.ColName, org originable, single bool) *TableSet { + if !col.Qualifier.IsEmpty() { + return nil + } + for i, colName := range v.columnNames { + if col.Name.String() == colName { + ts := org.depsForExpr(v.cols[i]) + return &ts + } + } + return nil +} + +func (a *AliasedTable) DepsFor(col *sqlparser.ColName, org originable, single bool) *TableSet { + if single { + ts := org.tableSetFor(a.ASTNode) + return &ts + } + for _, info := range a.GetColumns() { + if col.Name.String() == info.Name { + ts := org.tableSetFor(a.ASTNode) + return &ts + } + } + return nil +} + +func (r *RealTable) DepsFor(col *sqlparser.ColName, org originable, single bool) *TableSet { + if single { + ts := org.tableSetFor(r.ASTNode) + return &ts + } + for _, info := range r.GetColumns() { + if col.Name.String() == info.Name { + ts := org.tableSetFor(r.ASTNode) + return &ts + } + } + return nil +} + +func (v *vTableInfo) IsVirtual() bool { + return true +} + +func (a *AliasedTable) IsVirtual() bool { + return false +} + +func (r *RealTable) IsVirtual() bool { + return false +} + +var _ TableInfo = (*RealTable)(nil) +var _ TableInfo = (*AliasedTable)(nil) +var _ TableInfo = (*vTableInfo)(nil) + +func (v *vTableInfo) Matches(name sqlparser.TableName) bool { + return false +} + +func (v *vTableInfo) Authoritative() bool { + return true +} + +func (v *vTableInfo) Name() (sqlparser.TableName, error) { + return sqlparser.TableName{}, nil +} + +func (v *vTableInfo) GetExpr() *sqlparser.AliasedTableExpr { + return nil +} + +func (v *vTableInfo) GetColumns() []ColumnInfo { + cols := make([]ColumnInfo, 0, len(v.columnNames)) + for _, col := range v.columnNames { + cols = append(cols, ColumnInfo{ + Name: col, + }) + } + return cols +} + +func vindexTableToColumnInfo(tbl *vindexes.Table) []ColumnInfo { + if tbl == nil { + return nil + } + cols := make([]ColumnInfo, 0, len(tbl.Columns)) + for _, col := range tbl.Columns { + cols = append(cols, ColumnInfo{ + Name: col.Name.String(), + Type: col.Type, + }) + } + return cols +} + +// GetColumns implements the TableInfo interface +func (a *AliasedTable) GetColumns() []ColumnInfo { + return vindexTableToColumnInfo(a.Table) +} + +// GetExpr implements the TableInfo interface +func (a *AliasedTable) GetExpr() *sqlparser.AliasedTableExpr { + return a.ASTNode +} + +// Name implements the TableInfo interface +func (a *AliasedTable) Name() (sqlparser.TableName, error) { + return a.ASTNode.TableName() +} + +// Authoritative implements the TableInfo interface +func (a *AliasedTable) Authoritative() bool { + return a.Table != nil && a.Table.ColumnListAuthoritative +} + +// Matches implements the TableInfo interface +func (a *AliasedTable) Matches(name sqlparser.TableName) bool { + return a.tableName == name.Name.String() && name.Qualifier.IsEmpty() +} + +// GetColumns implements the TableInfo interface +func (r *RealTable) GetColumns() []ColumnInfo { + return vindexTableToColumnInfo(r.Table) +} + +// GetExpr implements the TableInfo interface +func (r *RealTable) GetExpr() *sqlparser.AliasedTableExpr { + return r.ASTNode +} + +// Name implements the TableInfo interface +func (r *RealTable) Name() (sqlparser.TableName, error) { + return r.ASTNode.TableName() +} + +// Authoritative implements the TableInfo interface +func (r *RealTable) Authoritative() bool { + return r.Table != nil && r.Table.ColumnListAuthoritative +} + +// Matches implements the TableInfo interface +func (r *RealTable) Matches(name sqlparser.TableName) bool { + if !name.Qualifier.IsEmpty() { + if r.dbName != name.Qualifier.String() { + return false + } + } + return r.tableName == name.Name.String() +} + // NewSemTable creates a new empty SemTable func NewSemTable() *SemTable { return &SemTable{exprDependencies: map[sqlparser.Expr]TableSet{}} } -// TableSetFor returns the bitmask for this particular tableshoe +// TableSetFor returns the bitmask for this particular table func (st *SemTable) TableSetFor(t *sqlparser.AliasedTableExpr) TableSet { for idx, t2 := range st.Tables { - if t == t2.ASTNode { + if t == t2.GetExpr() { return 1 << idx } } @@ -76,7 +261,7 @@ func (st *SemTable) TableSetFor(t *sqlparser.AliasedTableExpr) TableSet { } // TableInfoFor returns the table info for the table set. It should contains only single table. -func (st *SemTable) TableInfoFor(id TableSet) (*TableInfo, error) { +func (st *SemTable) TableInfoFor(id TableSet) (TableInfo, error) { if id.NumberOfTables() > 1 { return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] should only be used for single tables") } @@ -85,7 +270,11 @@ func (st *SemTable) TableInfoFor(id TableSet) (*TableInfo, error) { // Dependencies return the table dependencies of the expression. func (st *SemTable) Dependencies(expr sqlparser.Expr) TableSet { - deps, found := st.exprDependencies[expr] + return st.exprDependencies.Dependencies(expr) +} + +func (d ExprDependencies) Dependencies(expr sqlparser.Expr) TableSet { + deps, found := d[expr] if found { return deps } @@ -93,19 +282,17 @@ func (st *SemTable) Dependencies(expr sqlparser.Expr) TableSet { _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { colName, ok := node.(*sqlparser.ColName) if ok { - set := st.exprDependencies[colName] + set := d[colName] deps |= set } return true, nil }, expr) - - st.exprDependencies[expr] = deps - + d[expr] = deps return deps } // GetSelectTables returns the table in the select. -func (st *SemTable) GetSelectTables(node *sqlparser.Select) []*TableInfo { +func (st *SemTable) GetSelectTables(node *sqlparser.Select) []TableInfo { scope := st.selectScope[node] return scope.tables } @@ -122,16 +309,17 @@ func newScope(parent *scope) *scope { return &scope{parent: parent} } -func (s *scope) addTable(table *TableInfo) error { +func (s *scope) addTable(info TableInfo) error { for _, scopeTable := range s.tables { - b := scopeTable.tableName == table.tableName - b2 := scopeTable.dbName == table.dbName - if b && b2 { - return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqTable, "Not unique table/alias: '%s'", table.tableName) + scopeTableName, err := scopeTable.Name() + if err != nil { + return err + } + if info.Matches(scopeTableName) { + return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqTable, "Not unique table/alias: '%s'", scopeTableName.Name.String()) } } - - s.tables = append(s.tables, table) + s.tables = append(s.tables, info) return nil }