From 7e64321024d9d68fca764cfabf6e3f1556dbf63a Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Mon, 5 Jun 2017 13:39:40 +0900 Subject: [PATCH 01/18] Add supports to context.Context --- benchmark_go18_test.go | 93 +++++++++++++ connection.go | 87 ++++++++++-- connection_go18.go | 234 +++++++++++++++++++++++++++++++ driver.go | 15 +- driver_go18_test.go | 308 +++++++++++++++++++++++++++++++++++++++++ driver_test.go | 8 ++ packets.go | 11 ++ packets_test.go | 3 +- rows.go | 34 ++++- statement.go | 6 +- transaction.go | 4 +- 11 files changed, 774 insertions(+), 29 deletions(-) create mode 100644 benchmark_go18_test.go create mode 100644 connection_go18.go diff --git a/benchmark_go18_test.go b/benchmark_go18_test.go new file mode 100644 index 000000000..d6a7e9d6e --- /dev/null +++ b/benchmark_go18_test.go @@ -0,0 +1,93 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.8 + +package mysql + +import ( + "context" + "database/sql" + "fmt" + "runtime" + "testing" +) + +func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0)) + + tb := (*TB)(b) + stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?")) + defer stmt.Close() + + b.SetParallelism(p) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + var got string + for pb.Next() { + tb.check(stmt.QueryRow(1).Scan(&got)) + if got != "one" { + b.Fatalf("query = %q; want one", got) + } + } + }) +} + +func BenchmarkQueryContext(b *testing.B) { + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + defer db.Close() + for _, p := range []int{1, 2, 3, 4} { + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + benchmarkQueryContext(b, db, p) + }) + } +} + +func benchmarkExecContext(b *testing.B, db *sql.DB, p int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0)) + + tb := (*TB)(b) + stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1")) + defer stmt.Close() + + b.SetParallelism(p) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := stmt.ExecContext(ctx); err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkExecContext(b *testing.B) { + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + defer db.Close() + for _, p := range []int{1, 2, 3, 4} { + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + benchmarkQueryContext(b, db, p) + }) + } +} diff --git a/connection.go b/connection.go index cdce3e30f..fd97a987b 100644 --- a/connection.go +++ b/connection.go @@ -10,13 +10,23 @@ package mysql import ( "database/sql/driver" + "errors" "io" "net" "strconv" "strings" + "sync" "time" ) +//a copy of context.Context from Go 1.7 and later. +type mysqlContext interface { + Deadline() (deadline time.Time, ok bool) + Done() <-chan struct{} + Err() error + Value(key interface{}) interface{} +} + type mysqlConn struct { buf buffer netConn net.Conn @@ -31,6 +41,13 @@ type mysqlConn struct { sequence uint8 parseTime bool strict bool + watcher chan<- mysqlContext + closech chan struct{} + finished chan<- struct{} + + mu sync.Mutex // guards following fields + closed error // set non-nil when conn is closed, before closech is closed + canceledErr error // set non-nil if conn is canceled } // Handles parameters set in DSN after the connection is established @@ -64,7 +81,7 @@ func (mc *mysqlConn) handleParams() (err error) { } func (mc *mysqlConn) Begin() (driver.Tx, error) { - if mc.netConn == nil { + if mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -78,11 +95,11 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) { func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent - if mc.netConn != nil { + if !mc.isBroken() { err = mc.writeCommandPacket(comQuit) } - mc.cleanup() + mc.cleanup(errors.New("mysql: connection is closed")) return } @@ -91,20 +108,36 @@ func (mc *mysqlConn) Close() (err error) { // function after successfully authentication, call Close instead. This function // is called before auth or on auth failure because MySQL will have already // closed the network connection. -func (mc *mysqlConn) cleanup() { +func (mc *mysqlConn) cleanup(err error) { + if err == nil { + panic("nil error") + } + mc.mu.Lock() + defer mc.mu.Unlock() + + if mc.closed != nil { + return + } + // Makes cleanup idempotent - if mc.netConn != nil { - if err := mc.netConn.Close(); err != nil { - errLog.Print(err) - } - mc.netConn = nil + mc.closed = err + close(mc.closech) + if mc.netConn == nil { + return + } + if err := mc.netConn.Close(); err != nil { + errLog.Print(err) } - mc.cfg = nil - mc.buf.nc = nil +} + +func (mc *mysqlConn) isBroken() bool { + mc.mu.Lock() + defer mc.mu.Unlock() + return mc.closed != nil } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { - if mc.netConn == nil { + if mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -258,7 +291,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { - if mc.netConn == nil { + if mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -315,7 +348,7 @@ func (mc *mysqlConn) exec(query string) error { } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { - if mc.netConn == nil { + if mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -387,3 +420,29 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { } return nil, err } + +// finish is called when the query has canceled. +func (mc *mysqlConn) cancel(err error) { + mc.mu.Lock() + mc.canceledErr = err + mc.mu.Unlock() + mc.cleanup(errors.New("mysql: query canceled")) +} + +// canceled returns non-nil if the connection was closed due to context cancelation. +func (mc *mysqlConn) canceled() error { + mc.mu.Lock() + defer mc.mu.Unlock() + return mc.canceledErr +} + +// finish is called when the query has succeeded. +func (mc *mysqlConn) finish() { + if mc.finished == nil { + return + } + select { + case mc.finished <- struct{}{}: + case <-mc.closech: + } +} diff --git a/connection_go18.go b/connection_go18.go new file mode 100644 index 000000000..61ca84cac --- /dev/null +++ b/connection_go18.go @@ -0,0 +1,234 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.8 + +package mysql + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" +) + +type setfinish interface { + setFinish(f func()) +} + +// Ping implements driver.Pinger interface +func (mc *mysqlConn) Ping(ctx context.Context) error { + if mc.isBroken() { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + if err := mc.watchCancel(ctx); err != nil { + return err + } + defer mc.finish() + + if err := mc.writeCommandPacket(comPing); err != nil { + return err + } + if _, err := mc.readResultOK(); err != nil { + return err + } + + return nil +} + +// BeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { + return nil, errors.New("mysql: isolation levels not supported") + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + var err error + var tx driver.Tx + if opts.ReadOnly { + tx, err = mc.beginReadOnly() + } else { + tx, err = mc.Begin() + } + mc.finish() + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + tx.Rollback() + return nil, ctx.Err() + } + return tx, err +} + +func (mc *mysqlConn) beginReadOnly() (driver.Tx, error) { + if mc.isBroken() { + errLog.Print(ErrInvalidConn) + return nil, driver.ErrBadConn + } + // https://dev.mysql.com/doc/refman/5.7/en/innodb-performance-ro-txn.html + err := mc.exec("START TRANSACTION READ ONLY") + if err != nil { + return nil, err + } + + return &mysqlTx{mc}, nil +} + +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := mc.Query(query, dargs) + if err != nil { + mc.finish() + return nil, err + } + if set, ok := rows.(setfinish); ok { + set.setFinish(mc.finish) + } else { + mc.finish() + } + return rows, err +} + +func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + return mc.Exec(query, dargs) +} + +func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + stmt, err := mc.Prepare(query) + mc.finish() + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + stmt.Close() + return nil, ctx.Err() + } + return stmt, nil +} + +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := stmt.Query(dargs) + if err != nil { + stmt.mc.finish() + return nil, err + } + if set, ok := rows.(setfinish); ok { + set.setFinish(stmt.mc.finish) + } else { + stmt.mc.finish() + } + return rows, err +} + +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + defer stmt.mc.finish() + + return stmt.Exec(dargs) +} + +func (mc *mysqlConn) watchCancel(ctx context.Context) error { + select { + default: + case <-ctx.Done(): + return ctx.Err() + } + if mc.watcher == nil { + return nil + } + + mc.watcher <- ctx + + return nil +} + +func (mc *mysqlConn) startWatcher() { + watcher := make(chan mysqlContext, 1) + mc.watcher = watcher + finished := make(chan struct{}) + mc.finished = finished + go func() { + for { + var ctx mysqlContext + select { + case ctx = <-watcher: + case <-mc.closech: + return + } + + select { + case <-ctx.Done(): + mc.cancel(ctx.Err()) + case <-finished: + case <-mc.closech: + return + } + } + }() +} + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + return nil, errors.New("mysql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} diff --git a/driver.go b/driver.go index e51d98a3c..ce223dacc 100644 --- a/driver.go +++ b/driver.go @@ -52,6 +52,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc := &mysqlConn{ maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), } mc.cfg, err = ParseDSN(dsn) if err != nil { @@ -60,6 +61,14 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.parseTime = mc.cfg.ParseTime mc.strict = mc.cfg.Strict + // Call startWatcher for context support (From Go 1.8) + type starter interface { + startWatcher() + } + if s, ok := interface{}(mc).(starter); ok { + s.startWatcher() + } + // Connect to Server if dial, ok := dials[mc.cfg.Net]; ok { mc.netConn, err = dial(mc.cfg.Addr) @@ -90,13 +99,13 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { // Reading Handshake Initialization Packet cipher, err := mc.readInitPacket() if err != nil { - mc.cleanup() + mc.cleanup(err) return nil, err } // Send Client Authentication Packet if err = mc.writeAuthPacket(cipher); err != nil { - mc.cleanup() + mc.cleanup(err) return nil, err } @@ -105,7 +114,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { // Authentication failed and MySQL has already closed the connection // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. - mc.cleanup() + mc.cleanup(err) return nil, err } diff --git a/driver_go18_test.go b/driver_go18_test.go index 5a5fa10ff..e7050043d 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -11,10 +11,13 @@ package mysql import ( + "context" "database/sql" + "database/sql/driver" "fmt" "reflect" "testing" + "time" ) func TestMultiResultSet(t *testing.T) { @@ -196,3 +199,308 @@ func TestSkipResults(t *testing.T) { } }) } + +func TestPingContext(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := dbt.db.PingContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +var ( + _ driver.ConnBeginTx = &mysqlConn{} + _ driver.ConnPrepareContext = &mysqlConn{} + _ driver.ExecerContext = &mysqlConn{} + _ driver.Pinger = &mysqlConn{} + _ driver.QueryerContext = &mysqlConn{} +) + +func TestContextCancelExec(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query has done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Errorf("expected val to be 1, got %d", v) + } + + // Context is already canceled, so error should come before execution. + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil { + dbt.Error("expected error") + } else if err.Error() != "context canceled" { + dbt.Fatalf("unexpected error: %s", err) + } + + // The second insert query will fail, so the table has no changes. + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Errorf("expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query has done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Errorf("expected val to be 1, got %d", v) + } + + // Context is already canceled, so error should come before execution. + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + + // The second insert query will fail, so the table has no changes. + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Errorf("expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelQueryRow(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + dbt.mustExec("INSERT INTO test VALUES (1), (2), (3)") + ctx, cancel := context.WithCancel(context.Background()) + + rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM test") + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + // the first row will be succeed. + var v int + if !rows.Next() { + dbt.Fatalf("unexpected end") + } + if err := rows.Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + + cancel() + // make sure the driver recieve cancel request. + time.Sleep(100 * time.Millisecond) + + if rows.Next() { + dbt.Errorf("expected end, but not") + } + if err := rows.Err(); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelPrepare(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := dbt.db.PrepareContext(ctx, "SELECT 1"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +var ( + _ driver.StmtExecContext = &mysqlStmt{} + _ driver.StmtQueryContext = &mysqlStmt{} +) + +func TestContextCancelStmtExec(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + if err != nil { + dbt.Fatalf("unexpected error: %v", err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := stmt.ExecContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query has done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Errorf("expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelStmtQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + if err != nil { + dbt.Fatalf("unexpected error: %v", err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := stmt.QueryContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query has done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Errorf("expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelBegin(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + tx, err := dbt.db.BeginTx(ctx, nil) + if err != nil { + dbt.Fatal(err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Transaction is canceled, so expect an error. + switch err := tx.Commit(); err { + case sql.ErrTxDone: + // because the transaction has already been rollbacked. + // the database/sql package watches ctx + // and rollbacks when ctx is canceled. + case context.Canceled: + // the database/sql package rollbacks on another goroutine, + // so the transaction may not be rollbacked depending on goroutine scheduling. + default: + dbt.Errorf("expected sql.ErrTxDone or context.Canceled, got %v", err) + } + + // Context is canceled, so cannot begin a transaction. + if _, err := dbt.db.BeginTx(ctx, nil); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextBeginReadOnly(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ + ReadOnly: true, + }) + if _, ok := err.(*MySQLError); ok { + dbt.Skip("It seems that your MySQL does not support READ ONLY transactions") + return + } else if err != nil { + dbt.Fatal(err) + } + + // INSERT queries fail in a READ ONLY transaction. + _, err = tx.ExecContext(ctx, "INSERT INTO test VALUES (1)") + if _, ok := err.(*MySQLError); !ok { + dbt.Errorf("expected MySQLError, got %v", err) + } + + // SELECT queries can be executed. + var v int + row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + if err := row.Scan(&v); err != nil { + dbt.Fatal(err) + } + if v != 0 { + dbt.Errorf("expected val to be 0, got %d", v) + } + + if err := tx.Commit(); err != nil { + dbt.Fatal(err) + } + }) +} diff --git a/driver_test.go b/driver_test.go index 6ca5434a9..206e07cc9 100644 --- a/driver_test.go +++ b/driver_test.go @@ -1991,3 +1991,11 @@ func TestRejectReadOnly(t *testing.T) { dbt.mustExec("DROP TABLE test") }) } + +func TestPing(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.fail("Ping", "Ping", err) + } + }) +} diff --git a/packets.go b/packets.go index 303405a17..2b9caf94b 100644 --- a/packets.go +++ b/packets.go @@ -30,6 +30,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // read packet header data, err := mc.buf.readNext(4) if err != nil { + if cerr := mc.canceled(); cerr != nil { + return nil, cerr + } errLog.Print(err) mc.Close() return nil, driver.ErrBadConn @@ -63,6 +66,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { // read packet body [pktLen bytes] data, err = mc.buf.readNext(pktLen) if err != nil { + if cerr := mc.canceled(); cerr != nil { + return nil, cerr + } errLog.Print(err) mc.Close() return nil, driver.ErrBadConn @@ -125,8 +131,13 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handle error if err == nil { // n != len(data) + mc.cleanup(ErrMalformPkt) errLog.Print(ErrMalformPkt) } else { + if cerr := mc.canceled(); cerr != nil { + return cerr + } + mc.cleanup(err) errLog.Print(err) } return driver.ErrBadConn diff --git a/packets_test.go b/packets_test.go index b1d64f5c7..31c892d85 100644 --- a/packets_test.go +++ b/packets_test.go @@ -244,7 +244,8 @@ func TestReadPacketSplit(t *testing.T) { func TestReadPacketFail(t *testing.T) { conn := new(mockConn) mc := &mysqlConn{ - buf: newBuffer(conn), + buf: newBuffer(conn), + closech: make(chan struct{}), } // illegal empty (stand-alone) packet diff --git a/rows.go b/rows.go index 13905e216..0437eb0d3 100644 --- a/rows.go +++ b/rows.go @@ -28,8 +28,9 @@ type resultSet struct { } type mysqlRows struct { - mc *mysqlConn - rs resultSet + mc *mysqlConn + rs resultSet + finish func() } type binaryRows struct { @@ -64,12 +65,24 @@ func (rows *mysqlRows) Columns() []string { return columns } +func (rows *mysqlRows) setFinish(f func()) { + rows.finish = f +} + func (rows *mysqlRows) Close() (err error) { + if f := rows.finish; f != nil { + f() + rows.finish = nil + } + mc := rows.mc if mc == nil { return nil } - if mc.netConn == nil { + if mc.isBroken() { + if err := mc.canceled(); err != nil { + return err + } return ErrInvalidConn } @@ -98,7 +111,10 @@ func (rows *mysqlRows) nextResultSet() (int, error) { if rows.mc == nil { return 0, io.EOF } - if rows.mc.netConn == nil { + if rows.mc.isBroken() { + if err := rows.mc.canceled(); err != nil { + return 0, err + } return 0, ErrInvalidConn } @@ -145,7 +161,10 @@ func (rows *binaryRows) NextResultSet() error { func (rows *binaryRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { - if mc.netConn == nil { + if mc.isBroken() { + if err := mc.canceled(); err != nil { + return err + } return ErrInvalidConn } @@ -167,7 +186,10 @@ func (rows *textRows) NextResultSet() (err error) { func (rows *textRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { - if mc.netConn == nil { + if mc.isBroken() { + if err := mc.canceled(); err != nil { + return err + } return ErrInvalidConn } diff --git a/statement.go b/statement.go index e5071276a..1f95a3021 100644 --- a/statement.go +++ b/statement.go @@ -23,7 +23,7 @@ type mysqlStmt struct { } func (stmt *mysqlStmt) Close() error { - if stmt.mc == nil || stmt.mc.netConn == nil { + if stmt.mc == nil || stmt.mc.isBroken() { // driver.Stmt.Close can be called more than once, thus this function // has to be idempotent. // See also Issue #450 and golang/go#16019. @@ -45,7 +45,7 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { - if stmt.mc.netConn == nil { + if stmt.mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -89,7 +89,7 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { - if stmt.mc.netConn == nil { + if stmt.mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } diff --git a/transaction.go b/transaction.go index 33c749b35..5d88c0399 100644 --- a/transaction.go +++ b/transaction.go @@ -13,7 +13,7 @@ type mysqlTx struct { } func (tx *mysqlTx) Commit() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.isBroken() { return ErrInvalidConn } err = tx.mc.exec("COMMIT") @@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { } func (tx *mysqlTx) Rollback() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.isBroken() { return ErrInvalidConn } err = tx.mc.exec("ROLLBACK") From 54ef181d47ea11a955f9bf46ee18226a68e18d98 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Mon, 5 Jun 2017 16:49:51 +0900 Subject: [PATCH 02/18] add authors related context.Context support --- AUTHORS | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/AUTHORS b/AUTHORS index 0a22d46e7..5f27aa457 100644 --- a/AUTHORS +++ b/AUTHORS @@ -14,6 +14,7 @@ Aaron Hopkins Arne Hormann Asta Xie +Bulat Gaifullin Carlos Nieto Chris Moos Daniel Nichter @@ -21,11 +22,13 @@ Daniël van Eeden Dave Protasowski DisposaBoy Egor Smolyakov +Evan Shaw Frederick Mayle Gustavo Kristic Hanno Braun Henri Yandell Hirotaka Yamamoto +ICHINOSE Shogo INADA Naoki Jacek Szwec James Harr @@ -45,6 +48,7 @@ Luke Scott Michael Woolnough Nicola Peduzzi Olivier Mengué +oscarzhao Paul Bonser Peter Schultz Rebecca Chin From 66fa1379298ac26f645c08b897084fe4b4a3fd75 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Mon, 5 Jun 2017 16:53:42 +0900 Subject: [PATCH 03/18] fix comment of mysqlContext. - s/from/for/ - and start the comment with a space please --- connection.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connection.go b/connection.go index fd97a987b..15dc6a249 100644 --- a/connection.go +++ b/connection.go @@ -19,7 +19,7 @@ import ( "time" ) -//a copy of context.Context from Go 1.7 and later. +// a copy of context.Context for Go 1.7 and later. type mysqlContext interface { Deadline() (deadline time.Time, ok bool) Done() <-chan struct{} From 06b17e606b7022ff488fb583c9ca735a822baec2 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Mon, 5 Jun 2017 16:56:58 +0900 Subject: [PATCH 04/18] closed is now just bool flag. --- connection.go | 18 +++++++----------- driver.go | 6 +++--- packets.go | 4 ++-- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/connection.go b/connection.go index 15dc6a249..247ddc326 100644 --- a/connection.go +++ b/connection.go @@ -10,7 +10,6 @@ package mysql import ( "database/sql/driver" - "errors" "io" "net" "strconv" @@ -46,7 +45,7 @@ type mysqlConn struct { finished chan<- struct{} mu sync.Mutex // guards following fields - closed error // set non-nil when conn is closed, before closech is closed + closed bool // set true when conn is closed, before closech is closed canceledErr error // set non-nil if conn is canceled } @@ -99,7 +98,7 @@ func (mc *mysqlConn) Close() (err error) { err = mc.writeCommandPacket(comQuit) } - mc.cleanup(errors.New("mysql: connection is closed")) + mc.cleanup() return } @@ -108,19 +107,16 @@ func (mc *mysqlConn) Close() (err error) { // function after successfully authentication, call Close instead. This function // is called before auth or on auth failure because MySQL will have already // closed the network connection. -func (mc *mysqlConn) cleanup(err error) { - if err == nil { - panic("nil error") - } +func (mc *mysqlConn) cleanup() { mc.mu.Lock() defer mc.mu.Unlock() - if mc.closed != nil { + if mc.closed { return } // Makes cleanup idempotent - mc.closed = err + mc.closed = true close(mc.closech) if mc.netConn == nil { return @@ -133,7 +129,7 @@ func (mc *mysqlConn) cleanup(err error) { func (mc *mysqlConn) isBroken() bool { mc.mu.Lock() defer mc.mu.Unlock() - return mc.closed != nil + return mc.closed } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { @@ -426,7 +422,7 @@ func (mc *mysqlConn) cancel(err error) { mc.mu.Lock() mc.canceledErr = err mc.mu.Unlock() - mc.cleanup(errors.New("mysql: query canceled")) + mc.cleanup() } // canceled returns non-nil if the connection was closed due to context cancelation. diff --git a/driver.go b/driver.go index ce223dacc..04c609df1 100644 --- a/driver.go +++ b/driver.go @@ -99,13 +99,13 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { // Reading Handshake Initialization Packet cipher, err := mc.readInitPacket() if err != nil { - mc.cleanup(err) + mc.cleanup() return nil, err } // Send Client Authentication Packet if err = mc.writeAuthPacket(cipher); err != nil { - mc.cleanup(err) + mc.cleanup() return nil, err } @@ -114,7 +114,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { // Authentication failed and MySQL has already closed the connection // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. - mc.cleanup(err) + mc.cleanup() return nil, err } diff --git a/packets.go b/packets.go index 2b9caf94b..0bc120c0e 100644 --- a/packets.go +++ b/packets.go @@ -131,13 +131,13 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handle error if err == nil { // n != len(data) - mc.cleanup(ErrMalformPkt) + mc.cleanup() errLog.Print(ErrMalformPkt) } else { if cerr := mc.canceled(); cerr != nil { return cerr } - mc.cleanup(err) + mc.cleanup() errLog.Print(err) } return driver.ErrBadConn From 170555024d94cb373da5d91ed1c8d962b472d39f Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 6 Jun 2017 09:29:27 +0900 Subject: [PATCH 05/18] drop read-only transactions support --- connection_go18.go | 25 ++++--------------------- driver_go18_test.go | 38 -------------------------------------- 2 files changed, 4 insertions(+), 59 deletions(-) diff --git a/connection_go18.go b/connection_go18.go index 61ca84cac..aef085ba7 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -48,18 +48,15 @@ func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { return nil, errors.New("mysql: isolation levels not supported") } + if opts.ReadOnly { + return nil, errors.New("mysql: read-only transactions not supported") + } if err := mc.watchCancel(ctx); err != nil { return nil, err } - var err error - var tx driver.Tx - if opts.ReadOnly { - tx, err = mc.beginReadOnly() - } else { - tx, err = mc.Begin() - } + tx, err := mc.Begin() mc.finish() if err != nil { return nil, err @@ -74,20 +71,6 @@ func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver return tx, err } -func (mc *mysqlConn) beginReadOnly() (driver.Tx, error) { - if mc.isBroken() { - errLog.Print(ErrInvalidConn) - return nil, driver.ErrBadConn - } - // https://dev.mysql.com/doc/refman/5.7/en/innodb-performance-ro-txn.html - err := mc.exec("START TRANSACTION READ ONLY") - if err != nil { - return nil, err - } - - return &mysqlTx{mc}, nil -} - func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { dargs, err := namedValueToValue(args) if err != nil { diff --git a/driver_go18_test.go b/driver_go18_test.go index e7050043d..46e0be41b 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -466,41 +466,3 @@ func TestContextCancelBegin(t *testing.T) { } }) } - -func TestContextBeginReadOnly(t *testing.T) { - runTests(t, dsn, func(dbt *DBTest) { - dbt.mustExec("CREATE TABLE test (v INTEGER)") - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - tx, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ - ReadOnly: true, - }) - if _, ok := err.(*MySQLError); ok { - dbt.Skip("It seems that your MySQL does not support READ ONLY transactions") - return - } else if err != nil { - dbt.Fatal(err) - } - - // INSERT queries fail in a READ ONLY transaction. - _, err = tx.ExecContext(ctx, "INSERT INTO test VALUES (1)") - if _, ok := err.(*MySQLError); !ok { - dbt.Errorf("expected MySQLError, got %v", err) - } - - // SELECT queries can be executed. - var v int - row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") - if err := row.Scan(&v); err != nil { - dbt.Fatal(err) - } - if v != 0 { - dbt.Errorf("expected val to be 0, got %d", v) - } - - if err := tx.Commit(); err != nil { - dbt.Fatal(err) - } - }) -} From 41940ff8b972de7548d9faa1286a10c04380288c Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 6 Jun 2017 09:33:31 +0900 Subject: [PATCH 06/18] remove unused methods from mysqlContext interface. --- connection.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/connection.go b/connection.go index 247ddc326..519949c9c 100644 --- a/connection.go +++ b/connection.go @@ -20,10 +20,12 @@ import ( // a copy of context.Context for Go 1.7 and later. type mysqlContext interface { - Deadline() (deadline time.Time, ok bool) Done() <-chan struct{} Err() error - Value(key interface{}) interface{} + + // They are defined in context.Context, but go-mysql-driver does not use them. + // Deadline() (deadline time.Time, ok bool) + // Value(key interface{}) interface{} } type mysqlConn struct { From fd8a559aca53ae673b39609cde9fb8cfdc02e85c Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 6 Jun 2017 09:38:24 +0900 Subject: [PATCH 07/18] moved checking canceled logic into method of connection. --- connection.go | 10 ++++++++++ rows.go | 14 ++++---------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/connection.go b/connection.go index 519949c9c..aabbb67c0 100644 --- a/connection.go +++ b/connection.go @@ -134,6 +134,16 @@ func (mc *mysqlConn) isBroken() bool { return mc.closed } +func (mc *mysqlConn) error() error { + if mc.isBroken() { + if err := mc.canceled(); err != nil { + return err + } + return ErrInvalidConn + } + return nil +} + func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { if mc.isBroken() { errLog.Print(ErrInvalidConn) diff --git a/rows.go b/rows.go index 0437eb0d3..0204fd71a 100644 --- a/rows.go +++ b/rows.go @@ -161,11 +161,8 @@ func (rows *binaryRows) NextResultSet() error { func (rows *binaryRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { - if mc.isBroken() { - if err := mc.canceled(); err != nil { - return err - } - return ErrInvalidConn + if err := mc.error(); err != nil { + return err } // Fetch next row from stream @@ -186,11 +183,8 @@ func (rows *textRows) NextResultSet() (err error) { func (rows *textRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { - if mc.isBroken() { - if err := mc.canceled(); err != nil { - return err - } - return ErrInvalidConn + if err := mc.error(); err != nil { + return err } // Fetch next row from stream From a464739ac5ded4d5bfe3b221e554eb4adcdeaf0d Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 6 Jun 2017 09:44:51 +0900 Subject: [PATCH 08/18] add a section about context.Context to the README. --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 76a35d32e..103908734 100644 --- a/README.md +++ b/README.md @@ -443,6 +443,9 @@ Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAM See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support. +## Context Support +Go 1.8 added some `database/sql` methods that accept a `context.Context` parameter for better control over timeout and cancellation. +See more details on [context support to database/sql package](https://golang.org/doc/go1.8#database_sql, "sql"). ## Testing / Development To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. From 1fdad70c332c21a079a976739fcf0c5c0e1b24b2 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 6 Jun 2017 10:00:59 +0900 Subject: [PATCH 09/18] use atomic variable for closed. --- connection.go | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/connection.go b/connection.go index aabbb67c0..398747351 100644 --- a/connection.go +++ b/connection.go @@ -15,6 +15,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" ) @@ -46,8 +47,11 @@ type mysqlConn struct { closech chan struct{} finished chan<- struct{} + // set non-zero when conn is closed, before closech is closed. + // accessed atomically. + closed int32 + mu sync.Mutex // guards following fields - closed bool // set true when conn is closed, before closech is closed canceledErr error // set non-nil if conn is canceled } @@ -110,15 +114,11 @@ func (mc *mysqlConn) Close() (err error) { // is called before auth or on auth failure because MySQL will have already // closed the network connection. func (mc *mysqlConn) cleanup() { - mc.mu.Lock() - defer mc.mu.Unlock() - - if mc.closed { + if atomic.SwapInt32(&mc.closed, 1) != 0 { return } // Makes cleanup idempotent - mc.closed = true close(mc.closech) if mc.netConn == nil { return @@ -129,9 +129,7 @@ func (mc *mysqlConn) cleanup() { } func (mc *mysqlConn) isBroken() bool { - mc.mu.Lock() - defer mc.mu.Unlock() - return mc.closed + return atomic.LoadInt32(&mc.closed) != 0 } func (mc *mysqlConn) error() error { From f96feaa79cfaabc6dffb5803c2a12188df14d914 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 6 Jun 2017 14:34:15 +0900 Subject: [PATCH 10/18] short circuit for context.Background() --- connection.go | 12 ++++++++---- connection_go18.go | 11 +++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/connection.go b/connection.go index 398747351..38b297b95 100644 --- a/connection.go +++ b/connection.go @@ -43,9 +43,12 @@ type mysqlConn struct { sequence uint8 parseTime bool strict bool - watcher chan<- mysqlContext - closech chan struct{} - finished chan<- struct{} + + // for context support (From Go 1.8) + watching bool + watcher chan<- mysqlContext + closech chan struct{} + finished chan<- struct{} // set non-zero when conn is closed, before closech is closed. // accessed atomically. @@ -444,11 +447,12 @@ func (mc *mysqlConn) canceled() error { // finish is called when the query has succeeded. func (mc *mysqlConn) finish() { - if mc.finished == nil { + if !mc.watching || mc.finished == nil { return } select { case mc.finished <- struct{}{}: + mc.watching = false case <-mc.closech: } } diff --git a/connection_go18.go b/connection_go18.go index aef085ba7..d275051b2 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -166,6 +166,17 @@ func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue } func (mc *mysqlConn) watchCancel(ctx context.Context) error { + if mc.watching { + err := errors.New("mysql: illegal watching state") + errLog.Print(err) + mc.cleanup() + return err + } + if ctx.Done() == nil { + return nil + } + + mc.watching = true select { default: case <-ctx.Done(): From 4ce2087b2a98f5857525e949e6dc57752e2783ac Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 6 Jun 2017 14:57:23 +0900 Subject: [PATCH 11/18] fix illegal watching state --- connection_go18.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/connection_go18.go b/connection_go18.go index d275051b2..b29c96edf 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -167,10 +167,10 @@ func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue func (mc *mysqlConn) watchCancel(ctx context.Context) error { if mc.watching { - err := errors.New("mysql: illegal watching state") - errLog.Print(err) + // Reach here if canceled, + // so the connection is already invalid mc.cleanup() - return err + return nil } if ctx.Done() == nil { return nil From 208cb44a2526bedc04db3b27beaa201247c3ef80 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 6 Jun 2017 11:17:39 +0900 Subject: [PATCH 12/18] set rows.finish directly. --- connection.go | 4 ++++ connection_go18.go | 20 ++++---------------- rows.go | 4 ---- statement.go | 4 ++++ 4 files changed, 12 insertions(+), 20 deletions(-) diff --git a/connection.go b/connection.go index 38b297b95..89a4f464a 100644 --- a/connection.go +++ b/connection.go @@ -357,6 +357,10 @@ func (mc *mysqlConn) exec(query string) error { } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { + return mc.query(query, args) +} + +func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { if mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn diff --git a/connection_go18.go b/connection_go18.go index b29c96edf..908670344 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -17,10 +17,6 @@ import ( "errors" ) -type setfinish interface { - setFinish(f func()) -} - // Ping implements driver.Pinger interface func (mc *mysqlConn) Ping(ctx context.Context) error { if mc.isBroken() { @@ -81,16 +77,12 @@ func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driv return nil, err } - rows, err := mc.Query(query, dargs) + rows, err := mc.query(query, dargs) if err != nil { mc.finish() return nil, err } - if set, ok := rows.(setfinish); ok { - set.setFinish(mc.finish) - } else { - mc.finish() - } + rows.finish = mc.finish return rows, err } @@ -138,16 +130,12 @@ func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValu return nil, err } - rows, err := stmt.Query(dargs) + rows, err := stmt.query(dargs) if err != nil { stmt.mc.finish() return nil, err } - if set, ok := rows.(setfinish); ok { - set.setFinish(stmt.mc.finish) - } else { - stmt.mc.finish() - } + rows.finish = stmt.mc.finish return rows, err } diff --git a/rows.go b/rows.go index 0204fd71a..f7266915a 100644 --- a/rows.go +++ b/rows.go @@ -65,10 +65,6 @@ func (rows *mysqlRows) Columns() []string { return columns } -func (rows *mysqlRows) setFinish(f func()) { - rows.finish = f -} - func (rows *mysqlRows) Close() (err error) { if f := rows.finish; f != nil { f() diff --git a/statement.go b/statement.go index 1f95a3021..5855f943f 100644 --- a/statement.go +++ b/statement.go @@ -89,6 +89,10 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { + return stmt.query(args) +} + +func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { if stmt.mc.isBroken() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn From 31a72667281cf440bed9aeb9c06d348edf0c51e6 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 6 Jun 2017 15:06:56 +0900 Subject: [PATCH 13/18] move namedValueToValue to utils_go18.go --- connection_go18.go | 11 ----------- utils_go18.go | 17 ++++++++++++++++- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/connection_go18.go b/connection_go18.go index 908670344..330d76894 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -203,14 +203,3 @@ func (mc *mysqlConn) startWatcher() { } }() } - -func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { - dargs := make([]driver.Value, len(named)) - for n, param := range named { - if len(param.Name) > 0 { - return nil, errors.New("mysql: driver does not support the use of Named Parameters") - } - dargs[n] = param.Value - } - return dargs, nil -} diff --git a/utils_go18.go b/utils_go18.go index 2aa9d0f18..a7e690c8c 100644 --- a/utils_go18.go +++ b/utils_go18.go @@ -10,8 +10,23 @@ package mysql -import "crypto/tls" +import ( + "crypto/tls" + "database/sql/driver" + "errors" +) func cloneTLSConfig(c *tls.Config) *tls.Config { return c.Clone() } + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + return nil, errors.New("mysql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} From 09fbdfaad3aa739487a0818e4e51f37b16b308a5 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 6 Jun 2017 15:08:58 +0900 Subject: [PATCH 14/18] move static interface implementation checks to the top of the file --- driver_go18_test.go | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/driver_go18_test.go b/driver_go18_test.go index 46e0be41b..69d0a2b7d 100644 --- a/driver_go18_test.go +++ b/driver_go18_test.go @@ -20,6 +20,21 @@ import ( "time" ) +// static interface implementation checks of mysqlConn +var ( + _ driver.ConnBeginTx = &mysqlConn{} + _ driver.ConnPrepareContext = &mysqlConn{} + _ driver.ExecerContext = &mysqlConn{} + _ driver.Pinger = &mysqlConn{} + _ driver.QueryerContext = &mysqlConn{} +) + +// static interface implementation checks of mysqlStmt +var ( + _ driver.StmtExecContext = &mysqlStmt{} + _ driver.StmtQueryContext = &mysqlStmt{} +) + func TestMultiResultSet(t *testing.T) { type result struct { values [][]int @@ -210,14 +225,6 @@ func TestPingContext(t *testing.T) { }) } -var ( - _ driver.ConnBeginTx = &mysqlConn{} - _ driver.ConnPrepareContext = &mysqlConn{} - _ driver.ExecerContext = &mysqlConn{} - _ driver.Pinger = &mysqlConn{} - _ driver.QueryerContext = &mysqlConn{} -) - func TestContextCancelExec(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") @@ -351,11 +358,6 @@ func TestContextCancelPrepare(t *testing.T) { }) } -var ( - _ driver.StmtExecContext = &mysqlStmt{} - _ driver.StmtQueryContext = &mysqlStmt{} -) - func TestContextCancelStmtExec(t *testing.T) { runTests(t, dsn, func(dbt *DBTest) { dbt.mustExec("CREATE TABLE test (v INTEGER)") From c4f9ae6917688895e3abc5e525ac8c623b1ed944 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 6 Jun 2017 15:43:43 +0900 Subject: [PATCH 15/18] add the new section about `context.Context` to the table of contents, and fix the section. --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 103908734..d39b29c90 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support) * [time.Time support](#timetime-support) * [Unicode support](#unicode-support) + * [context.Context Support](#contextcontext-support) * [Testing / Development](#testing--development) * [License](#license) @@ -443,9 +444,9 @@ Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAM See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support. -## Context Support -Go 1.8 added some `database/sql` methods that accept a `context.Context` parameter for better control over timeout and cancellation. -See more details on [context support to database/sql package](https://golang.org/doc/go1.8#database_sql, "sql"). +## `context.Context` Support +Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts. +See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details. ## Testing / Development To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details. From 87ba95a204af232c4319ce79ade4909365e9fe9e Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Tue, 6 Jun 2017 16:33:32 +0900 Subject: [PATCH 16/18] mark unsupported features with TODO comments --- connection_go18.go | 2 ++ utils_go18.go | 1 + 2 files changed, 3 insertions(+) diff --git a/connection_go18.go b/connection_go18.go index 330d76894..384603a9e 100644 --- a/connection_go18.go +++ b/connection_go18.go @@ -42,9 +42,11 @@ func (mc *mysqlConn) Ping(ctx context.Context) error { // BeginTx implements driver.ConnBeginTx interface func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { + // TODO: support isolation levels return nil, errors.New("mysql: isolation levels not supported") } if opts.ReadOnly { + // TODO: support read-only transactions return nil, errors.New("mysql: read-only transactions not supported") } diff --git a/utils_go18.go b/utils_go18.go index a7e690c8c..eaeac4f84 100644 --- a/utils_go18.go +++ b/utils_go18.go @@ -24,6 +24,7 @@ func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { dargs := make([]driver.Value, len(named)) for n, param := range named { if len(param.Name) > 0 { + // TODO: support the use of Named Parameters #561 return nil, errors.New("mysql: driver does not support the use of Named Parameters") } dargs[n] = param.Value From 85f33dd1736a73a3995240a6d08db529b0396679 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Fri, 9 Jun 2017 08:21:20 +0900 Subject: [PATCH 17/18] rename watcher to starter. --- driver.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/driver.go b/driver.go index 04c609df1..f11e14462 100644 --- a/driver.go +++ b/driver.go @@ -22,6 +22,11 @@ import ( "net" ) +// watcher interface is used for context support (From Go 1.8) +type watcher interface { + startWatcher() +} + // MySQLDriver is exported to make the driver directly accessible. // In general the driver is used via the database/sql package. type MySQLDriver struct{} @@ -62,10 +67,7 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.strict = mc.cfg.Strict // Call startWatcher for context support (From Go 1.8) - type starter interface { - startWatcher() - } - if s, ok := interface{}(mc).(starter); ok { + if s, ok := interface{}(mc).(watcher); ok { s.startWatcher() } From 80f3f6fbec3e6a81ab5361beb0e33e1024404db8 Mon Sep 17 00:00:00 2001 From: Ichinose Shogo Date: Fri, 9 Jun 2017 08:24:22 +0900 Subject: [PATCH 18/18] use mc.error() instead of duplicated logics. --- rows.go | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/rows.go b/rows.go index f7266915a..c7f5ee26c 100644 --- a/rows.go +++ b/rows.go @@ -75,11 +75,8 @@ func (rows *mysqlRows) Close() (err error) { if mc == nil { return nil } - if mc.isBroken() { - if err := mc.canceled(); err != nil { - return err - } - return ErrInvalidConn + if err := mc.error(); err != nil { + return err } // Remove unread packets from stream @@ -107,11 +104,8 @@ func (rows *mysqlRows) nextResultSet() (int, error) { if rows.mc == nil { return 0, io.EOF } - if rows.mc.isBroken() { - if err := rows.mc.canceled(); err != nil { - return 0, err - } - return 0, ErrInvalidConn + if err := rows.mc.error(); err != nil { + return 0, err } // Remove unread packets from stream