From 94e3ba7cf5ca87fd7df73bdf9f9cec809ea80d4d Mon Sep 17 00:00:00 2001 From: oscarzhao Date: Thu, 16 Mar 2017 02:09:50 +0800 Subject: [PATCH 1/6] Add context support For backward compatibility with Go versions < 1.8, the file ctx_backport.go contains a redefinition of context.Context as well as a few other things. The result is that the vast majority of the code can be shared between Go versions. --- README.md | 5 ++++ connection.go | 38 ++++++++++++++++++++-------- connection_ctx.go | 48 +++++++++++++++++++++++++++++++++++ ctx_backport.go | 61 ++++++++++++++++++++++++++++++++++++++++++++ ctx_go18.go | 27 ++++++++++++++++++++ driver.go | 9 ++++--- driver_test.go | 8 ++++++ infile.go | 6 ++--- packets.go | 64 +++++++++++++++++++++++++++++------------------ statement.go | 18 ++++++++++--- statement_ctx.go | 26 +++++++++++++++++++ transaction.go | 6 +++-- 12 files changed, 270 insertions(+), 46 deletions(-) create mode 100644 connection_ctx.go create mode 100644 ctx_backport.go create mode 100644 ctx_go18.go create mode 100644 statement_ctx.go diff --git a/README.md b/README.md index a060e3cfd..ea08db9c4 100644 --- a/README.md +++ b/README.md @@ -418,6 +418,11 @@ Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAM See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support. +## Context Support +Since go1.8, context is introduced to `database/sql` for better control on timeout and cancellation. +New interfaces such as `driver.QueryerContext`, `driver.ExecerContext` are introduced. See more details on [context support to database/sql package](https://golang.org/doc/go1.8#database_sql, "sql"). + +In Go-MySQL-Driver, we implemented these interfaces for structs `mysqlConn`, `mysqlStmt` and `mysqlTx`. ## Testing / Development To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. diff --git a/connection.go b/connection.go index cdce3e30f..2fd94c65b 100644 --- a/connection.go +++ b/connection.go @@ -42,7 +42,7 @@ func (mc *mysqlConn) handleParams() (err error) { charsets := strings.Split(val, ",") for i := range charsets { // ignore errors here - a charset may not exist - err = mc.exec("SET NAMES " + charsets[i]) + err = mc.exec(backgroundCtx(), "SET NAMES "+charsets[i]) if err == nil { break } @@ -53,7 +53,7 @@ func (mc *mysqlConn) handleParams() (err error) { // System Vars default: - err = mc.exec("SET " + param + "=" + val + "") + err = mc.exec(backgroundCtx(), "SET "+param+"="+val+"") if err != nil { return } @@ -63,12 +63,17 @@ func (mc *mysqlConn) handleParams() (err error) { return } +// Begin implements driver.Conn interface func (mc *mysqlConn) Begin() (driver.Tx, error) { + return mc.beginTx(backgroundCtx(), txOptions{}) +} + +func (mc *mysqlConn) beginTx(ctx mysqlContext, opts txOptions) (driver.Tx, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - err := mc.exec("START TRANSACTION") + err := mc.exec(ctx, "START TRANSACTION") if err == nil { return &mysqlTx{mc}, err } @@ -79,7 +84,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent if mc.netConn != nil { - err = mc.writeCommandPacket(comQuit) + err = mc.writeCommandPacket(backgroundCtx(), comQuit) } mc.cleanup() @@ -104,12 +109,16 @@ func (mc *mysqlConn) cleanup() { } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { + return mc.prepareContext(backgroundCtx(), query) +} + +func (mc *mysqlConn) prepareContext(ctx mysqlContext, query string) (driver.Stmt, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := mc.writeCommandPacketStr(comStmtPrepare, query) + err := mc.writeCommandPacketStr(ctx, comStmtPrepare, query) if err != nil { return nil, err } @@ -258,6 +267,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { + return mc.ExecContext(backgroundCtx(), query, args) +} + +func (mc *mysqlConn) ExecContext(ctx mysqlContext, query string, args []driver.Value) (driver.Result, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn @@ -276,7 +289,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err mc.affectedRows = 0 mc.insertId = 0 - err := mc.exec(query) + err := mc.exec(ctx, query) if err == nil { return &mysqlResult{ affectedRows: int64(mc.affectedRows), @@ -287,9 +300,9 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err } // Internal function to execute commands -func (mc *mysqlConn) exec(query string) error { +func (mc *mysqlConn) exec(ctx mysqlContext, query string) error { // Send command - if err := mc.writeCommandPacketStr(comQuery, query); err != nil { + if err := mc.writeCommandPacketStr(ctx, comQuery, query); err != nil { return err } @@ -314,7 +327,12 @@ func (mc *mysqlConn) exec(query string) error { return mc.discardResults() } +// Query implements driver.Queryer interface func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return mc.queryContext(backgroundCtx(), query, args) +} + +func (mc *mysqlConn) queryContext(ctx mysqlContext, query string, args []driver.Value) (driver.Rows, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn @@ -331,7 +349,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro query = prepared } // Send command - err := mc.writeCommandPacketStr(comQuery, query) + err := mc.writeCommandPacketStr(ctx, comQuery, query) if err == nil { // Read Result var resLen int @@ -362,7 +380,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro // The returned byte slice is only valid until the next read func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { // Send command - if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil { + if err := mc.writeCommandPacketStr(backgroundCtx(), comQuery, "SELECT @@"+name); err != nil { return nil, err } diff --git a/connection_ctx.go b/connection_ctx.go new file mode 100644 index 000000000..866d61ee6 --- /dev/null +++ b/connection_ctx.go @@ -0,0 +1,48 @@ +// +build go1.8 + +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "context" + "database/sql/driver" +) + +// Ping implements driver.Pinger interface +func (mc *mysqlConn) Ping(ctx context.Context) error { + if mc.netConn == nil { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + if err := mc.writeCommandPacket(ctx, comPing); err != nil { + errLog.Print(err) + return err + } + + if _, err := mc.readResultOK(); err != nil { + errLog.Print(err) + return err + } + return nil +} + +// BeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + return mc.beginTx(ctx, txOptions(opts)) +} + +func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + return mc.prepareContext(ctx, query) +} + +// QueryContext implements driver.QueryerContext interface +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.Value) (driver.Rows, error) { + return mc.queryContext(ctx, query, args) +} diff --git a/ctx_backport.go b/ctx_backport.go new file mode 100644 index 000000000..86826155c --- /dev/null +++ b/ctx_backport.go @@ -0,0 +1,61 @@ +// +build !go1.8 + +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "time" +) + +// txOptions is defined for compatibility with Go 1.8's driver.TxOptions struct. +type txOptions struct { + Isolation int + ReadOnly bool +} + +// mysqlContext is a copy of context.Context from Go 1.7 and later. +type mysqlContext interface { + Deadline() (deadline time.Time, ok bool) + + Done() <-chan struct{} + + Err() error + + Value(key interface{}) interface{} +} + +// emptyCtx is copied from Go 1.7's context package. +type emptyCtx int + +func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (*emptyCtx) Done() <-chan struct{} { + return nil +} + +func (*emptyCtx) Err() error { + return nil +} + +func (*emptyCtx) Value(key interface{}) interface{} { + return nil +} + +func (e *emptyCtx) String() string { + return "context.Background" +} + +var background = new(emptyCtx) + +func backgroundCtx() mysqlContext { + return background +} diff --git a/ctx_go18.go b/ctx_go18.go new file mode 100644 index 000000000..60394dba6 --- /dev/null +++ b/ctx_go18.go @@ -0,0 +1,27 @@ +// +build go1.8 + +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "context" + "database/sql/driver" +) + +// The definitions below are for compatibility with older Go versions. +// See ctx_backport.go for the definitions used in older Go versions. + +type txOptions driver.TxOptions + +type mysqlContext context.Context + +func backgroundCtx() mysqlContext { + return context.Background() +} diff --git a/driver.go b/driver.go index e51d98a3c..5031206ce 100644 --- a/driver.go +++ b/driver.go @@ -14,6 +14,7 @@ // db, err := sql.Open("mysql", "user:password@/dbname") // // See https://github.com/go-sql-driver/mysql#usage for details + package mysql import ( @@ -95,7 +96,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } // Send Client Authentication Packet - if err = mc.writeAuthPacket(cipher); err != nil { + if err = mc.writeAuthPacket(backgroundCtx(), cipher); err != nil { mc.cleanup() return nil, err } @@ -157,7 +158,7 @@ func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { cipher = oldCipher } - if err = mc.writeOldAuthPacket(cipher); err != nil { + if err = mc.writeOldAuthPacket(backgroundCtx(), cipher); err != nil { return err } _, err = mc.readResultOK() @@ -165,12 +166,12 @@ func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { // Retry with clear text password for // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - if err = mc.writeClearAuthPacket(); err != nil { + if err = mc.writeClearAuthPacket(backgroundCtx()); err != nil { return err } _, err = mc.readResultOK() } else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { - if err = mc.writeNativeAuthPacket(cipher); err != nil { + if err = mc.writeNativeAuthPacket(backgroundCtx(), cipher); err != nil { return err } _, err = mc.readResultOK() diff --git a/driver_test.go b/driver_test.go index 6cd9675d9..2e33a4edd 100644 --- a/driver_test.go +++ b/driver_test.go @@ -182,6 +182,14 @@ func TestEmptyQuery(t *testing.T) { }) } +func (dbt *DBTest) TestPing(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.fail("Ping", "Ping", err) + } + }) +} + func TestCRUD(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { // Create Table diff --git a/infile.go b/infile.go index 547357cfa..06b3237b7 100644 --- a/infile.go +++ b/infile.go @@ -93,7 +93,7 @@ func deferredClose(err *error, closer io.Closer) { } } -func (mc *mysqlConn) handleInFileRequest(name string) (err error) { +func (mc *mysqlConn) handleInFileRequest(ctx mysqlContext, name string) (err error) { var rdr io.Reader var data []byte packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP @@ -153,7 +153,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { for err == nil { n, err = rdr.Read(data[4:]) if n > 0 { - if ioErr := mc.writePacket(data[:4+n]); ioErr != nil { + if ioErr := mc.writePacket(ctx, data[:4+n]); ioErr != nil { return ioErr } } @@ -167,7 +167,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { if data == nil { data = make([]byte, 4) } - if ioErr := mc.writePacket(data[:4]); ioErr != nil { + if ioErr := mc.writePacket(ctx, data[:4]); ioErr != nil { return ioErr } diff --git a/packets.go b/packets.go index cb21397a2..45461821f 100644 --- a/packets.go +++ b/packets.go @@ -83,7 +83,15 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // Write packet buffer 'data' -func (mc *mysqlConn) writePacket(data []byte) error { +func (mc *mysqlConn) writePacket(ctx mysqlContext, data []byte) error { + if ctx == nil { + panic("context cannot be nil") + } + ctxDeadline, isCtxDeadlineSet := ctx.Deadline() + if isCtxDeadlineSet && !ctxDeadline.After(time.Now()) { + return errors.New("timeout") + } + pktLen := len(data) - 4 if pktLen > mc.maxAllowedPacket { @@ -106,8 +114,16 @@ func (mc *mysqlConn) writePacket(data []byte) error { data[3] = mc.sequence // Write packet + var timeNow = time.Now() + var deadline = timeNow if mc.writeTimeout > 0 { - if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil { + deadline = timeNow.Add(mc.writeTimeout) + if isCtxDeadlineSet && deadline.After(ctxDeadline) { + deadline = ctxDeadline + } + } + if deadline.After(timeNow) { + if err := mc.netConn.SetWriteDeadline(deadline); err != nil { return err } } @@ -223,7 +239,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeAuthPacket(ctx mysqlContext, cipher []byte) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -292,7 +308,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest if mc.cfg.tls != nil { // Send TLS / SSL request packet - if err := mc.writePacket(data[:(4+4+1+23)+4]); err != nil { + if err := mc.writePacket(ctx, data[:(4+4+1+23)+4]); err != nil { return err } @@ -334,12 +350,12 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { data[pos] = 0x00 // Send Auth packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } // Client old authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeOldAuthPacket(ctx mysqlContext, cipher []byte) error { // User password scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) @@ -356,12 +372,12 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { copy(data[4:], scrambleBuff) data[4+pktLen-1] = 0x00 - return mc.writePacket(data) + return mc.writePacket(ctx, data) } // Client clear text authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeClearAuthPacket() error { +func (mc *mysqlConn) writeClearAuthPacket(ctx mysqlContext) error { // Calculate the packet length and add a tailing 0 pktLen := len(mc.cfg.Passwd) + 1 data := mc.buf.takeSmallBuffer(4 + pktLen) @@ -375,12 +391,12 @@ func (mc *mysqlConn) writeClearAuthPacket() error { copy(data[4:], mc.cfg.Passwd) data[4+pktLen-1] = 0x00 - return mc.writePacket(data) + return mc.writePacket(ctx, data) } // Native password authentication method // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeNativeAuthPacket(ctx mysqlContext, cipher []byte) error { scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) // Calculate the packet length and add a tailing 0 @@ -395,14 +411,14 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { // Add the scramble copy(data[4:], scrambleBuff) - return mc.writePacket(data) + return mc.writePacket(ctx, data) } /****************************************************************************** * Command Packets * ******************************************************************************/ -func (mc *mysqlConn) writeCommandPacket(command byte) error { +func (mc *mysqlConn) writeCommandPacket(ctx mysqlContext, command byte) error { // Reset Packet Sequence mc.sequence = 0 @@ -417,10 +433,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data[4] = command // Send CMD packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } -func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { +func (mc *mysqlConn) writeCommandPacketStr(ctx mysqlContext, command byte, arg string) error { // Reset Packet Sequence mc.sequence = 0 @@ -439,10 +455,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { copy(data[5:], arg) // Send CMD packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } -func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { +func (mc *mysqlConn) writeCommandPacketUint32(ctx mysqlContext, command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 @@ -463,7 +479,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data[8] = byte(arg >> 24) // Send CMD packet - return mc.writePacket(data) + return mc.writePacket(ctx, data) } /****************************************************************************** @@ -525,7 +541,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { return 0, mc.handleErrorPacket(data) case iLocalInFile: - return 0, mc.handleInFileRequest(string(data[1:])) + return 0, mc.handleInFileRequest(backgroundCtx(), string(data[1:])) } // column count @@ -823,7 +839,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { } // http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html -func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { +func (stmt *mysqlStmt) writeCommandLongData(ctx mysqlContext, paramID int, arg []byte) error { maxLen := stmt.mc.maxAllowedPacket - 1 pktLen := maxLen @@ -860,7 +876,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { data[10] = byte(paramID >> 8) // Send CMD packet - err := stmt.mc.writePacket(data[:4+pktLen]) + err := stmt.mc.writePacket(ctx, data[:4+pktLen]) if err == nil { data = data[pktLen-dataOffset:] continue @@ -876,7 +892,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { // Execute Prepared Statement // http://dev.mysql.com/doc/internals/en/com-stmt-execute.html -func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { +func (stmt *mysqlStmt) writeExecutePacket(ctx mysqlContext, args []driver.Value) error { if len(args) != stmt.paramCount { return fmt.Errorf( "argument count mismatch (got: %d; has: %d)", @@ -1021,7 +1037,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) paramValues = append(paramValues, v...) } else { - if err := stmt.writeCommandLongData(i, v); err != nil { + if err := stmt.writeCommandLongData(ctx, i, v); err != nil { return err } } @@ -1043,7 +1059,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) paramValues = append(paramValues, v...) } else { - if err := stmt.writeCommandLongData(i, []byte(v)); err != nil { + if err := stmt.writeCommandLongData(ctx, i, []byte(v)); err != nil { return err } } @@ -1080,7 +1096,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data = data[:pos] } - return mc.writePacket(data) + return mc.writePacket(ctx, data) } func (mc *mysqlConn) discardResults() error { diff --git a/statement.go b/statement.go index b88771674..d6285e5fd 100644 --- a/statement.go +++ b/statement.go @@ -23,6 +23,7 @@ type mysqlStmt struct { columns [][]mysqlField // cached from the first query } +// Close implements driver.Conn interface func (stmt *mysqlStmt) Close() error { if stmt.mc == nil || stmt.mc.netConn == nil { // driver.Stmt.Close can be called more than once, thus this function @@ -32,11 +33,12 @@ func (stmt *mysqlStmt) Close() error { return driver.ErrBadConn } - err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id) + err := stmt.mc.writeCommandPacketUint32(backgroundCtx(), comStmtClose, stmt.id) stmt.mc = nil return err } +// NumInput implements driver.Stmt interface func (stmt *mysqlStmt) NumInput() int { return stmt.paramCount } @@ -45,13 +47,18 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { return converter{} } +// Exec implements driver.Stmt interface func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { + return stmt.execContext(backgroundCtx(), args) +} + +func (stmt *mysqlStmt) execContext(ctx mysqlContext, args []driver.Value) (driver.Result, error) { if stmt.mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := stmt.writeExecutePacket(args) + err := stmt.writeExecutePacket(ctx, args) if err != nil { return nil, err } @@ -89,13 +96,18 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { }, nil } +// Query implements driver.Stmt interface func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { + return stmt.queryContext(backgroundCtx(), args) +} + +func (stmt *mysqlStmt) queryContext(ctx mysqlContext, args []driver.Value) (driver.Rows, error) { if stmt.mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command - err := stmt.writeExecutePacket(args) + err := stmt.writeExecutePacket(ctx, args) if err != nil { return nil, err } diff --git a/statement_ctx.go b/statement_ctx.go new file mode 100644 index 000000000..8b2d1204d --- /dev/null +++ b/statement_ctx.go @@ -0,0 +1,26 @@ +// +build go1.8 + +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "context" + "database/sql/driver" +) + +// ExecContent implements driver.StmtExecContext interface +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.Value) (driver.Result, error) { + return stmt.execContext(ctx, args) +} + +// QueryContext implements driver.StmtQueryContext interface +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.Value) (driver.Rows, error) { + return stmt.queryContext(ctx, args) +} diff --git a/transaction.go b/transaction.go index 33c749b35..55002c5f9 100644 --- a/transaction.go +++ b/transaction.go @@ -12,20 +12,22 @@ type mysqlTx struct { mc *mysqlConn } +// Commit implements driver.Tx interface func (tx *mysqlTx) Commit() (err error) { if tx.mc == nil || tx.mc.netConn == nil { return ErrInvalidConn } - err = tx.mc.exec("COMMIT") + err = tx.mc.exec(backgroundCtx(), "COMMIT") tx.mc = nil return } +// Rollback implements driver.Tx interface func (tx *mysqlTx) Rollback() (err error) { if tx.mc == nil || tx.mc.netConn == nil { return ErrInvalidConn } - err = tx.mc.exec("ROLLBACK") + err = tx.mc.exec(backgroundCtx(), "ROLLBACK") tx.mc = nil return } From 669fc71ffc29cdb8b344b9c9e0b4f383f8fd176a Mon Sep 17 00:00:00 2001 From: Evan Shaw Date: Tue, 9 May 2017 08:21:30 +1200 Subject: [PATCH 2/6] Use context.DeadlineExceeded error --- ctx_backport.go | 9 +++++++++ ctx_go18.go | 2 ++ packets.go | 5 +---- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/ctx_backport.go b/ctx_backport.go index 86826155c..dba59acb3 100644 --- a/ctx_backport.go +++ b/ctx_backport.go @@ -59,3 +59,12 @@ var background = new(emptyCtx) func backgroundCtx() mysqlContext { return background } + +var deadlineExceeded = deadlineExceededError{} + +// deadlineExceededError is copied from Go 1.7's context package. +type deadlineExceededError struct{} + +func (deadlineExceededError) Error() string { return "context deadline exceeded" } +func (deadlineExceededError) Timeout() bool { return true } +func (deadlineExceededError) Temporary() bool { return true } diff --git a/ctx_go18.go b/ctx_go18.go index 60394dba6..0b8dcb9b7 100644 --- a/ctx_go18.go +++ b/ctx_go18.go @@ -25,3 +25,5 @@ type mysqlContext context.Context func backgroundCtx() mysqlContext { return context.Background() } + +var deadlineExceeded = context.DeadlineExceeded diff --git a/packets.go b/packets.go index 45461821f..da6d9c15d 100644 --- a/packets.go +++ b/packets.go @@ -84,12 +84,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // Write packet buffer 'data' func (mc *mysqlConn) writePacket(ctx mysqlContext, data []byte) error { - if ctx == nil { - panic("context cannot be nil") - } ctxDeadline, isCtxDeadlineSet := ctx.Deadline() if isCtxDeadlineSet && !ctxDeadline.After(time.Now()) { - return errors.New("timeout") + return deadlineExceeded } pktLen := len(data) - 4 From 2abb4ae488413dc6c9fbe602e71f4acd8b279948 Mon Sep 17 00:00:00 2001 From: Evan Shaw Date: Tue, 9 May 2017 08:42:58 +1200 Subject: [PATCH 3/6] Add context support for reads --- buffer.go | 16 ++++++++++----- connection.go | 30 ++++++++++++++-------------- connection_ctx.go | 2 +- driver.go | 12 +++++------ infile.go | 4 ++-- packets.go | 51 ++++++++++++++++++++++++++--------------------- packets_test.go | 18 ++++++++--------- rows.go | 18 ++++++++--------- statement.go | 14 ++++++------- 9 files changed, 88 insertions(+), 77 deletions(-) diff --git a/buffer.go b/buffer.go index 2001feacd..e23806fbb 100644 --- a/buffer.go +++ b/buffer.go @@ -38,7 +38,7 @@ func newBuffer(nc net.Conn) buffer { } // fill reads into the buffer until at least _need_ bytes are in it -func (b *buffer) fill(need int) error { +func (b *buffer) fill(ctx mysqlContext, need int) error { n := b.length // move existing data to the beginning @@ -59,8 +59,14 @@ func (b *buffer) fill(need int) error { b.idx = 0 for { - if b.timeout > 0 { - if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil { + var deadline time.Time + if ctxDeadline, ok := ctx.Deadline(); ok { + deadline = ctxDeadline + } else if b.timeout > 0 { + deadline = time.Now().Add(b.timeout) + } + if !deadline.IsZero() { + if err := b.nc.SetReadDeadline(deadline); err != nil { return err } } @@ -91,10 +97,10 @@ func (b *buffer) fill(need int) error { // returns next N bytes from buffer. // The returned slice is only guaranteed to be valid until the next read -func (b *buffer) readNext(need int) ([]byte, error) { +func (b *buffer) readNext(ctx mysqlContext, need int) ([]byte, error) { if b.length < need { // refill - if err := b.fill(need); err != nil { + if err := b.fill(ctx, need); err != nil { return nil, err } } diff --git a/connection.go b/connection.go index 2fd94c65b..f756ce8c3 100644 --- a/connection.go +++ b/connection.go @@ -128,16 +128,16 @@ func (mc *mysqlConn) prepareContext(ctx mysqlContext, query string) (driver.Stmt } // Read Result - columnCount, err := stmt.readPrepareResultPacket() + columnCount, err := stmt.readPrepareResultPacket(ctx) if err == nil { if stmt.paramCount > 0 { - if err = mc.readUntilEOF(); err != nil { + if err = mc.readUntilEOF(ctx); err != nil { return nil, err } } if columnCount > 0 { - err = mc.readUntilEOF() + err = mc.readUntilEOF(ctx) } } @@ -307,24 +307,24 @@ func (mc *mysqlConn) exec(ctx mysqlContext, query string) error { } // Read Result - resLen, err := mc.readResultSetHeaderPacket() + resLen, err := mc.readResultSetHeaderPacket(ctx) if err != nil { return err } if resLen > 0 { // columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.readUntilEOF(ctx); err != nil { return err } // rows - if err := mc.readUntilEOF(); err != nil { + if err := mc.readUntilEOF(ctx); err != nil { return err } } - return mc.discardResults() + return mc.discardResults(ctx) } // Query implements driver.Queryer interface @@ -353,7 +353,7 @@ func (mc *mysqlConn) queryContext(ctx mysqlContext, query string, args []driver. if err == nil { // Read Result var resLen int - resLen, err = mc.readResultSetHeaderPacket() + resLen, err = mc.readResultSetHeaderPacket(ctx) if err == nil { rows := new(textRows) rows.mc = mc @@ -369,7 +369,7 @@ func (mc *mysqlConn) queryContext(ctx mysqlContext, query string, args []driver. } } // Columns - rows.rs.columns, err = mc.readColumns(resLen) + rows.rs.columns, err = mc.readColumns(ctx, resLen) return rows, err } } @@ -378,14 +378,14 @@ func (mc *mysqlConn) queryContext(ctx mysqlContext, query string, args []driver. // Gets the value of the given MySQL System Variable // The returned byte slice is only valid until the next read -func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { +func (mc *mysqlConn) getSystemVar(ctx mysqlContext, name string) ([]byte, error) { // Send command - if err := mc.writeCommandPacketStr(backgroundCtx(), comQuery, "SELECT @@"+name); err != nil { + if err := mc.writeCommandPacketStr(ctx, comQuery, "SELECT @@"+name); err != nil { return nil, err } // Read Result - resLen, err := mc.readResultSetHeaderPacket() + resLen, err := mc.readResultSetHeaderPacket(ctx) if err == nil { rows := new(textRows) rows.mc = mc @@ -393,14 +393,14 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { if resLen > 0 { // Columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.readUntilEOF(ctx); err != nil { return nil, err } } dest := make([]driver.Value, resLen) - if err = rows.readRow(dest); err == nil { - return dest[0].([]byte), mc.readUntilEOF() + if err = rows.readRow(ctx, dest); err == nil { + return dest[0].([]byte), mc.readUntilEOF(ctx) } } return nil, err diff --git a/connection_ctx.go b/connection_ctx.go index 866d61ee6..a013ff9ea 100644 --- a/connection_ctx.go +++ b/connection_ctx.go @@ -26,7 +26,7 @@ func (mc *mysqlConn) Ping(ctx context.Context) error { return err } - if _, err := mc.readResultOK(); err != nil { + if _, err := mc.readResultOK(ctx); err != nil { errLog.Print(err) return err } diff --git a/driver.go b/driver.go index 5031206ce..cf72c1a51 100644 --- a/driver.go +++ b/driver.go @@ -89,7 +89,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.writeTimeout = mc.cfg.WriteTimeout // Reading Handshake Initialization Packet - cipher, err := mc.readInitPacket() + cipher, err := mc.readInitPacket(backgroundCtx()) if err != nil { mc.cleanup() return nil, err @@ -114,7 +114,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket } else { // Get max allowed packet size - maxap, err := mc.getSystemVar("max_allowed_packet") + maxap, err := mc.getSystemVar(backgroundCtx(), "max_allowed_packet") if err != nil { mc.Close() return nil, err @@ -137,7 +137,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { // Read Result Packet - cipher, err := mc.readResultOK() + cipher, err := mc.readResultOK(backgroundCtx()) if err == nil { return nil // auth successful } @@ -161,7 +161,7 @@ func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { if err = mc.writeOldAuthPacket(backgroundCtx(), cipher); err != nil { return err } - _, err = mc.readResultOK() + _, err = mc.readResultOK(backgroundCtx()) } else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { // Retry with clear text password for // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html @@ -169,12 +169,12 @@ func handleAuthResult(mc *mysqlConn, oldCipher []byte) error { if err = mc.writeClearAuthPacket(backgroundCtx()); err != nil { return err } - _, err = mc.readResultOK() + _, err = mc.readResultOK(backgroundCtx()) } else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { if err = mc.writeNativeAuthPacket(backgroundCtx(), cipher); err != nil { return err } - _, err = mc.readResultOK() + _, err = mc.readResultOK(backgroundCtx()) } return err } diff --git a/infile.go b/infile.go index 06b3237b7..6f94cb738 100644 --- a/infile.go +++ b/infile.go @@ -173,10 +173,10 @@ func (mc *mysqlConn) handleInFileRequest(ctx mysqlContext, name string) (err err // read OK packet if err == nil { - _, err = mc.readResultOK() + _, err = mc.readResultOK(ctx) return err } - mc.readPacket() + mc.readPacket(ctx) return err } diff --git a/packets.go b/packets.go index da6d9c15d..108930483 100644 --- a/packets.go +++ b/packets.go @@ -24,11 +24,16 @@ import ( // http://dev.mysql.com/doc/internals/en/client-server-protocol.html // Read packet to buffer 'data' -func (mc *mysqlConn) readPacket() ([]byte, error) { +func (mc *mysqlConn) readPacket(ctx mysqlContext) ([]byte, error) { + ctxDeadline, isCtxDeadlineSet := ctx.Deadline() + if isCtxDeadlineSet && !ctxDeadline.After(time.Now()) { + return nil, deadlineExceeded + } + var prevData []byte for { // read packet header - data, err := mc.buf.readNext(4) + data, err := mc.buf.readNext(ctx, 4) if err != nil { errLog.Print(err) mc.Close() @@ -61,7 +66,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } // read packet body [pktLen bytes] - data, err = mc.buf.readNext(pktLen) + data, err = mc.buf.readNext(ctx, pktLen) if err != nil { errLog.Print(err) mc.Close() @@ -152,8 +157,8 @@ func (mc *mysqlConn) writePacket(ctx mysqlContext, data []byte) error { // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readInitPacket() ([]byte, error) { - data, err := mc.readPacket() +func (mc *mysqlConn) readInitPacket(ctx mysqlContext) ([]byte, error) { + data, err := mc.readPacket(ctx) if err != nil { return nil, err } @@ -484,8 +489,8 @@ func (mc *mysqlConn) writeCommandPacketUint32(ctx mysqlContext, command byte, ar ******************************************************************************/ // Returns error if Packet is not an 'Result OK'-Packet -func (mc *mysqlConn) readResultOK() ([]byte, error) { - data, err := mc.readPacket() +func (mc *mysqlConn) readResultOK(ctx mysqlContext) ([]byte, error) { + data, err := mc.readPacket(ctx) if err == nil { // packet indicator switch data[0] { @@ -526,8 +531,8 @@ func (mc *mysqlConn) readResultOK() ([]byte, error) { // Result Set Header Packet // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset -func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) { - data, err := mc.readPacket() +func (mc *mysqlConn) readResultSetHeaderPacket(ctx mysqlContext) (int, error) { + data, err := mc.readPacket(ctx) if err == nil { switch data[0] { @@ -616,11 +621,11 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41 -func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { +func (mc *mysqlConn) readColumns(ctx mysqlContext, count int) ([]mysqlField, error) { columns := make([]mysqlField, count) for i := 0; ; i++ { - data, err := mc.readPacket() + data, err := mc.readPacket(ctx) if err != nil { return nil, err } @@ -709,14 +714,14 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // Read Packets as Field Packets until EOF-Packet or an Error appears // http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow -func (rows *textRows) readRow(dest []driver.Value) error { +func (rows *textRows) readRow(ctx mysqlContext, dest []driver.Value) error { mc := rows.mc if rows.rs.done { return io.EOF } - data, err := mc.readPacket() + data, err := mc.readPacket(ctx) if err != nil { return err } @@ -777,9 +782,9 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // Reads Packets until EOF-Packet or an Error appears. Returns count of Packets read -func (mc *mysqlConn) readUntilEOF() error { +func (mc *mysqlConn) readUntilEOF(ctx mysqlContext) error { for { - data, err := mc.readPacket() + data, err := mc.readPacket(ctx) if err != nil { return err } @@ -802,8 +807,8 @@ func (mc *mysqlConn) readUntilEOF() error { // Prepare Result Packets // http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html -func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { - data, err := stmt.mc.readPacket() +func (stmt *mysqlStmt) readPrepareResultPacket(ctx mysqlContext) (uint16, error) { + data, err := stmt.mc.readPacket(ctx) if err == nil { // packet indicator [1 byte] if data[0] != iOK { @@ -1096,19 +1101,19 @@ func (stmt *mysqlStmt) writeExecutePacket(ctx mysqlContext, args []driver.Value) return mc.writePacket(ctx, data) } -func (mc *mysqlConn) discardResults() error { +func (mc *mysqlConn) discardResults(ctx mysqlContext) error { for mc.status&statusMoreResultsExists != 0 { - resLen, err := mc.readResultSetHeaderPacket() + resLen, err := mc.readResultSetHeaderPacket(ctx) if err != nil { return err } if resLen > 0 { // columns - if err := mc.readUntilEOF(); err != nil { + if err := mc.readUntilEOF(ctx); err != nil { return err } // rows - if err := mc.readUntilEOF(); err != nil { + if err := mc.readUntilEOF(ctx); err != nil { return err } } @@ -1117,8 +1122,8 @@ func (mc *mysqlConn) discardResults() error { } // http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html -func (rows *binaryRows) readRow(dest []driver.Value) error { - data, err := rows.mc.readPacket() +func (rows *binaryRows) readRow(ctx mysqlContext, dest []driver.Value) error { + data, err := rows.mc.readPacket(ctx) if err != nil { return err } diff --git a/packets_test.go b/packets_test.go index b1d64f5c7..1263dcf68 100644 --- a/packets_test.go +++ b/packets_test.go @@ -96,7 +96,7 @@ func TestReadPacketSingleByte(t *testing.T) { conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} conn.maxReads = 1 - packet, err := mc.readPacket() + packet, err := mc.readPacket(backgroundCtx()) if err != nil { t.Fatal(err) } @@ -118,7 +118,7 @@ func TestReadPacketWrongSequenceID(t *testing.T) { conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} conn.maxReads = 1 mc.sequence = 1 - _, err := mc.readPacket() + _, err := mc.readPacket(backgroundCtx()) if err != ErrPktSync { t.Errorf("expected ErrPktSync, got %v", err) } @@ -130,7 +130,7 @@ func TestReadPacketWrongSequenceID(t *testing.T) { // too high sequence id conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff} - _, err = mc.readPacket() + _, err = mc.readPacket(backgroundCtx()) if err != ErrPktSyncMul { t.Errorf("expected ErrPktSyncMul, got %v", err) } @@ -166,7 +166,7 @@ func TestReadPacketSplit(t *testing.T) { conn.data = data conn.maxReads = 3 - packet, err := mc.readPacket() + packet, err := mc.readPacket(backgroundCtx()) if err != nil { t.Fatal(err) } @@ -200,7 +200,7 @@ func TestReadPacketSplit(t *testing.T) { conn.reads = 0 conn.maxReads = 5 mc.sequence = 0 - packet, err = mc.readPacket() + packet, err = mc.readPacket(backgroundCtx()) if err != nil { t.Fatal(err) } @@ -226,7 +226,7 @@ func TestReadPacketSplit(t *testing.T) { conn.reads = 0 conn.maxReads = 4 mc.sequence = 0 - packet, err = mc.readPacket() + packet, err = mc.readPacket(backgroundCtx()) if err != nil { t.Fatal(err) } @@ -250,7 +250,7 @@ func TestReadPacketFail(t *testing.T) { // illegal empty (stand-alone) packet conn.data = []byte{0x00, 0x00, 0x00, 0x00} conn.maxReads = 1 - _, err := mc.readPacket() + _, err := mc.readPacket(backgroundCtx()) if err != driver.ErrBadConn { t.Errorf("expected ErrBadConn, got %v", err) } @@ -262,7 +262,7 @@ func TestReadPacketFail(t *testing.T) { // fail to read header conn.closed = true - _, err = mc.readPacket() + _, err = mc.readPacket(backgroundCtx()) if err != driver.ErrBadConn { t.Errorf("expected ErrBadConn, got %v", err) } @@ -275,7 +275,7 @@ func TestReadPacketFail(t *testing.T) { // fail to read body conn.maxReads = 1 - _, err = mc.readPacket() + _, err = mc.readPacket(backgroundCtx()) if err != driver.ErrBadConn { t.Errorf("expected ErrBadConn, got %v", err) } diff --git a/rows.go b/rows.go index 900f548ae..2a6ffae78 100644 --- a/rows.go +++ b/rows.go @@ -74,10 +74,10 @@ func (rows *mysqlRows) Close() (err error) { // Remove unread packets from stream if !rows.rs.done { - err = mc.readUntilEOF() + err = mc.readUntilEOF(backgroundCtx()) } if err == nil { - if err = mc.discardResults(); err != nil { + if err = mc.discardResults(backgroundCtx()); err != nil { return err } } @@ -103,7 +103,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { // Remove unread packets from stream if !rows.rs.done { - if err := rows.mc.readUntilEOF(); err != nil { + if err := rows.mc.readUntilEOF(backgroundCtx()); err != nil { return 0, err } rows.rs.done = true @@ -114,7 +114,7 @@ func (rows *mysqlRows) nextResultSet() (int, error) { return 0, io.EOF } rows.rs = resultSet{} - return rows.mc.readResultSetHeaderPacket() + return rows.mc.readResultSetHeaderPacket(backgroundCtx()) } func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { @@ -140,11 +140,11 @@ func (rows *binaryRows) NextResultSet() (err error) { // get columns, if not cached, read them and cache them. if rows.i >= len(*rows.stmtCols) { - rows.rs.columns, err = rows.mc.readColumns(resLen) + rows.rs.columns, err = rows.mc.readColumns(backgroundCtx(), resLen) *rows.stmtCols = append(*rows.stmtCols, rows.rs.columns) } else { rows.rs.columns = (*rows.stmtCols)[rows.i] - if err := rows.mc.readUntilEOF(); err != nil { + if err := rows.mc.readUntilEOF(backgroundCtx()); err != nil { return err } } @@ -160,7 +160,7 @@ func (rows *binaryRows) Next(dest []driver.Value) error { } // Fetch next row from stream - return rows.readRow(dest) + return rows.readRow(backgroundCtx(), dest) } return io.EOF } @@ -171,7 +171,7 @@ func (rows *textRows) NextResultSet() (err error) { return err } - rows.rs.columns, err = rows.mc.readColumns(resLen) + rows.rs.columns, err = rows.mc.readColumns(backgroundCtx(), resLen) return err } @@ -182,7 +182,7 @@ func (rows *textRows) Next(dest []driver.Value) error { } // Fetch next row from stream - return rows.readRow(dest) + return rows.readRow(backgroundCtx(), dest) } return io.EOF } diff --git a/statement.go b/statement.go index d6285e5fd..1f05c6c89 100644 --- a/statement.go +++ b/statement.go @@ -69,24 +69,24 @@ func (stmt *mysqlStmt) execContext(ctx mysqlContext, args []driver.Value) (drive mc.insertId = 0 // Read Result - resLen, err := mc.readResultSetHeaderPacket() + resLen, err := mc.readResultSetHeaderPacket(ctx) if err != nil { return nil, err } if resLen > 0 { // Columns - if err = mc.readUntilEOF(); err != nil { + if err = mc.readUntilEOF(ctx); err != nil { return nil, err } // Rows - if err := mc.readUntilEOF(); err != nil { + if err := mc.readUntilEOF(ctx); err != nil { return nil, err } } - if err := mc.discardResults(); err != nil { + if err := mc.discardResults(ctx); err != nil { return nil, err } @@ -115,7 +115,7 @@ func (stmt *mysqlStmt) queryContext(ctx mysqlContext, args []driver.Value) (driv mc := stmt.mc // Read Result - resLen, err := mc.readResultSetHeaderPacket() + resLen, err := mc.readResultSetHeaderPacket(ctx) if err != nil { return nil, err } @@ -129,11 +129,11 @@ func (stmt *mysqlStmt) queryContext(ctx mysqlContext, args []driver.Value) (driv // Columns // If not cached, read them and cache them if len(stmt.columns) == 0 { - rows.rs.columns, err = mc.readColumns(resLen) + rows.rs.columns, err = mc.readColumns(ctx, resLen) stmt.columns = append(stmt.columns, rows.rs.columns) } else { rows.rs.columns = stmt.columns[0] - err = mc.readUntilEOF() + err = mc.readUntilEOF(ctx) } } else { rows.rs.done = true From 691ec98f74ff25424017959038202780ee5e847d Mon Sep 17 00:00:00 2001 From: Evan Shaw Date: Tue, 9 May 2017 08:44:17 +1200 Subject: [PATCH 4/6] Rework write deadline logic --- packets.go | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/packets.go b/packets.go index 108930483..64703a131 100644 --- a/packets.go +++ b/packets.go @@ -116,15 +116,13 @@ func (mc *mysqlConn) writePacket(ctx mysqlContext, data []byte) error { data[3] = mc.sequence // Write packet - var timeNow = time.Now() - var deadline = timeNow - if mc.writeTimeout > 0 { - deadline = timeNow.Add(mc.writeTimeout) - if isCtxDeadlineSet && deadline.After(ctxDeadline) { - deadline = ctxDeadline - } + var deadline time.Time + if ctxDeadline, ok := ctx.Deadline(); ok { + deadline = ctxDeadline + } else if mc.writeTimeout > 0 { + deadline = time.Now().Add(mc.writeTimeout) } - if deadline.After(timeNow) { + if !deadline.IsZero() { if err := mc.netConn.SetWriteDeadline(deadline); err != nil { return err } From 6d7e8045352814ac0fe0de0f24b5545726f18cac Mon Sep 17 00:00:00 2001 From: Evan Shaw Date: Thu, 11 May 2017 08:39:10 +1200 Subject: [PATCH 5/6] README improvements --- README.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ea08db9c4..5fe55198c 100644 --- a/README.md +++ b/README.md @@ -419,10 +419,9 @@ Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAM See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support. ## Context Support -Since go1.8, context is introduced to `database/sql` for better control on timeout and cancellation. -New interfaces such as `driver.QueryerContext`, `driver.ExecerContext` are introduced. See more details on [context support to database/sql package](https://golang.org/doc/go1.8#database_sql, "sql"). - -In Go-MySQL-Driver, we implemented these interfaces for structs `mysqlConn`, `mysqlStmt` and `mysqlTx`. +Go 1.8 added some `database/sql` methods that accept a `context.Context` parameter for better control over timeout and cancellation. +See more details on [context support to database/sql package](https://golang.org/doc/go1.8#database_sql, "sql"). +Go-MySQL-Driver supports context deadlines, but not cancellation. ## Testing / Development To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. From ce38c4f00c087c0452ce17012e8f69065792d36b Mon Sep 17 00:00:00 2001 From: Evan Shaw Date: Sat, 13 May 2017 21:29:46 +1200 Subject: [PATCH 6/6] Add compile-time interface tests and fix method signatures --- connection.go | 4 ++-- connection_ctx.go | 29 +++++++++++++++++++++++++++-- connection_ctx_test.go | 23 +++++++++++++++++++++++ statement_ctx.go | 16 ++++++++++++---- statement_ctx_test.go | 20 ++++++++++++++++++++ 5 files changed, 84 insertions(+), 8 deletions(-) create mode 100644 connection_ctx_test.go create mode 100644 statement_ctx_test.go diff --git a/connection.go b/connection.go index f756ce8c3..5379abccd 100644 --- a/connection.go +++ b/connection.go @@ -267,10 +267,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { - return mc.ExecContext(backgroundCtx(), query, args) + return mc.execContext(backgroundCtx(), query, args) } -func (mc *mysqlConn) ExecContext(ctx mysqlContext, query string, args []driver.Value) (driver.Result, error) { +func (mc *mysqlConn) execContext(ctx mysqlContext, query string, args []driver.Value) (driver.Result, error) { if mc.netConn == nil { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn diff --git a/connection_ctx.go b/connection_ctx.go index a013ff9ea..98d8b02f2 100644 --- a/connection_ctx.go +++ b/connection_ctx.go @@ -13,6 +13,7 @@ package mysql import ( "context" "database/sql/driver" + "errors" ) // Ping implements driver.Pinger interface @@ -43,6 +44,30 @@ func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.S } // QueryContext implements driver.QueryerContext interface -func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.Value) (driver.Rows, error) { - return mc.queryContext(ctx, query, args) +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + values, err := namedValueToValue(args) + if err != nil { + return nil, err + } + return mc.queryContext(ctx, query, values) +} + +// ExecContext implements driver.ExecerContext interface +func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + values, err := namedValueToValue(args) + if err != nil { + return nil, err + } + return mc.execContext(ctx, query, values) +} + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + return nil, errors.New("mysql: Named Parameters are not supported") + } + dargs[n] = param.Value + } + return dargs, nil } diff --git a/connection_ctx_test.go b/connection_ctx_test.go new file mode 100644 index 000000000..465f3fcef --- /dev/null +++ b/connection_ctx_test.go @@ -0,0 +1,23 @@ +// +build go1.8 + +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/.package mysql + +package mysql + +import ( + "database/sql/driver" +) + +var ( + _ driver.ConnBeginTx = &mysqlConn{} + _ driver.ConnPrepareContext = &mysqlConn{} + _ driver.ExecerContext = &mysqlConn{} + _ driver.Pinger = &mysqlConn{} + _ driver.QueryerContext = &mysqlConn{} +) diff --git a/statement_ctx.go b/statement_ctx.go index 8b2d1204d..89c600cd6 100644 --- a/statement_ctx.go +++ b/statement_ctx.go @@ -16,11 +16,19 @@ import ( ) // ExecContent implements driver.StmtExecContext interface -func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.Value) (driver.Result, error) { - return stmt.execContext(ctx, args) +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + values, err := namedValueToValue(args) + if err != nil { + return nil, err + } + return stmt.execContext(ctx, values) } // QueryContext implements driver.StmtQueryContext interface -func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.Value) (driver.Rows, error) { - return stmt.queryContext(ctx, args) +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + values, err := namedValueToValue(args) + if err != nil { + return nil, err + } + return stmt.queryContext(ctx, values) } diff --git a/statement_ctx_test.go b/statement_ctx_test.go new file mode 100644 index 000000000..60eeeef18 --- /dev/null +++ b/statement_ctx_test.go @@ -0,0 +1,20 @@ +// +build go1.8 + +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/.package mysql + +package mysql + +import ( + "database/sql/driver" +) + +var ( + _ driver.StmtExecContext = &mysqlStmt{} + _ driver.StmtQueryContext = &mysqlStmt{} +)