Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add atomic wrappers for bool and error #612

Merged
merged 3 commits into from
Jun 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
sudo: false
language: go
go:
- 1.2
- 1.3
- 1.4
- 1.5
- 1.6
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
* Optional placeholder interpolation

## Requirements
* Go 1.2 or higher
* Go 1.4 or higher
* MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+)

---------------------------------------
Expand Down Expand Up @@ -279,7 +279,7 @@ Default: false

`rejectreadOnly=true` causes the driver to reject read-only connections. This
is for a possible race condition during an automatic failover, where the mysql
client gets connected to a read-only replica after the failover.
client gets connected to a read-only replica after the failover.

Note that this should be a fairly rare case, as an automatic failover normally
happens when the primary is down, and the race condition shouldn't happen
Expand Down
48 changes: 14 additions & 34 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,15 @@ import (
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)

// a copy of context.Context for Go 1.7 and later.
// a copy of context.Context for Go 1.7 and earlier
type mysqlContext interface {
Done() <-chan struct{}
Err() error

// They are defined in context.Context, but go-mysql-driver does not use them.
// defined in context.Context, but not used in this driver:
// Deadline() (deadline time.Time, ok bool)
// Value(key interface{}) interface{}
}
Expand All @@ -44,18 +42,13 @@ type mysqlConn struct {
parseTime bool
strict bool

// for context support (From Go 1.8)
// for context support (Go 1.8+)
watching bool
watcher chan<- mysqlContext
closech chan struct{}
finished chan<- struct{}

// set non-zero when conn is closed, before closech is closed.
// accessed atomically.
closed int32

mu sync.Mutex // guards following fields
canceledErr error // set non-nil if conn is canceled
canceled atomicError // set non-nil if conn is canceled
closed atomicBool // set when conn is closed, before closech is closed
}

// Handles parameters set in DSN after the connection is established
Expand Down Expand Up @@ -89,7 +82,7 @@ func (mc *mysqlConn) handleParams() (err error) {
}

