diff --git a/go/sqltypes/testing.go b/go/sqltypes/testing.go index 48e83370404..3a82dafe3d2 100644 --- a/go/sqltypes/testing.go +++ b/go/sqltypes/testing.go @@ -76,31 +76,6 @@ func MakeTestResult(fields []*querypb.Field, rows ...string) *Result { return result } -// MakeTestResultNoFields builds a *sqltypes.Result object for testing. -// result := sqltypes.MakeTestResult( -// fields, -// " 1|a", -// "10|abcd", -// ) -// The field type values are set as the types for the rows built. -// Spaces are trimmed from row values. "null" is treated as NULL. -func MakeTestResultNoFields(fields []*querypb.Field, rows ...string) *Result { - result := &Result{} - if len(rows) > 0 { - result.Rows = make([][]Value, len(rows)) - } - for i, row := range rows { - result.Rows[i] = make([]Value, len(fields)) - for j, col := range split(row) { - if col == "null" { - continue - } - result.Rows[i][j] = MakeTrusted(fields[j].Type, []byte(col)) - } - } - return result -} - // MakeTestStreamingResults builds a list of results for streaming. // results := sqltypes.MakeStreamingResults( // fields, diff --git a/go/vt/vtgate/engine/semi_join_test.go b/go/vt/vtgate/engine/semi_join_test.go index ff85a12baa6..13d8bcfca99 100644 --- a/go/vt/vtgate/engine/semi_join_test.go +++ b/go/vt/vtgate/engine/semi_join_test.go @@ -120,26 +120,16 @@ func TestSemiJoinStreamExecute(t *testing.T) { "col4|col5|col6", "int64|varchar|varchar", ) - rightPrim := &fakePrimitive{ // we'll return non-empty results for rows 2 and 4 - results: []*sqltypes.Result{ - // First right query will always be a GetFields. - sqltypes.MakeTestResultNoFields( - rightFields, - ), - sqltypes.MakeTestResultNoFields( - rightFields, - "4|d|dd", - ), - sqltypes.MakeTestResultNoFields( - rightFields, - ), - sqltypes.MakeTestResultNoFields( - rightFields, - "5|e|ee", - "6|f|ff", - "7|g|gg", - ), - }, + rightPrim := &fakePrimitive{ + // we'll return non-empty results for rows 2 and 4 + results: sqltypes.MakeTestStreamingResults(rightFields, + "4|d|dd", + "---", + "---", + "5|e|ee", + "6|f|ff", + "7|g|gg", + ), } jn := &SemiJoin{ diff --git a/go/vt/vtgate/planbuilder/concantenatetree.go b/go/vt/vtgate/planbuilder/concantenatetree.go index c180293c0c1..ce4dbeeac97 100644 --- a/go/vt/vtgate/planbuilder/concantenatetree.go +++ b/go/vt/vtgate/planbuilder/concantenatetree.go @@ -17,6 +17,8 @@ limitations under the License. package planbuilder import ( + "fmt" + "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" @@ -64,9 +66,9 @@ func (c *concatenateTree) pushOutputColumns(columns []*sqlparser.ColName, semTab } func (c *concatenateTree) pushPredicate(ctx *planningContext, expr sqlparser.Expr) error { - return vterrors.New(vtrpc.Code_UNIMPLEMENTED, "pushPredicate does not work on concatenate trees") + return vterrors.New(vtrpc.Code_INTERNAL, fmt.Sprintf("add '%s' predicate not supported on concatenate trees", sqlparser.String(expr))) } func (c *concatenateTree) removePredicate(ctx *planningContext, expr sqlparser.Expr) error { - return vterrors.New(vtrpc.Code_UNIMPLEMENTED, "removePredicate does not work on concatenate trees") + return vterrors.New(vtrpc.Code_INTERNAL, fmt.Sprintf("remove '%s' predicate not supported on concatenate trees", sqlparser.String(expr))) } diff --git a/go/vt/vtgate/planbuilder/derivedtree.go b/go/vt/vtgate/planbuilder/derivedtree.go index 3971aa04953..b88e788e28b 100644 --- a/go/vt/vtgate/planbuilder/derivedtree.go +++ b/go/vt/vtgate/planbuilder/derivedtree.go @@ -17,6 +17,8 @@ limitations under the License. package planbuilder import ( + "fmt" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" @@ -75,11 +77,11 @@ func (d *derivedTree) pushOutputColumns(names []*sqlparser.ColName, semTable *se } func (d *derivedTree) pushPredicate(ctx *planningContext, expr sqlparser.Expr) error { - return vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "pushPredicate does not work on derivedTrees") + return vterrors.New(vtrpcpb.Code_INTERNAL, fmt.Sprintf("add '%s' predicate not supported on derived trees", sqlparser.String(expr))) } func (d *derivedTree) removePredicate(ctx *planningContext, expr sqlparser.Expr) error { - return vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "removePredicate does not work on derivedTrees") + return vterrors.New(vtrpcpb.Code_INTERNAL, fmt.Sprintf("remove '%s' predicate not supported on derived trees", sqlparser.String(expr))) } // findOutputColumn returns the index on which the given name is found in the slice of diff --git a/go/vt/vtgate/planbuilder/jointree.go b/go/vt/vtgate/planbuilder/jointree.go index a7e62c105a3..aeca1c6c4c0 100644 --- a/go/vt/vtgate/planbuilder/jointree.go +++ b/go/vt/vtgate/planbuilder/jointree.go @@ -17,6 +17,8 @@ limitations under the License. package planbuilder import ( + "fmt" + "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" @@ -81,7 +83,7 @@ func (jp *joinTree) pushOutputColumns(columns []*sqlparser.ColName, semTable *se outputColumns := make([]int, len(toTheLeft)) var l, r int for i, isLeft := range toTheLeft { - outputColumns[i] = i + outputColumns[i] = len(jp.columns) if isLeft { jp.columns = append(jp.columns, -lhsOffset[l]-1) l++ @@ -94,19 +96,41 @@ func (jp *joinTree) pushOutputColumns(columns []*sqlparser.ColName, semTable *se } func (jp *joinTree) pushPredicate(ctx *planningContext, expr sqlparser.Expr) error { + isPushed := false if ctx.semTable.RecursiveDeps(expr).IsSolvedBy(jp.lhs.tableID()) { - return jp.lhs.pushPredicate(ctx, expr) - } else if ctx.semTable.RecursiveDeps(expr).IsSolvedBy(jp.rhs.tableID()) { - return jp.rhs.pushPredicate(ctx, expr) + if err := jp.lhs.pushPredicate(ctx, expr); err != nil { + return err + } + isPushed = true + } + if ctx.semTable.RecursiveDeps(expr).IsSolvedBy(jp.rhs.tableID()) { + if err := jp.rhs.pushPredicate(ctx, expr); err != nil { + return err + } + isPushed = true + } + if isPushed { + return nil } - return vterrors.New(vtrpc.Code_UNIMPLEMENTED, "pushPredicate does not work on joinTrees with predicates having dependencies from both the sides") + return vterrors.New(vtrpc.Code_UNIMPLEMENTED, fmt.Sprintf("add '%s' predicate not supported on cross-shard join query", sqlparser.String(expr))) } func (jp *joinTree) removePredicate(ctx *planningContext, expr sqlparser.Expr) error { + isRemoved := false if ctx.semTable.RecursiveDeps(expr).IsSolvedBy(jp.lhs.tableID()) { - return jp.lhs.removePredicate(ctx, expr) - } else if ctx.semTable.RecursiveDeps(expr).IsSolvedBy(jp.rhs.tableID()) { - return jp.rhs.removePredicate(ctx, expr) + if err := jp.lhs.removePredicate(ctx, expr); err != nil { + return err + } + isRemoved = true + } + if ctx.semTable.RecursiveDeps(expr).IsSolvedBy(jp.rhs.tableID()) { + if err := jp.rhs.removePredicate(ctx, expr); err != nil { + return err + } + isRemoved = true + } + if isRemoved { + return nil } - return vterrors.New(vtrpc.Code_UNIMPLEMENTED, "removePredicate does not work on joinTrees with predicates having dependencies from both the sides") + return vterrors.New(vtrpc.Code_UNIMPLEMENTED, fmt.Sprintf("remove '%s' predicate not supported on cross-shard join query", sqlparser.String(expr))) } diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index b12fc634545..9e53994a08e 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -201,20 +201,29 @@ func createCorrelatedSubqueryTree(ctx *planningContext, innerTree, outerTree que } vars := map[string]int{} + bindVars := map[*sqlparser.ColName]string{} for _, pred := range preds { var rewriteError error sqlparser.Rewrite(pred, func(cursor *sqlparser.Cursor) bool { switch node := cursor.Node().(type) { case *sqlparser.ColName: if ctx.semTable.RecursiveDeps(node).IsSolvedBy(outerTree.tableID()) { + // check whether the bindVariable already exists in the map + // we do so by checking that the column names are the same and their recursive dependencies are the same + // so if the column names user.a and a would also be equal if the latter is also referencing the user table + for colName, bindVar := range bindVars { + if node.Name.Equal(colName.Name) && ctx.semTable.RecursiveDeps(node).Equals(ctx.semTable.RecursiveDeps(colName)) { + cursor.Replace(sqlparser.NewArgument(bindVar)) + return false + } + } + // get the bindVariable for that column name and replace it in the predicate bindVar := ctx.reservedVars.ReserveColName(node) cursor.Replace(sqlparser.NewArgument(bindVar)) - // check whether the bindVariable already exists in the map - _, alreadyExists := vars[bindVar] - if alreadyExists { - return false - } + // store it in the map for future comparisons + bindVars[node] = bindVar + // if it does not exist, then push this as an output column in the outerTree and add it to the joinVars columnIndexes, err := outerTree.pushOutputColumns([]*sqlparser.ColName{node}, ctx.semTable) if err != nil { diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.txt b/go/vt/vtgate/planbuilder/testdata/select_cases.txt index 6e4eb9b195a..0da8b8f5773 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.txt @@ -2580,3 +2580,103 @@ Gen4 plan same as above ] } } + +# correlated subquery having dependencies on two tables +"select 1 from user u1, user u2 where exists (select 1 from user_extra ue where ue.col = u1.col and ue.col = u2.col)" +"unsupported: cross-shard correlated subquery" +{ + "QueryType": "SELECT", + "Original": "select 1 from user u1, user u2 where exists (select 1 from user_extra ue where ue.col = u1.col and ue.col = u2.col)", + "Instructions": { + "OperatorType": "SemiJoin", + "JoinVars": { + "u1_col": 0, + "u2_col": 1 + }, + "ProjectedIndexes": "-3", + "TableName": "`user`_`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "-1,1,-2", + "TableName": "`user`_`user`", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u1.col, 1 from `user` as u1 where 1 != 1", + "Query": "select u1.col, 1 from `user` as u1", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u2.col from `user` as u2 where 1 != 1", + "Query": "select u2.col from `user` as u2", + "Table": "`user`" + } + ] + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra as ue where 1 != 1", + "Query": "select 1 from user_extra as ue where ue.col = :u1_col and ue.col = :u2_col", + "Table": "user_extra" + } + ] + } +} + +# correlated subquery using a column twice +"select 1 from user u where exists (select 1 from user_extra ue where ue.col = u.col and u.col = ue.col2)" +"unsupported: cross-shard correlated subquery" +{ + "QueryType": "SELECT", + "Original": "select 1 from user u where exists (select 1 from user_extra ue where ue.col = u.col and u.col = ue.col2)", + "Instructions": { + "OperatorType": "SemiJoin", + "JoinVars": { + "u_col": 0 + }, + "ProjectedIndexes": "-2", + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u.col, 1 from `user` as u where 1 != 1", + "Query": "select u.col, 1 from `user` as u", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra as ue where 1 != 1", + "Query": "select 1 from user_extra as ue where ue.col = :u_col and ue.col2 = :u_col", + "Table": "user_extra" + } + ] + } +} diff --git a/go/vt/vtgate/planbuilder/vindextree.go b/go/vt/vtgate/planbuilder/vindextree.go index 4fa67aeb1e8..0684ecfc072 100644 --- a/go/vt/vtgate/planbuilder/vindextree.go +++ b/go/vt/vtgate/planbuilder/vindextree.go @@ -17,6 +17,8 @@ limitations under the License. package planbuilder import ( + "fmt" + "vitess.io/vitess/go/sqltypes" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" @@ -67,9 +69,9 @@ outer: } func (v *vindexTree) pushPredicate(ctx *planningContext, expr sqlparser.Expr) error { - return vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "pushPredicate does not work on vindexTrees") + return vterrors.New(vtrpcpb.Code_INTERNAL, fmt.Sprintf("add '%s' predicate not supported on vindex trees", sqlparser.String(expr))) } func (v *vindexTree) removePredicate(ctx *planningContext, expr sqlparser.Expr) error { - return vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "removePredicate does not work on vindexTrees") + return vterrors.New(vtrpcpb.Code_INTERNAL, fmt.Sprintf("remove '%s' predicate not supported on vindex trees", sqlparser.String(expr))) } diff --git a/go/vt/vtgate/semantics/tabletset.go b/go/vt/vtgate/semantics/tabletset.go index 0bc0fc01a00..3a0290a2800 100644 --- a/go/vt/vtgate/semantics/tabletset.go +++ b/go/vt/vtgate/semantics/tabletset.go @@ -215,6 +215,11 @@ func (ts TableSet) IsSolvedBy(other TableSet) bool { } } +// Equals returns true if `ts` and `other` contain the same tables +func (ts TableSet) Equals(other TableSet) bool { + return ts.IsSolvedBy(other) && other.IsSolvedBy(ts) +} + // NumberOfTables returns the number of bits set func (ts TableSet) NumberOfTables() int { if ts.large == nil {