Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context support #586

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAM

See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support.

## Context Support
Go 1.8 added some `database/sql` methods that accept a `context.Context` parameter for better control over timeout and cancellation.
See more details on [context support to database/sql package](https://golang.org/doc/go1.8#database_sql, "sql").
Go-MySQL-Driver supports context deadlines, but not cancellation.

## Testing / Development
To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details.
Expand Down
16 changes: 11 additions & 5 deletions buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func newBuffer(nc net.Conn) buffer {
}

// fill reads into the buffer until at least _need_ bytes are in it
func (b *buffer) fill(need int) error {
func (b *buffer) fill(ctx mysqlContext, need int) error {
n := b.length

// move existing data to the beginning
Expand All @@ -59,8 +59,14 @@ func (b *buffer) fill(need int) error {
b.idx = 0

for {
if b.timeout > 0 {
if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil {
var deadline time.Time
if ctxDeadline, ok := ctx.Deadline(); ok {
deadline = ctxDeadline
} else if b.timeout > 0 {
deadline = time.Now().Add(b.timeout)
}
if !deadline.IsZero() {
if err := b.nc.SetReadDeadline(deadline); err != nil {
return err
}
}
Expand Down Expand Up @@ -91,10 +97,10 @@ func (b *buffer) fill(need int) error {

// returns next N bytes from buffer.
// The returned slice is only guaranteed to be valid until the next read
func (b *buffer) readNext(need int) ([]byte, error) {
func (b *buffer) readNext(ctx mysqlContext, need int) ([]byte, error) {
if b.length < need {
// refill
if err := b.fill(need); err != nil {
if err := b.fill(ctx, need); err != nil {
return nil, err
}
}
Expand Down
66 changes: 42 additions & 24 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (mc *mysqlConn) handleParams() (err error) {
charsets := strings.Split(val, ",")
for i := range charsets {
// ignore errors here - a charset may not exist
err = mc.exec("SET NAMES " + charsets[i])
err = mc.exec(backgroundCtx(), "SET NAMES "+charsets[i])
if err == nil {
break
}
Expand All @@ -53,7 +53,7 @@ func (mc *mysqlConn) handleParams() (err error) {

// System Vars
default:
err = mc.exec("SET " + param + "=" + val + "")
err = mc.exec(backgroundCtx(), "SET "+param+"="+val+"")
if err != nil {
return
}
Expand All @@ -63,12 +63,17 @@ func (mc *mysqlConn) handleParams() (err error) {
return
}

// Begin implements driver.Conn interface
func (mc *mysqlConn) Begin() (driver.Tx, error) {
return mc.beginTx(backgroundCtx(), txOptions{})
}

func (mc *mysqlConn) beginTx(ctx mysqlContext, opts txOptions) (driver.Tx, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
err := mc.exec("START TRANSACTION")
err := mc.exec(ctx, "START TRANSACTION")
if err == nil {
return &mysqlTx{mc}, err
}
Expand All @@ -79,7 +84,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {
func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if mc.netConn != nil {
err = mc.writeCommandPacket(comQuit)
err = mc.writeCommandPacket(backgroundCtx(), comQuit)
}

mc.cleanup()
Expand All @@ -104,12 +109,16 @@ func (mc *mysqlConn) cleanup() {
}

func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
return mc.prepareContext(backgroundCtx(), query)
}

func (mc *mysqlConn) prepareContext(ctx mysqlContext, query string) (driver.Stmt, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query)
err := mc.writeCommandPacketStr(ctx, comStmtPrepare, query)
if err != nil {
return nil, err
}
Expand All @@ -119,16 +128,16 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
}

// Read Result
columnCount, err := stmt.readPrepareResultPacket()
columnCount, err := stmt.readPrepareResultPacket(ctx)
if err == nil {
if stmt.paramCount > 0 {
if err = mc.readUntilEOF(); err != nil {
if err = mc.readUntilEOF(ctx); err != nil {
return nil, err
}
}

if columnCount > 0 {
err = mc.readUntilEOF()
err = mc.readUntilEOF(ctx)
}
}

Expand Down Expand Up @@ -258,6 +267,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
}

func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
return mc.execContext(backgroundCtx(), query, args)
}

func (mc *mysqlConn) execContext(ctx mysqlContext, query string, args []driver.Value) (driver.Result, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
Expand All @@ -276,7 +289,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
mc.affectedRows = 0
mc.insertId = 0

err := mc.exec(query)
err := mc.exec(ctx, query)
if err == nil {
return &mysqlResult{
affectedRows: int64(mc.affectedRows),
Expand All @@ -287,34 +300,39 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
}

// Internal function to execute commands
func (mc *mysqlConn) exec(query string) error {
func (mc *mysqlConn) exec(ctx mysqlContext, query string) error {
// Send command
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
if err := mc.writeCommandPacketStr(ctx, comQuery, query); err != nil {
return err
}

// Read Result
resLen, err := mc.readResultSetHeaderPacket()
resLen, err := mc.readResultSetHeaderPacket(ctx)
if err != nil {
return err
}

if resLen > 0 {
// columns
if err := mc.readUntilEOF(); err != nil {
if err := mc.readUntilEOF(ctx); err != nil {
return err
}

// rows
if err := mc.readUntilEOF(); err != nil {
if err := mc.readUntilEOF(ctx); err != nil {
return err
}
}

return mc.discardResults()
return mc.discardResults(ctx)
}

// Query implements driver.Queryer interface
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
return mc.queryContext(backgroundCtx(), query, args)
}

func (mc *mysqlConn) queryContext(ctx mysqlContext, query string, args []driver.Value) (driver.Rows, error) {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
Expand All @@ -331,11 +349,11 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
query = prepared
}
// Send command
err := mc.writeCommandPacketStr(comQuery, query)
err := mc.writeCommandPacketStr(ctx, comQuery, query)
if err == nil {
// Read Result
var resLen int
resLen, err = mc.readResultSetHeaderPacket()
resLen, err = mc.readResultSetHeaderPacket(ctx)
if err == nil {
rows := new(textRows)
rows.mc = mc
Expand All @@ -351,7 +369,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
}
}
// Columns
rows.rs.columns, err = mc.readColumns(resLen)
rows.rs.columns, err = mc.readColumns(ctx, resLen)
return rows, err
}
}
Expand All @@ -360,29 +378,29 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro

// Gets the value of the given MySQL System Variable
// The returned byte slice is only valid until the next read
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
func (mc *mysqlConn) getSystemVar(ctx mysqlContext, name string) ([]byte, error) {
// Send command
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
if err := mc.writeCommandPacketStr(ctx, comQuery, "SELECT @@"+name); err != nil {
return nil, err
}

// Read Result
resLen, err := mc.readResultSetHeaderPacket()
resLen, err := mc.readResultSetHeaderPacket(ctx)
if err == nil {
rows := new(textRows)
rows.mc = mc
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}

if resLen > 0 {
// Columns
if err := mc.readUntilEOF(); err != nil {
if err := mc.readUntilEOF(ctx); err != nil {
return nil, err
}
}

dest := make([]driver.Value, resLen)
if err = rows.readRow(dest); err == nil {
return dest[0].([]byte), mc.readUntilEOF()
if err = rows.readRow(ctx, dest); err == nil {
return dest[0].([]byte), mc.readUntilEOF(ctx)
}
}
return nil, err
Expand Down
73 changes: 73 additions & 0 deletions connection_ctx.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// +build go1.8

// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.

package mysql

import (
"context"
"database/sql/driver"
"errors"
)

// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) error {
if mc.netConn == nil {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
if err := mc.writeCommandPacket(ctx, comPing); err != nil {
errLog.Print(err)
return err
}

if _, err := mc.readResultOK(ctx); err != nil {
errLog.Print(err)
return err
}
return nil
}

// BeginTx implements driver.ConnBeginTx interface
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return mc.beginTx(ctx, txOptions(opts))
}

func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
return mc.prepareContext(ctx, query)
}

// QueryContext implements driver.QueryerContext interface
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
values, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return mc.queryContext(ctx, query, values)
}

// ExecContext implements driver.ExecerContext interface
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
values, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return mc.execContext(ctx, query, values)
}

func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
return nil, errors.New("mysql: Named Parameters are not supported")
}
dargs[n] = param.Value
}
return dargs, nil
}
23 changes: 23 additions & 0 deletions connection_ctx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// +build go1.8

// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.package mysql

package mysql

import (
"database/sql/driver"
)

var (
_ driver.ConnBeginTx = &mysqlConn{}
_ driver.ConnPrepareContext = &mysqlConn{}
_ driver.ExecerContext = &mysqlConn{}
_ driver.Pinger = &mysqlConn{}
_ driver.QueryerContext = &mysqlConn{}
)
Loading