Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return ErrBusyBuffer instead of driver.ErrBadConn #611

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Individual Persons

Aaron Hopkins <go-sql-driver at die.net>
Alexander Menzhinsky <amenzhinsky at gmail.com>
Arne Hormann <arnehormann at gmail.com>
Asta Xie <xiemengjun at gmail.com>
Carlos Nieto <jose.carlos at menteslibres.net>
Expand Down
26 changes: 13 additions & 13 deletions buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,39 +109,39 @@ func (b *buffer) readNext(need int) ([]byte, error) {
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) []byte {
func (b *buffer) takeBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil
return nil, ErrUnreadTxRows
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not return the error here. Leave the buffer unchanged.
This specific error should only be returned when we know that it is this specific error.

}

// test (cheap) general case first
if length <= defaultBufSize || length <= cap(b.buf) {
return b.buf[:length]
return b.buf[:length], nil
}

if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf
return b.buf, nil
}
return make([]byte, length)
return make([]byte, length), nil
}

// shortcut which can be used if the requested buffer is guaranteed to be
// smaller than defaultBufSize
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) []byte {
if b.length == 0 {
return b.buf[:length]
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
if b.length > 0 {
return nil, ErrUnreadTxRows
}
return nil
return b.buf[:length], nil
}

// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() []byte {
if b.length == 0 {
return b.buf
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
if b.length > 0 {
return nil, ErrUnreadTxRows
}
return nil
return b.buf, nil
}
8 changes: 3 additions & 5 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
return "", driver.ErrSkip
}

buf := mc.buf.takeCompleteBuffer()
if buf == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return "", driver.ErrBadConn
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
return "", err
}
buf = buf[:0]
argPos := 0
Expand Down
28 changes: 28 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1991,3 +1991,31 @@ func TestRejectReadOnly(t *testing.T) {
dbt.mustExec("DROP TABLE test")
})
}

func TestUnclosedRows(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
tx, err := dbt.db.Begin()
if err != nil {
dbt.Fatal(err)
}

rows, err := tx.Query("SELECT 1")
if err != nil {
dbt.Fatal(err)
}

// here's common mistake: rows are closed only
// when current func exits keeping the rows buffer
// busy for the following request.
defer rows.Close()

if !rows.Next() {
dbt.Fatal("no rows after `SELECT 1`")
}

_, err = tx.Query("SELECT 2")
if err != ErrUnreadTxRows {
dbt.Errorf("got %v, want %v", err, ErrUnreadTxRows)
}
})
}
2 changes: 1 addition & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ var (
ErrPktSync = errors.New("commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
ErrBusyBuffer = errors.New("busy buffer")
ErrUnreadTxRows = errors.New("rows buffer is busy. Try to read out or close previous rows")
)

var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
Expand Down
67 changes: 26 additions & 41 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
}

// Calculate packet length and get buffer with that size
data := mc.buf.takeSmallBuffer(pktLen + 4)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
if err != nil {
return err
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example here it should be impossible that we're in a transaction with unread rows. Here the driver.ErrBadConn should still be returned as there is this is a brand new connection and there really shouldn't be any unread data in the buffer.

Copy link
Author

@amenzhinsky amenzhinsky Jun 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had some time to sort things out, I think that's correct behaviour to return ErrBusyBuffer because when we cannot take it then something must be wrong in the code (multiple functions are using the buffer at the same time) or in this case buffer has some unread data.

And only two functions writePacket and readPacket should return ErrBadConn since they do the io.

So we should return ErrBusyBuffer anyway not ErrBadConn, and probably add additional checks to Query and Exec to return ErrUnreadRows to be more descriptive.

What do you think?

}

// ClientFlags [32 bit]
Expand Down Expand Up @@ -345,11 +343,9 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {

// Calculate the packet length and add a tailing 0
pktLen := len(scrambleBuff) + 1
data := mc.buf.takeSmallBuffer(4 + pktLen)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
data, err := mc.buf.takeSmallBuffer(4 + pktLen)
if err != nil {
return err
}

// Add the scrambled password [null terminated string]
Expand All @@ -364,11 +360,9 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
func (mc *mysqlConn) writeClearAuthPacket() error {
// Calculate the packet length and add a tailing 0
pktLen := len(mc.cfg.Passwd) + 1
data := mc.buf.takeSmallBuffer(4 + pktLen)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
data, err := mc.buf.takeSmallBuffer(4 + pktLen)
if err != nil {
return err
}

// Add the clear password [null terminated string]
Expand All @@ -385,11 +379,9 @@ func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {

// Calculate the packet length and add a tailing 0
pktLen := len(scrambleBuff)
data := mc.buf.takeSmallBuffer(4 + pktLen)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
data, err := mc.buf.takeSmallBuffer(4 + pktLen)
if err != nil {
return err
}

// Add the scramble
Expand All @@ -406,11 +398,9 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
// Reset Packet Sequence
mc.sequence = 0

data := mc.buf.takeSmallBuffer(4 + 1)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
return err
}

// Add command byte
Expand All @@ -425,11 +415,9 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
mc.sequence = 0

pktLen := 1 + len(arg)
data := mc.buf.takeBuffer(pktLen + 4)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
return err
}

// Add command byte
Expand All @@ -446,11 +434,9 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
// Reset Packet Sequence
mc.sequence = 0

data := mc.buf.takeSmallBuffer(4 + 1 + 4)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
return err
}

// Add command byte
Expand Down Expand Up @@ -907,16 +893,15 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
mc.sequence = 0

var data []byte
var err error

if len(args) == 0 {
data = mc.buf.takeBuffer(minPktLen)
data, err = mc.buf.takeBuffer(minPktLen)
} else {
data = mc.buf.takeCompleteBuffer()
data, err = mc.buf.takeCompleteBuffer()
}
if data == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(ErrBusyBuffer)
return driver.ErrBadConn
if err != nil {
return err
}

// command [1 byte]
Expand Down