func (mc *mysqlConn) Begin() (driver.Tx, error) {
if mc.isBroken() {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
Expand All @@ -103,7 +96,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {

func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if !mc.isBroken() {
if !mc.closed.IsSet() {
err = mc.writeCommandPacket(comQuit)
}

Expand All @@ -117,7 +110,7 @@ func (mc *mysqlConn) Close() (err error) {
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
if atomic.SwapInt32(&mc.closed, 1) != 0 {
if !mc.closed.TrySet(true) {
return
}

Expand All @@ -131,13 +124,9 @@ func (mc *mysqlConn) cleanup() {
}
}

func (mc *mysqlConn) isBroken() bool {
return atomic.LoadInt32(&mc.closed) != 0
}

func (mc *mysqlConn) error() error {
if mc.isBroken() {
if err := mc.canceled(); err != nil {
if mc.closed.IsSet() {
if err := mc.canceled.Value(); err != nil {
return err
}
return ErrInvalidConn
Expand All @@ -146,7 +135,7 @@ func (mc *mysqlConn) error() error {
}

func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.isBroken() {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
Expand Down Expand Up @@ -300,7 +289,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
}

func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.isBroken() {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
Expand Down Expand Up @@ -361,7 +350,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
}

func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
if mc.isBroken() {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
Expand Down Expand Up @@ -436,19 +425,10 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {

// finish is called when the query has canceled.
func (mc *mysqlConn) cancel(err error) {
mc.mu.Lock()
mc.canceledErr = err
mc.mu.Unlock()
mc.canceled.Set(err)
mc.cleanup()
}

// canceled returns non-nil if the connection was closed due to context cancelation.
func (mc *mysqlConn) canceled() error {
mc.mu.Lock()
defer mc.mu.Unlock()
return mc.canceledErr
}

// finish is called when the query has succeeded.
func (mc *mysqlConn) finish() {
if !mc.watching || mc.finished == nil {
Expand Down
2 changes: 1 addition & 1 deletion connection_go18.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (

// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) error {
if mc.isBroken() {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
Expand Down
6 changes: 3 additions & 3 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// read packet header
data, err := mc.buf.readNext(4)
if err != nil {
if cerr := mc.canceled(); cerr != nil {
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
errLog.Print(err)
Expand Down Expand Up @@ -66,7 +66,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
// read packet body [pktLen bytes]
data, err = mc.buf.readNext(pktLen)
if err != nil {
if cerr := mc.canceled(); cerr != nil {
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
errLog.Print(err)
Expand Down Expand Up @@ -134,7 +134,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
mc.cleanup()
errLog.Print(ErrMalformPkt)
} else {
if cerr := mc.canceled(); cerr != nil {
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
mc.cleanup()
Expand Down
6 changes: 3 additions & 3 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type mysqlStmt struct {
}

func (stmt *mysqlStmt) Close() error {
if stmt.mc == nil || stmt.mc.isBroken() {
if stmt.mc == nil || stmt.mc.closed.IsSet() {
// driver.Stmt.Close can be called more than once, thus this function
// has to be idempotent.
// See also Issue #450 and golang/go#16019.
Expand All @@ -45,7 +45,7 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
}

func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.isBroken() {
if stmt.mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
Expand Down Expand Up @@ -93,7 +93,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
}

func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
if stmt.mc.isBroken() {
if stmt.mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return nil, driver.ErrBadConn
}
Expand Down
4 changes: 2 additions & 2 deletions transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type mysqlTx struct {
}

func (tx *mysqlTx) Commit() (err error) {
if tx.mc == nil || tx.mc.isBroken() {
if tx.mc == nil || tx.mc.closed.IsSet() {
return ErrInvalidConn
}
err = tx.mc.exec("COMMIT")
Expand All @@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) {
}

func (tx *mysqlTx) Rollback() (err error) {
if tx.mc == nil || tx.mc.isBroken() {
if tx.mc == nil || tx.mc.closed.IsSet() {
return ErrInvalidConn
}
err = tx.mc.exec("ROLLBACK")
Expand Down
64 changes: 64 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"fmt"
"io"
"strings"
"sync/atomic"
"time"
)

Expand Down Expand Up @@ -740,3 +741,66 @@ func escapeStringQuotes(buf []byte, v string) []byte {

return buf[:pos]
}

/******************************************************************************
* Sync utils *
******************************************************************************/
// noCopy may be embedded into structs which must not be copied
// after the first use.
//
// See https://github.com/golang/go/issues/8005#issuecomment-190753527
// for details.
type noCopy struct{}

// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock() {}

// atomicBool is a wrapper around uint32 for usage as a boolean value with
// atomic access.
type atomicBool struct {
_noCopy noCopy
value uint32
}

// IsSet returns wether the current boolean value is true
func (ab *atomicBool) IsSet() bool {
return atomic.LoadUint32(&ab.value) > 0
}

// Set sets the value of the bool regardless of the previous value
func (ab *atomicBool) Set(value bool) {
if value {
atomic.StoreUint32(&ab.value, 1)
} else {
atomic.StoreUint32(&ab.value, 0)
}
}

// TrySet sets the value of the bool and returns wether the value changed
func (ab *atomicBool) TrySet(value bool) bool {
if value {
return atomic.SwapUint32(&ab.value, 1) == 0
}
return atomic.SwapUint32(&ab.value, 0) > 0
}

// atomicBool is a wrapper for atomically accessed error values
type atomicError struct {
_noCopy noCopy
value atomic.Value
}

// Set sets the error value regardless of the previous value.
// The value must not be nil
func (ae *atomicError) Set(value error) {
ae.value.Store(value)
}

// Value returns the current error value
func (ae *atomicError) Value() error {
if v := ae.value.Load(); v != nil {
// this will panic if the value doesn't implement the error interface
return v.(error)
}
return nil
}
80 changes: 80 additions & 0 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,83 @@ func TestEscapeQuotes(t *testing.T) {
expect("foo''bar", "foo'bar") // affected
expect("foo\"bar", "foo\"bar") // not affected
}

func TestAtomicBool(t *testing.T) {
var ab atomicBool
if ab.IsSet() {
t.Fatal("Expected value to be false")
}

ab.Set(true)
if ab.value != 1 {
t.Fatal("Set(true) did not set value to 1")
}
if !ab.IsSet() {
t.Fatal("Expected value to be true")
}

ab.Set(true)
if !ab.IsSet() {
t.Fatal("Expected value to be true")
}

ab.Set(false)
if ab.value != 0 {
t.Fatal("Set(false) did not set value to 0")
}
if ab.IsSet() {
t.Fatal("Expected value to be false")
}

ab.Set(false)
if ab.IsSet() {
t.Fatal("Expected value to be false")
}
if ab.TrySet(false) {
t.Fatal("Expected TrySet(false) to fail")
}
if !ab.TrySet(true) {
t.Fatal("Expected TrySet(true) to succeed")
}
if !ab.IsSet() {
t.Fatal("Expected value to be true")
}

ab.Set(true)
if !ab.IsSet() {
t.Fatal("Expected value to be true")
}
if ab.TrySet(true) {
t.Fatal("Expected TrySet(true) to fail")
}
if !ab.TrySet(false) {
t.Fatal("Expected TrySet(false) to succeed")
}
if ab.IsSet() {
t.Fatal("Expected value to be false")
}

ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯
}

func TestAtomicError(t *testing.T) {
var ae atomicError
if ae.Value() != nil {
t.Fatal("Expected value to be nil")
}

ae.Set(ErrMalformPkt)
if v := ae.Value(); v != ErrMalformPkt {
if v == nil {
t.Fatal("Value is still nil")
}
t.Fatal("Error did not match")
}
ae.Set(ErrPktSync)
if ae.Value() == ErrMalformPkt {
t.Fatal("Error still matches old error")
}
if v := ae.Value(); v != ErrPktSync {
t.Fatal("Error did not match")
}
}