From ea32d9f4bc5ce4fc07a421845da8f5aeec520f43 Mon Sep 17 00:00:00 2001 From: Tzu-Chiao Yeh Date: Sat, 5 Sep 2020 09:58:16 +0800 Subject: [PATCH] Ensure backward compatibility with legacy EOF format --- connection.go | 25 ++++++++++++++----------- packets.go | 51 ++++++++++++++++----------------------------------- rows.go | 2 +- 3 files changed, 31 insertions(+), 47 deletions(-) diff --git a/connection.go b/connection.go index 425cc5952..90aec6439 100644 --- a/connection.go +++ b/connection.go @@ -180,16 +180,16 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { // Read Result columnCount, err := stmt.readPrepareResultPacket() - if err != nil { - return stmt, err - } - - if err := mc.readPackets(stmt.paramCount); err != nil { - return nil, err - } + if err == nil { + if stmt.paramCount > 0 { + if err = mc.readUntilEOF(); err != nil { + return nil, err + } + } - if err := mc.readPackets(int(columnCount)); err != nil { - return nil, err + if columnCount > 0 { + err = mc.readUntilEOF() + } } return stmt, err @@ -415,8 +415,11 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { rows.mc = mc rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} - if err := mc.readPackets(resLen); err != nil { - return nil, err + if resLen > 0 { + // Columns + if err := mc.readUntilEOF(); err != nil { + return nil, err + } } dest := make([]driver.Value, resLen) diff --git a/packets.go b/packets.go index 422278fe1..6cec68228 100644 --- a/packets.go +++ b/packets.go @@ -614,7 +614,7 @@ func readStatus(b []byte) statusFlag { } // Ok Packet -// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet +// https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html func (mc *mysqlConn) handleOkPacket(data []byte) error { // 0x00 or 0xFE [1 byte] n := 1 @@ -640,8 +640,12 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { // isEOFPacket will return true if the data is either a EOF-Packet or OK-Packet // acting as an EOF. -func isEOFPacket(data []byte) bool { - return data[0] == iEOF && len(data) < 9 +func (mc *mysqlConn) isEOFPacket(data []byte) bool { + // Legacy EOF packet + if data[0] == iEOF && (len(data) == 5 || len(data) == 1) && mc.flags&clientDeprecateEOF == 0 { + return true + } + return data[0] == iEOF && len(data) < 9 && mc.flags&clientDeprecateEOF != 0 } // Read Packets as Field Packets until EOF-Packet or an Error appears @@ -649,13 +653,13 @@ func isEOFPacket(data []byte) bool { func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { columns := make([]mysqlField, count) - for i := 0; i < count; i++ { + for i := 0; ; i++ { data, err := mc.readPacket() if err != nil { return nil, err } - if mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data) { + if mc.isEOFPacket(data) { if i == count { return columns, nil } @@ -741,7 +745,6 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { // defaultVal, _, err = bytesToLengthCodedBinary(data[pos:]) //} } - return columns, nil } // Read Packets as Field Packets until EOF/OK-Packet or an Error appears @@ -759,12 +762,13 @@ func (rows *textRows) readRow(dest []driver.Value) error { } // EOF Packet - if isEOFPacket(data) { + if mc.isEOFPacket(data) { if mc.flags&clientDeprecateEOF == 0 { // server_status [2 bytes] rows.mc.status = readStatus(data[3:]) } else { if err := mc.handleOkPacket(data); err != nil { + rows.mc = nil return err } } @@ -830,39 +834,15 @@ func (mc *mysqlConn) readUntilEOF() error { switch { case data[0] == iERR: return mc.handleErrorPacket(data) - case isEOFPacket(data): + case mc.isEOFPacket(data): if mc.flags&clientDeprecateEOF == 0 { mc.status = readStatus(data[3:]) - } else { - return mc.handleOkPacket(data) + return nil } - return nil - } - } -} - -func (mc *mysqlConn) readPackets(num int) error { - - // we need to read EOF as well - if mc.flags&clientDeprecateEOF == 0 { - num++ - } - for i := 0; i < num; i++ { - data, err := mc.readPacket() - if err != nil { - return err - } - - switch { - case data[0] == iERR: - return mc.handleErrorPacket(data) - case mc.flags&clientDeprecateEOF == 0 && isEOFPacket(data): - mc.status = readStatus(data[3:]) - return nil + return mc.handleOkPacket(data) } } - return nil } /****************************************************************************** @@ -1223,11 +1203,12 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // packet indicator [1 byte] if data[0] != iOK { - if isEOFPacket(data) { + if rows.mc.isEOFPacket(data) { if rows.mc.flags&clientDeprecateEOF == 0 { rows.mc.status = readStatus(data[3:]) } else { if err := rows.mc.handleOkPacket(data); err != nil { + rows.mc = nil return err } } diff --git a/rows.go b/rows.go index 888bdb5f0..5567a7287 100644 --- a/rows.go +++ b/rows.go @@ -215,7 +215,7 @@ func (rows *textRows) Next(dest []driver.Value) error { if err := mc.error(); err != nil { return err } - + errLog.Print("perform next read") // Fetch next row from stream return rows.readRow(dest) }