Skip to content

Commit

Permalink
Implemented table aliases in semantic analysis
Browse files Browse the repository at this point in the history
Signed-off-by: Florent Poinsard <florent.poinsard@outlook.fr>
  • Loading branch information
frouioui committed Apr 15, 2021
1 parent 18e61f0 commit 481e317
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 64 deletions.
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/querygraph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func TestQueryGraph(t *testing.T) {
t.Run(fmt.Sprintf("%d %s", i, sql), func(t *testing.T) {
tree, err := sqlparser.Parse(sql)
require.NoError(t, err)
semTable, err := semantics.Analyse(tree, "", &fakeSI{})
semTable, err := semantics.Analyze(tree, "", &fakeSI{})
require.NoError(t, err)
qgraph, err := createQGFromSelect(tree.(*sqlparser.Select), semTable)
require.NoError(t, err)
Expand All @@ -136,7 +136,7 @@ func TestQueryGraph(t *testing.T) {
func TestString(t *testing.T) {
tree, err := sqlparser.Parse("select * from a,b join c on b.id = c.id where a.id = b.id and b.col IN (select 42) and func() = 'foo'")
require.NoError(t, err)
semTable, err := semantics.Analyse(tree, "", &fakeSI{})
semTable, err := semantics.Analyze(tree, "", &fakeSI{})
require.NoError(t, err)
qgraph, err := createQGFromSelect(tree.(*sqlparser.Select), semTable)
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/route_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func newBuildSelectPlan(sel *sqlparser.Select, vschema ContextVSchema) (engine.P
if err != nil {
return nil, err
}
semTable, err := semantics.Analyse(sel, keyspace.Name, vschema)
semTable, err := semantics.Analyze(sel, keyspace.Name, vschema)
if err != nil {
return nil, err
}
Expand Down
57 changes: 26 additions & 31 deletions go/vt/vtgate/semantics/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ func newAnalyzer(dbName string, si SchemaInformation) *analyzer {
}
}

// Analyse analyzes the parsed query.
func Analyse(statement sqlparser.Statement, currentDb string, si SchemaInformation) (*SemTable, error) {
// Analyze analyzes the parsed query.
func Analyze(statement sqlparser.Statement, currentDb string, si SchemaInformation) (*SemTable, error) {
analyzer := newAnalyzer(currentDb, si)
// Initial scope
err := analyzer.analyze(statement)
Expand Down Expand Up @@ -135,31 +135,18 @@ func (a *analyzer) analyzeTableExpr(tableExpr sqlparser.TableExpr) error {

// resolveQualifiedColumn handles `tabl.col` expressions
func (a *analyzer) resolveQualifiedColumn(current *scope, expr *sqlparser.ColName) (*TableInfo, error) {
id := tableID{
dbName: expr.Qualifier.Qualifier.String(),
tableName: expr.Qualifier.Name.String(),
}
id2 := tableID{
dbName: a.currentDb,
tableName: expr.Qualifier.Name.String(),
}
checkCurrentDB := id.dbName == "" && id != id2

// search up the scope stack until we find a match
for current != nil {
tableExpr, found := current.tables[id]
if found {
return tableExpr, nil
}
if checkCurrentDB {
tableExpr, found := current.tables[id2]
if found {
return tableExpr, nil
dbName := expr.Qualifier.Qualifier.String()
tableName := expr.Qualifier.Name.String()
for _, table := range current.tables {
if tableName == table.tableName &&
(dbName == table.dbName || (dbName == "" && (table.dbName == a.currentDb || a.currentDb == ""))) {
return table, nil
}
}
current = current.parent
}

return nil, vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.BadFieldError, "Unknown table referenced by '%s'", sqlparser.String(expr))
}

Expand Down Expand Up @@ -206,8 +193,8 @@ func (a *analyzer) bindTable(alias *sqlparser.AliasedTableExpr, expr sqlparser.S
}
a.popScope()
scope := a.currentScope()
dbName := "" // derived tables are always referenced only by their alias - they cannot be found using a fully qualified name
return scope.addTable(dbName, alias.As.String(), &TableInfo{alias, nil})
//dbName := "" // derived tables are always referenced only by their alias - they cannot be found using a fully qualified name
return scope.addTable(&TableInfo{})
case sqlparser.TableName:
tbl, vdx, _, _, _, err := a.si.FindTableOrVindex(t)
if err != nil {
Expand All @@ -217,16 +204,24 @@ func (a *analyzer) bindTable(alias *sqlparser.AliasedTableExpr, expr sqlparser.S
return Gen4NotSupportedF("vindex in FROM")
}
scope := a.currentScope()
table := &TableInfo{alias, tbl}
a.Tables = append(a.Tables, table)
dbName := t.Qualifier.String()
if dbName == "" {
dbName = a.currentDb
}
var tableName string
if alias.As.IsEmpty() {
dbName := t.Qualifier.String()
if dbName == "" {
dbName = a.currentDb
}
return scope.addTable(dbName, t.Name.String(), table)
tableName = t.Name.String()
} else {
tableName = alias.As.String()
}
table := &TableInfo{
dbName: dbName,
tableName: tableName,
ASTNode: alias,
Table: tbl,
}
return scope.addTable("", alias.As.String(), table)
a.Tables = append(a.Tables, table)
return scope.addTable(table)
}
return nil
}
Expand Down
26 changes: 13 additions & 13 deletions go/vt/vtgate/semantics/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func TestBindingSingleTable(t *testing.T) {
"select col from d.tabl",
"select tabl.col from d.tabl",
"select d.tabl.col from d.tabl",
"select d.tabl.col from X as tabl",
}
for _, query := range queries {
t.Run(query, func(t *testing.T) {
Expand All @@ -90,15 +91,14 @@ func TestBindingSingleTable(t *testing.T) {
t.Run("negative tests", func(t *testing.T) {
queries := []string{
"select foo.col from tabl",
"select d.tabl.col from X as tabl",
"select tabl.col from d.tabl as X",
"select d.tabl.col from d.tabl as X",
}
for _, query := range queries {
t.Run(query, func(t *testing.T) {
parse, err := sqlparser.Parse(query)
require.NoError(t, err)
_, err = Analyse(parse, "d", &fakeSI{})
_, err = Analyze(parse, "d", &fakeSI{})
require.Error(t, err)
})
}
Expand All @@ -112,6 +112,7 @@ func TestBindingSingleAliasedTable(t *testing.T) {
"select tabl.col from X as tabl",
"select col from d.X as tabl",
"select tabl.col from d.X as tabl",
"select d.tabl.col from d.X as tabl",
}
for _, query := range queries {
t.Run(query, func(t *testing.T) {
Expand All @@ -131,13 +132,12 @@ func TestBindingSingleAliasedTable(t *testing.T) {
"select tabl.col from tabl as X",
"select X.col from X as tabl",
"select d.X.col from d.X as tabl",
"select d.tabl.col from d.X as tabl",
}
for _, query := range queries {
t.Run(query, func(t *testing.T) {
parse, err := sqlparser.Parse(query)
require.NoError(t, err)
_, err = Analyse(parse, "", &fakeSI{
_, err = Analyze(parse, "", &fakeSI{
tables: map[string]*vindexes.Table{
"t": {Name: sqlparser.NewTableIdent("t")},
},
Expand Down Expand Up @@ -191,11 +191,11 @@ func TestBindingMultiTable(t *testing.T) {
}, {
query: "select case t.col when s.col then r.col else u.col end from t, s, r, w, u",
deps: T0 | T1 | T2 | T4,
//}, {
// }, {
// // make sure that we don't let sub-query Dependencies leak out by mistake
// query: "select t.col + (select 42 from s) from t",
// deps: T0,
//}, {
// }, {
// query: "select (select 42 from s where r.id = s.id) from r",
// deps: T0 | T1,
}}
Expand All @@ -209,17 +209,17 @@ func TestBindingMultiTable(t *testing.T) {
})

t.Run("negative tests", func(t *testing.T) {
t.Skip("implement me!")
queries := []string{
"select 1 from d.tabl, d.foo as tabl",
}
for _, query := range queries {
t.Run(query, func(t *testing.T) {
parse, err := sqlparser.Parse(query)
require.NoError(t, err)
_, err = Analyse(parse, "", &fakeSI{
_, err = Analyze(parse, "d", &fakeSI{
tables: map[string]*vindexes.Table{
"t": {Name: sqlparser.NewTableIdent("t")},
"tabl": {Name: sqlparser.NewTableIdent("tabl")},
"foo": {Name: sqlparser.NewTableIdent("foo")},
},
})
require.Error(t, err)
Expand Down Expand Up @@ -252,7 +252,7 @@ func TestNotUniqueTableName(t *testing.T) {
t.Skip("derived tables not implemented")
}
parse, _ := sqlparser.Parse(query)
_, err := Analyse(parse, "test", &fakeSI{})
_, err := Analyze(parse, "test", &fakeSI{})
require.Error(t, err)
require.Contains(t, err.Error(), "Not unique table/alias")
})
Expand All @@ -267,7 +267,7 @@ func TestMissingTable(t *testing.T) {
for _, query := range queries {
t.Run(query, func(t *testing.T) {
parse, _ := sqlparser.Parse(query)
_, err := Analyse(parse, "", &fakeSI{})
_, err := Analyze(parse, "", &fakeSI{})
require.Error(t, err)
require.Contains(t, err.Error(), "Unknown table")
})
Expand Down Expand Up @@ -346,7 +346,7 @@ func TestUnknownColumnMap2(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
si := &fakeSI{tables: test.schema}
_, err := Analyse(parse, "", si)
_, err := Analyze(parse, "", si)
if test.err {
require.Error(t, err)
} else {
Expand All @@ -360,7 +360,7 @@ func parseAndAnalyze(t *testing.T, query, dbName string) (sqlparser.Statement, *
t.Helper()
parse, err := sqlparser.Parse(query)
require.NoError(t, err)
semTable, err := Analyse(parse, dbName, &fakeSI{
semTable, err := Analyze(parse, dbName, &fakeSI{
tables: map[string]*vindexes.Table{
"t": {Name: sqlparser.NewTableIdent("t")},
},
Expand Down
31 changes: 14 additions & 17 deletions go/vt/vtgate/semantics/semantic_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ import (
type (
// TableInfo contains the alias table expr and vindex table
TableInfo struct {
ASTNode *sqlparser.AliasedTableExpr
Table *vindexes.Table
dbName, tableName string
ASTNode *sqlparser.AliasedTableExpr
Table *vindexes.Table
}

// TableSet is how a set of tables is expressed.
Expand All @@ -44,13 +45,9 @@ type (
exprDependencies map[sqlparser.Expr]TableSet
}

tableID struct {
dbName, tableName string
}

scope struct {
parent *scope
tables map[tableID]*TableInfo
tables []*TableInfo
}

// SchemaInformation is used tp provide table information from Vschema.
Expand Down Expand Up @@ -104,19 +101,19 @@ func (st *SemTable) Dependencies(expr sqlparser.Expr) TableSet {
}

func newScope(parent *scope) *scope {
return &scope{tables: map[tableID]*TableInfo{}, parent: parent}
return &scope{parent: parent}
}

func (s *scope) addTable(dbName, tableName string, table *TableInfo) error {
id := tableID{
dbName: dbName,
tableName: tableName,
}
_, found := s.tables[id]
if found {
return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqTable, "Not unique table/alias: '%s'", tableName)
func (s *scope) addTable(table *TableInfo) error {
for _, scopeTable := range s.tables {
b := scopeTable.tableName == table.tableName
b2 := scopeTable.dbName == table.dbName
if b && b2 {
return vterrors.NewErrorf(vtrpcpb.Code_INVALID_ARGUMENT, vterrors.NonUniqTable, "Not unique table/alias: '%s'", table.tableName)
}
}
s.tables[id] = table

s.tables = append(s.tables, table)
return nil
}

Expand Down

0 comments on commit 481e317

Please sign in to comment.