From 481e317a36fa6f1215aaa70a801c99486de453c1 Mon Sep 17 00:00:00 2001 From: Florent Poinsard Date: Thu, 15 Apr 2021 10:25:30 +0200 Subject: [PATCH] Implemented table aliases in semantic analysis Signed-off-by: Florent Poinsard --- go/vt/vtgate/planbuilder/querygraph_test.go | 4 +- go/vt/vtgate/planbuilder/route_planning.go | 2 +- go/vt/vtgate/semantics/analyzer.go | 57 ++++++++++----------- go/vt/vtgate/semantics/analyzer_test.go | 26 +++++----- go/vt/vtgate/semantics/semantic_state.go | 31 +++++------ 5 files changed, 56 insertions(+), 64 deletions(-) diff --git a/go/vt/vtgate/planbuilder/querygraph_test.go b/go/vt/vtgate/planbuilder/querygraph_test.go index 551ec9bb033..76a6052756f 100644 --- a/go/vt/vtgate/planbuilder/querygraph_test.go +++ b/go/vt/vtgate/planbuilder/querygraph_test.go @@ -123,7 +123,7 @@ func TestQueryGraph(t *testing.T) { t.Run(fmt.Sprintf("%d %s", i, sql), func(t *testing.T) { tree, err := sqlparser.Parse(sql) require.NoError(t, err) - semTable, err := semantics.Analyse(tree, "", &fakeSI{}) + semTable, err := semantics.Analyze(tree, "", &fakeSI{}) require.NoError(t, err) qgraph, err := createQGFromSelect(tree.(*sqlparser.Select), semTable) require.NoError(t, err) @@ -136,7 +136,7 @@ func TestQueryGraph(t *testing.T) { func TestString(t *testing.T) { tree, err := sqlparser.Parse("select * from a,b join c on b.id = c.id where a.id = b.id and b.col IN (select 42) and func() = 'foo'") require.NoError(t, err) - semTable, err := semantics.Analyse(tree, "", &fakeSI{}) + semTable, err := semantics.Analyze(tree, "", &fakeSI{}) require.NoError(t, err) qgraph, err := createQGFromSelect(tree.(*sqlparser.Select), semTable) require.NoError(t, err) diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index 23a913dfdf2..956ec6de241 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -50,7 +50,7 @@ func newBuildSelectPlan(sel *sqlparser.Select, vschema ContextVSchema) (engine.P if err != nil { return nil, err } - semTable, err := semantics.Analyse(sel, keyspace.Name, vschema) + semTable, err := semantics.Analyze(sel, keyspace.Name, vschema) if err != nil { return nil, err } diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index fcd35342f9f..6932ac6759a 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -46,8 +46,8 @@ func newAnalyzer(dbName string, si SchemaInformation) *analyzer { } } -// Analyse analyzes the parsed query. -func Analyse(statement sqlparser.Statement, currentDb string, si SchemaInformation) (*SemTable, error) { +// Analyze analyzes the parsed query. +func Analyze(statement sqlparser.Statement, currentDb string, si SchemaInformation) (*SemTable, error) { analyzer := newAnalyzer(currentDb, si) // Initial scope err := analyzer.analyze(statement) @@ -135,31 +135,18 @@ func (a *analyzer) analyzeTableExpr(tableExpr sqlparser.TableExpr) error { // resolveQualifiedColumn handles `tabl.col` expressions func (a *analyzer) resolveQualifiedColumn(current *scope, expr *sqlparser.ColName) (*TableInfo, error) { - id := tableID{ - dbName: expr.Qualifier.Qualifier.String(), - tableName: expr.Qualifier.Name.String(), - } - id2 := tableID{ - dbName: a.currentDb, - tableName: expr.Qualifier.Name.String(), - } - checkCurrentDB := id.dbName == "" && id != id2 - // search up the scope stack until we find a match for current != nil { - tableExpr, found := current.tables[id] - if found { - return tableExpr, nil - } - if checkCurrentDB { - tableExpr, found := current.tables[id2] - if found { - return tableExpr, nil + dbName := expr.Qualifier.Qualifier.String() + tableName := expr.Qualifier.Name.String() + for _, table := range current.tables { + if tableName == table.tableName && + (dbName == table.dbName || (dbName == "" && (table.dbName == a.currentDb || a.currentDb == ""))) { + return table, nil } } current = current.parent } - return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown table referenced by '%s'", sqlparser.String(expr)) } @@ -206,8 +193,8 @@ func (a *analyzer) bindTable(alias *sqlparser.AliasedTableExpr, expr sqlparser.S } a.popScope() scope := a.currentScope() - dbName := "" // derived tables are always referenced only by their alias - they cannot be found using a fully qualified name - return scope.addTable(dbName, alias.As.String(), &TableInfo{alias, nil}) + //dbName := "" // derived tables are always referenced only by their alias - they cannot be found using a fully qualified name + return scope.addTable(&TableInfo{}) case sqlparser.TableName: tbl, vdx, _, _, _, err := a.si.FindTableOrVindex(t) if err != nil { @@ -217,16 +204,24 @@ func (a *analyzer) bindTable(alias *sqlparser.AliasedTableExpr, expr sqlparser.S return Gen4NotSupportedF("vindex in FROM") } scope := a.currentScope() - table := &TableInfo{alias, tbl} - a.Tables = append(a.Tables, table) + dbName := t.Qualifier.String() + if dbName == "" { + dbName = a.currentDb + } + var tableName string if alias.As.IsEmpty() { - dbName := t.Qualifier.String() - if dbName == "" { - dbName = a.currentDb - } - return scope.addTable(dbName, t.Name.String(), table) + tableName = t.Name.String() + } else { + tableName = alias.As.String() + } + table := &TableInfo{ + dbName: dbName, + tableName: tableName, + ASTNode: alias, + Table: tbl, } - return scope.addTable("", alias.As.String(), table) + a.Tables = append(a.Tables, table) + return scope.addTable(table) } return nil } diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index b662f876965..cc4d31dcf36 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -73,6 +73,7 @@ func TestBindingSingleTable(t *testing.T) { "select col from d.tabl", "select tabl.col from d.tabl", "select d.tabl.col from d.tabl", + "select d.tabl.col from X as tabl", } for _, query := range queries { t.Run(query, func(t *testing.T) { @@ -90,7 +91,6 @@ func TestBindingSingleTable(t *testing.T) { t.Run("negative tests", func(t *testing.T) { queries := []string{ "select foo.col from tabl", - "select d.tabl.col from X as tabl", "select tabl.col from d.tabl as X", "select d.tabl.col from d.tabl as X", } @@ -98,7 +98,7 @@ func TestBindingSingleTable(t *testing.T) { t.Run(query, func(t *testing.T) { parse, err := sqlparser.Parse(query) require.NoError(t, err) - _, err = Analyse(parse, "d", &fakeSI{}) + _, err = Analyze(parse, "d", &fakeSI{}) require.Error(t, err) }) } @@ -112,6 +112,7 @@ func TestBindingSingleAliasedTable(t *testing.T) { "select tabl.col from X as tabl", "select col from d.X as tabl", "select tabl.col from d.X as tabl", + "select d.tabl.col from d.X as tabl", } for _, query := range queries { t.Run(query, func(t *testing.T) { @@ -131,13 +132,12 @@ func TestBindingSingleAliasedTable(t *testing.T) { "select tabl.col from tabl as X", "select X.col from X as tabl", "select d.X.col from d.X as tabl", - "select d.tabl.col from d.X as tabl", } for _, query := range queries { t.Run(query, func(t *testing.T) { parse, err := sqlparser.Parse(query) require.NoError(t, err) - _, err = Analyse(parse, "", &fakeSI{ + _, err = Analyze(parse, "", &fakeSI{ tables: map[string]*vindexes.Table{ "t": {Name: sqlparser.NewTableIdent("t")}, }, @@ -191,11 +191,11 @@ func TestBindingMultiTable(t *testing.T) { }, { query: "select case t.col when s.col then r.col else u.col end from t, s, r, w, u", deps: T0 | T1 | T2 | T4, - //}, { + // }, { // // make sure that we don't let sub-query Dependencies leak out by mistake // query: "select t.col + (select 42 from s) from t", // deps: T0, - //}, { + // }, { // query: "select (select 42 from s where r.id = s.id) from r", // deps: T0 | T1, }} @@ -209,7 +209,6 @@ func TestBindingMultiTable(t *testing.T) { }) t.Run("negative tests", func(t *testing.T) { - t.Skip("implement me!") queries := []string{ "select 1 from d.tabl, d.foo as tabl", } @@ -217,9 +216,10 @@ func TestBindingMultiTable(t *testing.T) { t.Run(query, func(t *testing.T) { parse, err := sqlparser.Parse(query) require.NoError(t, err) - _, err = Analyse(parse, "", &fakeSI{ + _, err = Analyze(parse, "d", &fakeSI{ tables: map[string]*vindexes.Table{ - "t": {Name: sqlparser.NewTableIdent("t")}, + "tabl": {Name: sqlparser.NewTableIdent("tabl")}, + "foo": {Name: sqlparser.NewTableIdent("foo")}, }, }) require.Error(t, err) @@ -252,7 +252,7 @@ func TestNotUniqueTableName(t *testing.T) { t.Skip("derived tables not implemented") } parse, _ := sqlparser.Parse(query) - _, err := Analyse(parse, "test", &fakeSI{}) + _, err := Analyze(parse, "test", &fakeSI{}) require.Error(t, err) require.Contains(t, err.Error(), "Not unique table/alias") }) @@ -267,7 +267,7 @@ func TestMissingTable(t *testing.T) { for _, query := range queries { t.Run(query, func(t *testing.T) { parse, _ := sqlparser.Parse(query) - _, err := Analyse(parse, "", &fakeSI{}) + _, err := Analyze(parse, "", &fakeSI{}) require.Error(t, err) require.Contains(t, err.Error(), "Unknown table") }) @@ -346,7 +346,7 @@ func TestUnknownColumnMap2(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { si := &fakeSI{tables: test.schema} - _, err := Analyse(parse, "", si) + _, err := Analyze(parse, "", si) if test.err { require.Error(t, err) } else { @@ -360,7 +360,7 @@ func parseAndAnalyze(t *testing.T, query, dbName string) (sqlparser.Statement, * t.Helper() parse, err := sqlparser.Parse(query) require.NoError(t, err) - semTable, err := Analyse(parse, dbName, &fakeSI{ + semTable, err := Analyze(parse, dbName, &fakeSI{ tables: map[string]*vindexes.Table{ "t": {Name: sqlparser.NewTableIdent("t")}, }, diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 2e76c12fff1..3fbfb749ba4 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -29,8 +29,9 @@ import ( type ( // TableInfo contains the alias table expr and vindex table TableInfo struct { - ASTNode *sqlparser.AliasedTableExpr - Table *vindexes.Table + dbName, tableName string + ASTNode *sqlparser.AliasedTableExpr + Table *vindexes.Table } // TableSet is how a set of tables is expressed. @@ -44,13 +45,9 @@ type ( exprDependencies map[sqlparser.Expr]TableSet } - tableID struct { - dbName, tableName string - } - scope struct { parent *scope - tables map[tableID]*TableInfo + tables []*TableInfo } // SchemaInformation is used tp provide table information from Vschema. @@ -104,19 +101,19 @@ func (st *SemTable) Dependencies(expr sqlparser.Expr) TableSet { } func newScope(parent *scope) *scope { - return &scope{tables: map[tableID]*TableInfo{}, parent: parent} + return &scope{parent: parent} } -func (s *scope) addTable(dbName, tableName string, table *TableInfo) error { - id := tableID{ - dbName: dbName, - tableName: tableName, - } - _, found := s.tables[id] - if found { - return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqTable, "Not unique table/alias: '%s'", tableName) +func (s *scope) addTable(table *TableInfo) error { + for _, scopeTable := range s.tables { + b := scopeTable.tableName == table.tableName + b2 := scopeTable.dbName == table.dbName + if b && b2 { + return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqTable, "Not unique table/alias: '%s'", table.tableName) + } } - s.tables[id] = table + + s.tables = append(s.tables, table) return nil }