diff --git a/data/test/tabletserver/exec_cases.txt b/data/test/tabletserver/exec_cases.txt index 131a1a92d26..c2dc80490f2 100644 --- a/data/test/tabletserver/exec_cases.txt +++ b/data/test/tabletserver/exec_cases.txt @@ -2151,3 +2151,7 @@ options:PassthroughDMLs # syntax error "syntax error" "syntax error at position 7 near 'syntax'" + +# named locks are unsafe with server-side connection pooling +"select get_lock('foo') from dual" +"get_lock() not allowed" diff --git a/data/test/tabletserver/stream_cases.txt b/data/test/tabletserver/stream_cases.txt index a42a8eabede..fb32bfd3dc1 100644 --- a/data/test/tabletserver/stream_cases.txt +++ b/data/test/tabletserver/stream_cases.txt @@ -52,3 +52,7 @@ # syntax error "syntax error" "syntax error at position 7 near 'syntax'" + +# named locks are unsafe with server-side connection pooling +"select get_lock('foo') from dual" +"get_lock() not allowed" diff --git a/go/vt/vttablet/tabletserver/planbuilder/plan.go b/go/vt/vttablet/tabletserver/planbuilder/plan.go index ff01f85ed4d..bbeecd963ea 100644 --- a/go/vt/vttablet/tabletserver/planbuilder/plan.go +++ b/go/vt/vttablet/tabletserver/planbuilder/plan.go @@ -267,6 +267,12 @@ func (plan *Plan) setTable(tableName sqlparser.TableIdent, tables map[string]*sc func Build(statement sqlparser.Statement, tables map[string]*schema.Table) (*Plan, error) { var plan *Plan var err error + + err = checkForPoolingUnsafeConstructs(statement) + if err != nil { + return nil, err + } + switch stmt := statement.(type) { case *sqlparser.Union: plan, err = &Plan{ @@ -309,6 +315,11 @@ func BuildStreaming(sql string, tables map[string]*schema.Table) (*Plan, error) return nil, err } + err = checkForPoolingUnsafeConstructs(statement) + if err != nil { + return nil, err + } + plan := &Plan{ PlanID: PlanSelectStream, FullQuery: GenerateFullQuery(statement), @@ -350,3 +361,20 @@ func BuildMessageStreaming(name string, tables map[string]*schema.Table) (*Plan, }} return plan, nil } + +// checkForPoolingUnsafeConstructs returns an error if the SQL expression contains +// a call to GET_LOCK(), which is unsafe with server-side connection pooling. +// For more background, see https://github.com/vitessio/vitess/issues/3631. +func checkForPoolingUnsafeConstructs(expr sqlparser.SQLNode) error { + return sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + if f, ok := node.(*sqlparser.FuncExpr); ok { + if f.Name.Lowered() == "get_lock" { + return false, vterrors.New(vtrpcpb.Code_FAILED_PRECONDITION, "get_lock() not allowed") + } + } + + // TODO: This could be smarter about not walking down parts of the AST that can't contain + // function calls. + return true, nil + }, expr) +}