diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 794c7894bc6..a651df58b1c 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -17,8 +17,6 @@ limitations under the License. package planbuilder import ( - "strings" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/vtgate/planbuilder/abstract" @@ -126,7 +124,7 @@ func pushProjections(ctx *planningContext, plan logicalPlan, selectExprs []abstr if err != nil { return err } - if _, _, err := pushProjection(aliasExpr, plan, ctx.semTable, true, false); err != nil { + if _, _, err := pushProjection(aliasExpr, plan, ctx.semTable, true, false, false); err != nil { return err } } @@ -157,7 +155,7 @@ func (hp *horizonPlanning) truncateColumnsIfNeeded(plan logicalPlan) error { } // pushProjection pushes a projection to the plan. -func pushProjection(expr *sqlparser.AliasedExpr, plan logicalPlan, semTable *semantics.SemTable, inner bool, reuseCol bool) (offset int, added bool, err error) { +func pushProjection(expr *sqlparser.AliasedExpr, plan logicalPlan, semTable *semantics.SemTable, inner, reuseCol, hasAggregation bool) (offset int, added bool, err error) { switch node := plan.(type) { case *route: value, err := makePlanValue(expr.Expr) @@ -203,21 +201,49 @@ func pushProjection(expr *sqlparser.AliasedExpr, plan logicalPlan, semTable *sem } switch { case deps.IsSolvedBy(lhsSolves): - offset, added, err := pushProjection(expr, node.Left, semTable, inner, passDownReuseCol) + offset, added, err := pushProjection(expr, node.Left, semTable, inner, passDownReuseCol, hasAggregation) if err != nil { return 0, false, err } column = -(offset + 1) appended = added case deps.IsSolvedBy(rhsSolves): - offset, added, err := pushProjection(expr, node.Right, semTable, inner && node.Opcode != engine.LeftJoin, passDownReuseCol) + offset, added, err := pushProjection(expr, node.Right, semTable, inner && node.Opcode != engine.LeftJoin, passDownReuseCol, hasAggregation) if err != nil { return 0, false, err } column = offset + 1 appended = added default: - return 0, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unknown dependencies for %s", sqlparser.String(expr)) + // if an expression has aggregation, then it should not be split up and pushed to both sides, + // for example an expression like count(*) will have dependencies on both sides, but we should not push it + // instead we should return an error + if hasAggregation { + return 0, false, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: cross-shard query with aggregates") + } + // now we break the expression into left and right side dependencies and rewrite the left ones to bind variables + bvName, cols, rewrittenExpr, err := breakExpressionInLHSandRHS(expr.Expr, semTable, lhsSolves) + if err != nil { + return 0, false, err + } + // go over all the columns coming from the left side of the tree and push them down. While at it, also update the bind variable map. + // It is okay to reuse the columns on the left side since + // the final expression which will be selected will be pushed into the right side. + for i, col := range cols { + colOffset, _, err := pushProjection(&sqlparser.AliasedExpr{Expr: col}, node.Left, semTable, inner, true, false) + if err != nil { + return 0, false, err + } + node.Vars[bvName[i]] = colOffset + } + // push the rewritten expression on the right side of the tree. Here we should take care whether we want to reuse the expression or not. + expr.Expr = rewrittenExpr + offset, added, err := pushProjection(expr, node.Right, semTable, inner && node.Opcode != engine.LeftJoin, passDownReuseCol, false) + if err != nil { + return 0, false, err + } + column = offset + 1 + appended = added } if reuseCol && !appended { for idx, col := range node.Cols { @@ -232,9 +258,9 @@ func pushProjection(expr *sqlparser.AliasedExpr, plan logicalPlan, semTable *sem return len(node.Cols) - 1, true, nil case *pulloutSubquery: // push projection to the outer query - return pushProjection(expr, node.underlying, semTable, inner, reuseCol) + return pushProjection(expr, node.underlying, semTable, inner, reuseCol, hasAggregation) case *simpleProjection: - offset, _, err := pushProjection(expr, node.input, semTable, inner, true) + offset, _, err := pushProjection(expr, node.input, semTable, inner, true, hasAggregation) if err != nil { return 0, false, err } @@ -266,9 +292,9 @@ func pushProjection(expr *sqlparser.AliasedExpr, plan logicalPlan, semTable *sem } return i /* col added */, len(node.eVindexFunc.Cols) > colsBefore, nil case *limit: - return pushProjection(expr, node.input, semTable, inner, reuseCol) + return pushProjection(expr, node.input, semTable, inner, reuseCol, hasAggregation) case *distinct: - return pushProjection(expr, node.input, semTable, inner, reuseCol) + return pushProjection(expr, node.input, semTable, inner, reuseCol, hasAggregation) default: return 0, false, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "[BUG] push projection does not yet support: %T", node) } @@ -371,7 +397,7 @@ func (hp *horizonPlanning) planAggregations(ctx *planningContext, plan logicalPl // push all expression if they are non-aggregating or the plan is not ordered aggregated plan. if !e.Aggr || oa == nil { - _, _, err := pushProjection(aliasExpr, plan, ctx.semTable, true, false) + _, _, err := pushProjection(aliasExpr, plan, ctx.semTable, true, false, false) if err != nil { return nil, err } @@ -392,11 +418,8 @@ func (hp *horizonPlanning) planAggregations(ctx *planningContext, plan logicalPl } pushExpr, alias, opcode := hp.createPushExprAndAlias(e, handleDistinct, innerAliased, opcode, oa) - offset, _, err := pushProjection(pushExpr, plan, ctx.semTable, true, false) + offset, _, err := pushProjection(pushExpr, plan, ctx.semTable, true, false, true) if err != nil { - if strings.HasPrefix(err.Error(), "unknown dependencies for") { - return nil, vterrors.New(vtrpcpb.Code_UNIMPLEMENTED, "unsupported: cross-shard query with aggregates") - } return nil, err } oa.eaggr.Aggregates = append(oa.eaggr.Aggregates, &engine.AggregateParams{ @@ -696,7 +719,7 @@ func checkOrderExprCanBePlannedInScatter(plan *route, order abstract.OrderBy, ha // wrapAndPushExpr pushes the expression and weighted_string function to the plan using semantics.SemTable // It returns (expr offset, weight_string offset, new_column added, error) func wrapAndPushExpr(expr sqlparser.Expr, weightStrExpr sqlparser.Expr, plan logicalPlan, semTable *semantics.SemTable) (int, int, bool, error) { - offset, added, err := pushProjection(&sqlparser.AliasedExpr{Expr: expr}, plan, semTable, true, true) + offset, added, err := pushProjection(&sqlparser.AliasedExpr{Expr: expr}, plan, semTable, true, true, false) if err != nil { return 0, 0, false, err } @@ -720,7 +743,7 @@ func wrapAndPushExpr(expr sqlparser.Expr, weightStrExpr sqlparser.Expr, plan log weightStringOffset := -1 var wAdded bool if wsNeeded { - weightStringOffset, wAdded, err = pushProjection(&sqlparser.AliasedExpr{Expr: weightStringFor(weightStrExpr)}, plan, semTable, true, true) + weightStringOffset, wAdded, err = pushProjection(&sqlparser.AliasedExpr{Expr: weightStringFor(weightStrExpr)}, plan, semTable, true, true, false) if err != nil { return 0, 0, false, err } diff --git a/go/vt/vtgate/planbuilder/route_planning.go b/go/vt/vtgate/planbuilder/route_planning.go index 7d8ef25c746..f542a7464cc 100644 --- a/go/vt/vtgate/planbuilder/route_planning.go +++ b/go/vt/vtgate/planbuilder/route_planning.go @@ -479,7 +479,7 @@ func pushJoinPredicateOnJoin(ctx *planningContext, exprs []sqlparser.Expr, node continue } - bvName, cols, predicate, err := breakPredicateInLHSandRHS(expr, ctx.semTable, node.lhs.tableID()) + bvName, cols, predicate, err := breakExpressionInLHSandRHS(expr, ctx.semTable, node.lhs.tableID()) if err != nil { return nil, err } @@ -514,13 +514,13 @@ func pushJoinPredicateOnJoin(ctx *planningContext, exprs []sqlparser.Expr, node }, nil } -func breakPredicateInLHSandRHS( +func breakExpressionInLHSandRHS( expr sqlparser.Expr, semTable *semantics.SemTable, lhs semantics.TableSet, -) (bvNames []string, columns []*sqlparser.ColName, predicate sqlparser.Expr, err error) { - predicate = sqlparser.CloneExpr(expr) - _ = sqlparser.Rewrite(predicate, nil, func(cursor *sqlparser.Cursor) bool { +) (bvNames []string, columns []*sqlparser.ColName, rewrittenExpr sqlparser.Expr, err error) { + rewrittenExpr = sqlparser.CloneExpr(expr) + _ = sqlparser.Rewrite(rewrittenExpr, nil, func(cursor *sqlparser.Cursor) bool { switch node := cursor.Node().(type) { case *sqlparser.ColName: deps := semTable.RecursiveDeps(node) diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.txt b/go/vt/vtgate/planbuilder/testdata/select_cases.txt index c6d6cc48d68..6f869b2844c 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.txt +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.txt @@ -2495,3 +2495,44 @@ Gen4 plan same as above } } Gen4 plan same as above + +# select expression having dependencies on both sides of a join +"select user.id * user_id as amount from user, user_extra" +{ + "QueryType": "SELECT", + "Original": "select user.id * user_id as amount from user, user_extra", + "Instructions": { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "1", + "JoinVars": { + "user_id": 0 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.id from `user` where 1 != 1", + "Query": "select `user`.id from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "SelectScatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select :user_id * user_id as amount from user_extra where 1 != 1", + "Query": "select :user_id * user_id as amount from user_extra", + "Table": "user_extra" + } + ] + } +} +Gen4 plan same as above