diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 3d8d027f12a..c4d0bdd3199 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -1150,6 +1150,22 @@ func (node *Update) AddWhere(expr Expr) { } } +// AddWhere adds the boolean expression to the +// WHERE clause as an AND condition. +func (node *Delete) AddWhere(expr Expr) { + if node.Where == nil { + node.Where = &Where{ + Type: WhereClause, + Expr: expr, + } + return + } + node.Where.Expr = &AndExpr{ + Left: node.Where.Expr, + Right: expr, + } +} + // AddOrder adds an order by element func (node *Union) AddOrder(order *Order) { node.OrderBy = append(node.OrderBy, order) diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 283ee147ba4..e17548c68b1 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -401,20 +401,29 @@ func newRoutingParams(ctx *plancontext.PlanningContext, opCode engine.Opcode) *e } func transformRoutePlan(ctx *plancontext.PlanningContext, op *operators.Route) (logicalPlan, error) { - switch src := op.Source.(type) { - case *operators.Insert: - return transformInsertPlan(ctx, op, src) - case *operators.Update: - return transformUpdatePlan(ctx, op, src) - case *operators.Delete: - return transformDeletePlan(ctx, op, src) - } - condition := getVindexPredicate(ctx, op) - sel, err := operators.ToSQL(ctx, op.Source) + stmt, dmlOp, err := operators.ToSQL(ctx, op.Source) if err != nil { return nil, err } - replaceSubQuery(ctx, sel) + + replaceSubQuery(ctx, stmt) + + switch stmt := stmt.(type) { + case sqlparser.SelectStatement: + return buildRouteLogicalPlan(ctx, op, stmt) + case *sqlparser.Update: + return buildUpdateLogicalPlan(ctx, op, dmlOp, stmt) + case *sqlparser.Delete: + return buildDeleteLogicalPlan(ctx, op, dmlOp, stmt) + case *sqlparser.Insert: + return buildInsertLogicalPlan(ctx, op, dmlOp, stmt) + default: + return nil, vterrors.VT13001(fmt.Sprintf("dont know how to %T", stmt)) + } +} + +func buildRouteLogicalPlan(ctx *plancontext.PlanningContext, op *operators.Route, stmt sqlparser.SelectStatement) (logicalPlan, error) { + condition := getVindexPredicate(ctx, op) eroute, err := routeToEngineRoute(ctx, op) for _, order := range op.Ordering { typ, collation, _ := ctx.SemTable.TypeForExpr(order.AST) @@ -431,17 +440,17 @@ func transformRoutePlan(ctx *plancontext.PlanningContext, op *operators.Route) ( } return &route{ eroute: eroute, - Select: sel, + Select: stmt, tables: operators.TableID(op), condition: condition, }, nil - } -func transformInsertPlan(ctx *plancontext.PlanningContext, op *operators.Route, ins *operators.Insert) (i *insert, err error) { +func buildInsertLogicalPlan(ctx *plancontext.PlanningContext, rb *operators.Route, op ops.Operator, stmt *sqlparser.Insert) (logicalPlan, error) { + ins := op.(*operators.Insert) eins := &engine.Insert{ - Opcode: mapToInsertOpCode(op.Routing.OpCode(), ins.Input != nil), - Keyspace: op.Routing.Keyspace(), + Opcode: mapToInsertOpCode(rb.Routing.OpCode(), ins.Input != nil), + Keyspace: rb.Routing.Keyspace(), TableName: ins.VTable.Name.String(), Ignore: ins.Ignore, ForceNonStreaming: ins.ForceNonStreaming, @@ -450,7 +459,7 @@ func transformInsertPlan(ctx *plancontext.PlanningContext, op *operators.Route, VindexValues: ins.VindexValues, VindexValueOffset: ins.VindexValueOffset, } - i = &insert{eInsert: eins} + lp := &insert{eInsert: eins} // we would need to generate the query on the fly. The only exception here is // when unsharded query with autoincrement for that there is no input operator. @@ -459,15 +468,16 @@ func transformInsertPlan(ctx *plancontext.PlanningContext, op *operators.Route, } if ins.Input == nil { - eins.Query = generateQuery(ins.AST) + eins.Query = generateQuery(stmt) } else { - i.source, err = transformToLogicalPlan(ctx, ins.Input) + newSrc, err := transformToLogicalPlan(ctx, ins.Input) if err != nil { - return + return nil, err } + lp.source = newSrc } - return + return lp, nil } func mapToInsertOpCode(code engine.Opcode, insertSelect bool) engine.InsertOpcode { @@ -541,16 +551,20 @@ func dmlFormatter(buf *sqlparser.TrackedBuffer, node sqlparser.SQLNode) { node.Format(buf) } -func transformUpdatePlan(ctx *plancontext.PlanningContext, op *operators.Route, upd *operators.Update) (logicalPlan, error) { - ast := upd.AST - replaceSubQuery(ctx, ast) +func buildUpdateLogicalPlan( + ctx *plancontext.PlanningContext, + op *operators.Route, + dmlOp ops.Operator, + stmt *sqlparser.Update, +) (logicalPlan, error) { + upd := dmlOp.(*operators.Update) rp := newRoutingParams(ctx, op.Routing.OpCode()) err := op.Routing.UpdateRoutingParams(ctx, rp) if err != nil { return nil, err } edml := &engine.DML{ - Query: generateQuery(ast), + Query: generateQuery(stmt), TableNames: []string{upd.VTable.Name.String()}, Vindexes: upd.VTable.ColumnVindexes, OwnedVindexQuery: upd.OwnedVindexQuery, @@ -567,11 +581,15 @@ func transformUpdatePlan(ctx *plancontext.PlanningContext, op *operators.Route, return &primitiveWrapper{prim: e}, nil } -func transformDeletePlan(ctx *plancontext.PlanningContext, op *operators.Route, del *operators.Delete) (logicalPlan, error) { - ast := del.AST - replaceSubQuery(ctx, ast) - rp := newRoutingParams(ctx, op.Routing.OpCode()) - err := op.Routing.UpdateRoutingParams(ctx, rp) +func buildDeleteLogicalPlan( + ctx *plancontext.PlanningContext, + rb *operators.Route, + dmlOp ops.Operator, + ast *sqlparser.Delete, +) (logicalPlan, error) { + del := dmlOp.(*operators.Delete) + rp := newRoutingParams(ctx, rb.Routing.OpCode()) + err := rb.Routing.UpdateRoutingParams(ctx, rp) if err != nil { return nil, err } @@ -583,7 +601,7 @@ func transformDeletePlan(ctx *plancontext.PlanningContext, op *operators.Route, RoutingParameters: rp, } - transformDMLPlan(del.VTable, edml, op.Routing, del.OwnedVindexQuery != "") + transformDMLPlan(del.VTable, edml, rb.Routing, del.OwnedVindexQuery != "") e := &engine.Delete{ DML: edml, diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 6ea6dd43ff0..15e1833703c 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -30,20 +30,25 @@ import ( type ( queryBuilder struct { - ctx *plancontext.PlanningContext - sel sqlparser.SelectStatement - tableNames []string + ctx *plancontext.PlanningContext + stmt sqlparser.Statement + tableNames []string + dmlOperator ops.Operator } ) -func ToSQL(ctx *plancontext.PlanningContext, op ops.Operator) (sqlparser.SelectStatement, error) { +func (qb *queryBuilder) asSelectStatement() sqlparser.SelectStatement { + return qb.stmt.(sqlparser.SelectStatement) +} + +func ToSQL(ctx *plancontext.PlanningContext, op ops.Operator) (sqlparser.Statement, ops.Operator, error) { q := &queryBuilder{ctx: ctx} err := buildQuery(op, q) if err != nil { - return nil, err + return nil, nil, err } q.sortTables() - return q.sel, nil + return q.stmt, q.dmlOperator, nil } func (qb *queryBuilder) addTable(db, tableName, alias string, tableID semantics.TableSet, hints sqlparser.IndexHints) { @@ -61,10 +66,10 @@ func (qb *queryBuilder) addTableExpr( hints sqlparser.IndexHints, columnAliases sqlparser.Columns, ) { - if qb.sel == nil { - qb.sel = &sqlparser.Select{} + if qb.stmt == nil { + qb.stmt = &sqlparser.Select{} } - sel := qb.sel.(*sqlparser.Select) + sel := qb.stmt.(*sqlparser.Select) elems := &sqlparser.AliasedTableExpr{ Expr: tblExpr, Partitions: nil, @@ -74,7 +79,7 @@ func (qb *queryBuilder) addTableExpr( } qb.ctx.SemTable.ReplaceTableSetFor(tableID, elems) sel.From = append(sel.From, elems) - qb.sel = sel + qb.stmt = sel qb.tableNames = append(qb.tableNames, tableName) } @@ -85,34 +90,43 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) { return } - sel := qb.sel.(*sqlparser.Select) _, isSubQuery := expr.(*sqlparser.ExtractedSubquery) var addPred func(sqlparser.Expr) - if sqlparser.ContainsAggregation(expr) && !isSubQuery { - addPred = sel.AddHaving - } else { - addPred = sel.AddWhere + switch stmt := qb.stmt.(type) { + case *sqlparser.Select: + if sqlparser.ContainsAggregation(expr) && !isSubQuery { + addPred = stmt.AddHaving + } else { + addPred = stmt.AddWhere + } + case *sqlparser.Update: + addPred = stmt.AddWhere + case *sqlparser.Delete: + addPred = stmt.AddWhere + default: + panic(fmt.Sprintf("cant add WHERE to %T", qb.stmt)) } + for _, exp := range sqlparser.SplitAndExpression(nil, expr) { addPred(exp) } } func (qb *queryBuilder) addGroupBy(original sqlparser.Expr) { - sel := qb.sel.(*sqlparser.Select) + sel := qb.stmt.(*sqlparser.Select) sel.GroupBy = append(sel.GroupBy, original) } func (qb *queryBuilder) addProjection(projection *sqlparser.AliasedExpr) error { - switch stmt := qb.sel.(type) { + switch stmt := qb.stmt.(type) { case *sqlparser.Select: stmt.SelectExprs = append(stmt.SelectExprs, projection) return nil case *sqlparser.Union: switch expr := projection.Expr.(type) { case *sqlparser.ColName: - return checkUnionColumnByName(expr, qb.sel) + return checkUnionColumnByName(expr, stmt) default: // if there is more than just column names, we'll just push the UNION // inside a derived table and then recurse into this method again @@ -121,13 +135,14 @@ func (qb *queryBuilder) addProjection(projection *sqlparser.AliasedExpr) error { } } - return vterrors.VT13001(fmt.Sprintf("unknown select statement type: %T", qb.sel)) + return vterrors.VT13001(fmt.Sprintf("unknown select statement type: %T", qb.stmt)) } func (qb *queryBuilder) pushUnionInsideDerived() { + selStmt := qb.asSelectStatement() dt := &sqlparser.DerivedTable{ Lateral: false, - Select: qb.sel, + Select: selStmt, } sel := &sqlparser.Select{ From: []sqlparser.TableExpr{&sqlparser.AliasedTableExpr{ @@ -135,8 +150,8 @@ func (qb *queryBuilder) pushUnionInsideDerived() { As: sqlparser.NewIdentifierCS("dt"), }}, } - sel.SelectExprs = unionSelects(sqlparser.GetFirstSelect(qb.sel).SelectExprs) - qb.sel = sel + sel.SelectExprs = unionSelects(sqlparser.GetFirstSelect(selStmt).SelectExprs) + qb.stmt = sel } func unionSelects(exprs sqlparser.SelectExprs) (selectExprs sqlparser.SelectExprs) { @@ -172,7 +187,7 @@ func checkUnionColumnByName(column *sqlparser.ColName, sel sqlparser.SelectState } func (qb *queryBuilder) clearProjections() { - sel, isSel := qb.sel.(*sqlparser.Select) + sel, isSel := qb.stmt.(*sqlparser.Select) if !isSel { return } @@ -180,16 +195,16 @@ func (qb *queryBuilder) clearProjections() { } func (qb *queryBuilder) unionWith(other *queryBuilder, distinct bool) { - qb.sel = &sqlparser.Union{ - Left: qb.sel, - Right: other.sel, + qb.stmt = &sqlparser.Union{ + Left: qb.asSelectStatement(), + Right: other.asSelectStatement(), Distinct: distinct, } } func (qb *queryBuilder) joinInnerWith(other *queryBuilder, onCondition sqlparser.Expr) { - sel := qb.sel.(*sqlparser.Select) - otherSel := other.sel.(*sqlparser.Select) + sel := qb.stmt.(*sqlparser.Select) + otherSel := other.stmt.(*sqlparser.Select) sel.From = append(sel.From, otherSel.From...) sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...) @@ -210,8 +225,8 @@ func (qb *queryBuilder) joinInnerWith(other *queryBuilder, onCondition sqlparser } func (qb *queryBuilder) joinOuterWith(other *queryBuilder, onCondition sqlparser.Expr) { - sel := qb.sel.(*sqlparser.Select) - otherSel := other.sel.(*sqlparser.Select) + sel := qb.stmt.(*sqlparser.Select) + otherSel := other.stmt.(*sqlparser.Select) var lhs sqlparser.TableExpr if len(sel.From) == 1 { lhs = sel.From[0] @@ -258,7 +273,7 @@ func (qb *queryBuilder) sortTables() { } sort.Sort(ts) return true, nil - }, qb.sel) + }, qb.stmt) } @@ -370,14 +385,29 @@ func buildQuery(op ops.Operator, qb *queryBuilder) error { if err != nil { return err } - qb.sel.MakeDistinct() - return nil + qb.asSelectStatement().MakeDistinct() + case *Update: + buildDML(op, qb) + case *Delete: + buildDML(op, qb) + case *Insert: + buildDML(op, qb) default: return vterrors.VT13001(fmt.Sprintf("unknown operator to convert to SQL: %T", op)) } return nil } +type OpWithAST interface { + ops.Operator + Statement() sqlparser.Statement +} + +func buildDML(op OpWithAST, qb *queryBuilder) { + qb.stmt = op.Statement() + qb.dmlOperator = op +} + func buildAggregation(op *Aggregator, qb *queryBuilder) error { err := buildQuery(op.Source, qb) if err != nil { @@ -415,7 +445,7 @@ func buildOrdering(op *Ordering, qb *queryBuilder) error { } for _, order := range op.Order { - qb.sel.AddOrder(order.Inner) + qb.asSelectStatement().AddOrder(order.Inner) } return nil } @@ -425,7 +455,7 @@ func buildLimit(op *Limit, qb *queryBuilder) error { if err != nil { return err } - qb.sel.SetLimit(op.AST) + qb.asSelectStatement().SetLimit(op.AST) return nil } @@ -453,7 +483,7 @@ func buildProjection(op *Projection, qb *queryBuilder) error { return err } - _, isSel := qb.sel.(*sqlparser.Select) + _, isSel := qb.stmt.(*sqlparser.Select) if isSel { qb.clearProjections() @@ -468,8 +498,8 @@ func buildProjection(op *Projection, qb *queryBuilder) error { // if the projection is on derived table, we use the select we have // created above and transform it into a derived table if op.TableID != nil { - sel := qb.sel - qb.sel = nil + sel := qb.asSelectStatement() + qb.stmt = nil qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{ Select: sel, }, nil, nil) @@ -553,8 +583,8 @@ func buildDerived(op *Horizon, qb *queryBuilder) error { } sqlparser.RemoveKeyspace(op.Query) - stmt := qb.sel - qb.sel = nil + stmt := qb.stmt + qb.stmt = nil switch sel := stmt.(type) { case *sqlparser.Select: return buildDerivedSelect(op, qb, sel) @@ -610,7 +640,7 @@ func buildHorizon(op *Horizon, qb *queryBuilder) error { return err } - err = stripDownQuery(op.Query, qb.sel) + err = stripDownQuery(op.Query, qb.asSelectStatement()) if err != nil { return err } @@ -619,7 +649,7 @@ func buildHorizon(op *Horizon, qb *queryBuilder) error { removeKeyspaceFromSelectExpr(aliasedExpr) } return true, nil - }, qb.sel) + }, qb.stmt) return nil } diff --git a/go/vt/vtgate/planbuilder/operators/delete.go b/go/vt/vtgate/planbuilder/operators/delete.go index c24ab9f5065..01b3ab11520 100644 --- a/go/vt/vtgate/planbuilder/operators/delete.go +++ b/go/vt/vtgate/planbuilder/operators/delete.go @@ -65,3 +65,7 @@ func (d *Delete) GetOrdering() ([]ops.OrderBy, error) { func (d *Delete) ShortDescription() string { return fmt.Sprintf("%s.%s %s", d.VTable.Keyspace.Name, d.VTable.Name.String(), sqlparser.String(d.AST.Where)) } + +func (d *Delete) Statement() sqlparser.Statement { + return d.AST +} diff --git a/go/vt/vtgate/planbuilder/operators/insert.go b/go/vt/vtgate/planbuilder/operators/insert.go index 3fc70ed8998..78ae6cc133e 100644 --- a/go/vt/vtgate/planbuilder/operators/insert.go +++ b/go/vt/vtgate/planbuilder/operators/insert.go @@ -117,3 +117,7 @@ func (i *Insert) Clone(inputs []ops.Operator) ops.Operator { func (i *Insert) TablesUsed() []string { return SingleQualifiedIdentifier(i.VTable.Keyspace, i.VTable.Name) } + +func (i *Insert) Statement() sqlparser.Statement { + return i.AST +} diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index 0627f07734e..f523643a84e 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -68,3 +68,7 @@ func (u *Update) TablesUsed() []string { func (u *Update) ShortDescription() string { return u.VTable.String() } + +func (u *Update) Statement() sqlparser.Statement { + return u.AST +}