Skip to content

Commit

Permalink
Merge branch 'master' into fix-race-on-cancel
Browse files Browse the repository at this point in the history
  • Loading branch information
methane authored Mar 16, 2024
2 parents a832658 + 1a64773 commit b568757
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 91 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ It's possible to access the last inserted ID and number of affected rows for mul

```go
conn, _ := db.Conn(ctx)
conn.Raw(func(conn interface{}) error {
conn.Raw(func(conn any) error {
ex := conn.(driver.Execer)
res, err := ex.Exec(`
UPDATE point SET x = 1 WHERE y = 2;
Expand Down
2 changes: 1 addition & 1 deletion auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
return authEd25519(authData, mc.cfg.Passwd)

default:
mc.cfg.Logger.Print("unknown auth plugin:", plugin)
mc.log("unknown auth plugin:", plugin)
return nil, ErrUnknownPlugin
}
}
Expand Down
23 changes: 14 additions & 9 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ type mysqlConn struct {
closed atomic.Bool // set when conn is closed, before closech is closed
}

// Helper function to call per-connection logger.
func (mc *mysqlConn) log(v ...any) {
mc.cfg.Logger.Print(v...)
}

// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
var cmdSet strings.Builder
Expand Down Expand Up @@ -110,7 +115,7 @@ func (mc *mysqlConn) Begin() (driver.Tx, error) {

func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
if mc.closed.Load() {
mc.cfg.Logger.Print(ErrInvalidConn)
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
var q string
Expand Down Expand Up @@ -153,7 +158,7 @@ func (mc *mysqlConn) cleanup() {
return
}
if err := nc.Close(); err != nil {
mc.cfg.Logger.Print(err)
mc.log(err)
}
// This function can be called from multiple goroutines.
// So we can not mc.clearResult() here.
Expand All @@ -172,14 +177,14 @@ func (mc *mysqlConn) error() error {

func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
if mc.closed.Load() {
mc.cfg.Logger.Print(ErrInvalidConn)
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
err := mc.writeCommandPacketStr(comStmtPrepare, query)
if err != nil {
// STMT_PREPARE is safe to retry. So we can return ErrBadConn here.
mc.cfg.Logger.Print(err)
mc.log(err)
return nil, driver.ErrBadConn
}

Expand Down Expand Up @@ -213,7 +218,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
// can not take the buffer. Something must be wrong with the connection
mc.cfg.Logger.Print(err)
mc.log(err)
return "", ErrInvalidConn
}
buf = buf[:0]
Expand Down Expand Up @@ -305,7 +310,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin

func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
if mc.closed.Load() {
mc.cfg.Logger.Print(ErrInvalidConn)
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
Expand Down Expand Up @@ -365,7 +370,7 @@ func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error)
handleOk := mc.clearResult()

if mc.closed.Load() {
mc.cfg.Logger.Print(ErrInvalidConn)
mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
if len(args) != 0 {
Expand Down Expand Up @@ -460,7 +465,7 @@ func (mc *mysqlConn) finish() {
// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.Load() {
mc.cfg.Logger.Print(ErrInvalidConn)
mc.log(ErrInvalidConn)
return driver.ErrBadConn
}

Expand Down Expand Up @@ -669,7 +674,7 @@ func (mc *mysqlConn) ResetSession(ctx context.Context) error {
err = connCheck(conn)
}
if err != nil {
mc.cfg.Logger.Print("closing bad idle connection: ", err)
mc.log("closing bad idle connection: ", err)
return driver.ErrBadConn
}
}
Expand Down
116 changes: 58 additions & 58 deletions driver_test.go

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ var defaultLogger = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|lo

// Logger is used to log critical error messages.
type Logger interface {
Print(v ...interface{})
Print(v ...any)
}

// NopLogger is a nop implementation of the Logger interface.
type NopLogger struct{}

// Print implements Logger interface.
func (nl *NopLogger) Print(_ ...interface{}) {}
func (nl *NopLogger) Print(_ ...any) {}

// SetLogger is used to set the default logger for critical errors.
// The initial logger is os.Stderr.
Expand Down
2 changes: 1 addition & 1 deletion fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ var (
scanTypeString = reflect.TypeOf("")
scanTypeNullString = reflect.TypeOf(sql.NullString{})
scanTypeBytes = reflect.TypeOf([]byte{})
scanTypeUnknown = reflect.TypeOf(new(interface{}))
scanTypeUnknown = reflect.TypeOf(new(any))
)

type mysqlField struct {
Expand Down
2 changes: 1 addition & 1 deletion nulltime.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type NullTime sql.NullTime
// Scan implements the Scanner interface.
// The value type must be time.Time or string / []byte (formatted time-string),
// otherwise Scan fails.
func (nt *NullTime) Scan(value interface{}) (err error) {
func (nt *NullTime) Scan(value any) (err error) {
if value == nil {
nt.Time, nt.Valid = time.Time{}, false
return
Expand Down
2 changes: 1 addition & 1 deletion nulltime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ var (

func TestScanNullTime(t *testing.T) {
var scanTests = []struct {
in interface{}
in any
error bool
valid bool
time time.Time
Expand Down
24 changes: 12 additions & 12 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
mc.cfg.Logger.Print(err)
mc.log(err)
mc.Close()
return nil, ErrInvalidConn
}
Expand All @@ -57,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
if pktLen == 0 {
// there was no previous packet
if prevData == nil {
mc.cfg.Logger.Print(ErrMalformPkt)
mc.log(ErrMalformPkt)
mc.Close()
return nil, ErrInvalidConn
}
Expand All @@ -71,7 +71,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
if cerr := mc.canceled.Value(); cerr != nil {
return nil, cerr
}
mc.cfg.Logger.Print(err)
mc.log(err)
mc.Close()
return nil, ErrInvalidConn
}
Expand Down Expand Up @@ -134,7 +134,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
// Handle error
if err == nil { // n != len(data)
mc.cleanup()
mc.cfg.Logger.Print(ErrMalformPkt)
mc.log(ErrMalformPkt)
} else {
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
Expand All @@ -144,7 +144,7 @@ func (mc *mysqlConn) writePacket(data []byte) error {
return errBadConnNoWrite
}
mc.cleanup()
mc.cfg.Logger.Print(err)
mc.log(err)
}
return ErrInvalidConn
}
Expand Down Expand Up @@ -302,7 +302,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.cfg.Logger.Print(err)
mc.log(err)
return errBadConnNoWrite
}

Expand Down Expand Up @@ -392,7 +392,7 @@ func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
data, err := mc.buf.takeSmallBuffer(pktLen)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.cfg.Logger.Print(err)
mc.log(err)
return errBadConnNoWrite
}

Expand All @@ -412,7 +412,7 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.cfg.Logger.Print(err)
mc.log(err)
return errBadConnNoWrite
}

Expand All @@ -431,7 +431,7 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.cfg.Logger.Print(err)
mc.log(err)
return errBadConnNoWrite
}

Expand All @@ -452,7 +452,7 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.cfg.Logger.Print(err)
mc.log(err)
return errBadConnNoWrite
}

Expand Down Expand Up @@ -994,7 +994,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
}
if err != nil {
// cannot take the buffer. Something must be wrong with the connection
mc.cfg.Logger.Print(err)
mc.log(err)
return errBadConnNoWrite
}

Expand Down Expand Up @@ -1193,7 +1193,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
if err = mc.buf.store(data); err != nil {
mc.cfg.Logger.Print(err)
mc.log(err)
return errBadConnNoWrite
}
}
Expand Down
6 changes: 3 additions & 3 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) {

func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
if stmt.mc.closed.Load() {
stmt.mc.cfg.Logger.Print(ErrInvalidConn)
stmt.mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
Expand Down Expand Up @@ -95,7 +95,7 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {

func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
if stmt.mc.closed.Load() {
stmt.mc.cfg.Logger.Print(ErrInvalidConn)
stmt.mc.log(ErrInvalidConn)
return nil, driver.ErrBadConn
}
// Send command
Expand Down Expand Up @@ -141,7 +141,7 @@ type converter struct{}
// implementation does not. This function should be kept in sync with
// database/sql/driver defaultConverter.ConvertValue() except for that
// deliberate difference.
func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
func (c converter) ConvertValue(v any) (driver.Value, error) {
if driver.IsValue(v) {
return v, nil
}
Expand Down
4 changes: 2 additions & 2 deletions statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func TestConvertPointer(t *testing.T) {
}

func TestConvertSignedIntegers(t *testing.T) {
values := []interface{}{
values := []any{
int8(-42),
int16(-42),
int32(-42),
Expand Down Expand Up @@ -106,7 +106,7 @@ func (u myUint64) Value() (driver.Value, error) {
}

func TestConvertUnsignedIntegers(t *testing.T) {
values := []interface{}{
values := []any{
uint8(42),
uint16(42),
uint32(42),
Expand Down

0 comments on commit b568757

Please sign in to comment.