diff --git a/server/conn.go b/server/conn.go index c9971a6876b09..a90cd3052884b 100644 --- a/server/conn.go +++ b/server/conn.go @@ -231,6 +231,63 @@ type handshakeResponse41 struct { Attrs map[string]string } +// parseOldHandshakeResponseHeader parses the old version handshake header HandshakeResponse320 +func parseOldHandshakeResponseHeader(packet *handshakeResponse41, data []byte) (parsedBytes int, err error) { + // Ensure there are enough data to read: + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse320 + log.Debugf("Try to parse hanshake response as Protocol::HandshakeResponse320 , packet data: %v", data) + if len(data) < 2+3 { + log.Errorf("Got malformed handshake response, packet data: %v", data) + return 0, mysql.ErrMalformPacket + } + offset := 0 + // capability + capability := binary.LittleEndian.Uint16(data[:2]) + packet.Capability = uint32(capability) + + // be compatible with Protocol::HandshakeResponse41 + packet.Capability = packet.Capability | mysql.ClientProtocol41 + + offset += 2 + // skip max packet size + offset += 3 + // usa default CharsetID + packet.Collation = mysql.CollationNames["utf8mb4_general_ci"] + + return offset, nil +} + +// parseOldHandshakeResponseBody parse the HandshakeResponse for Protocol::HandshakeResponse320 (except the common header part). +func parseOldHandshakeResponseBody(packet *handshakeResponse41, data []byte, offset int) (err error) { + defer func() { + // Check malformat packet cause out of range is disgusting, but don't panic! + if r := recover(); r != nil { + log.Errorf("handshake panic, packet data: %v", data) + err = mysql.ErrMalformPacket + } + }() + // user name + packet.User = string(data[offset : offset+bytes.IndexByte(data[offset:], 0)]) + offset += len(packet.User) + 1 + + if packet.Capability&mysql.ClientConnectWithDB > 0 { + if len(data[offset:]) > 0 { + idx := bytes.IndexByte(data[offset:], 0) + packet.DBName = string(data[offset : offset+idx]) + offset = offset + idx + 1 + } + if len(data[offset:]) > 0 { + packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)] + offset += len(packet.Auth) + 1 + } + } else { + packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)] + offset += len(packet.Auth) + 1 + } + + return nil +} + // parseHandshakeResponseHeader parses the common header of SSLRequest and HandshakeResponse41. func parseHandshakeResponseHeader(packet *handshakeResponse41, data []byte) (parsedBytes int, err error) { // Ensure there are enough data to read: @@ -351,9 +408,24 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse() error { return errors.Trace(err) } + isOldVersion := false + var resp handshakeResponse41 + var pos int + + if len(data) < 2 { + log.Errorf("Got malformed handshake response, packet data: %v", data) + return mysql.ErrMalformPacket + } + + capability := uint32(binary.LittleEndian.Uint16(data[:2])) + if capability&mysql.ClientProtocol41 > 0 { + pos, err = parseHandshakeResponseHeader(&resp, data) + } else { + pos, err = parseOldHandshakeResponseHeader(&resp, data) + isOldVersion = true + } - pos, err := parseHandshakeResponseHeader(&resp, data) if err != nil { return errors.Trace(err) } @@ -368,14 +440,23 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse() error { if err != nil { return errors.Trace(err) } - pos, err = parseHandshakeResponseHeader(&resp, data) + if isOldVersion { + pos, err = parseOldHandshakeResponseHeader(&resp, data) + } else { + pos, err = parseHandshakeResponseHeader(&resp, data) + } if err != nil { return errors.Trace(err) } } // Read the remaining part of the packet. - if err = parseHandshakeResponseBody(&resp, data, pos); err != nil { + if isOldVersion { + err = parseOldHandshakeResponseBody(&resp, data, pos) + } else { + err = parseHandshakeResponseBody(&resp, data, pos) + } + if err != nil { return errors.Trace(err) } @@ -384,6 +465,7 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse() error { cc.dbname = resp.DBName cc.collation = resp.Collation cc.attrs = resp.Attrs + err = cc.openSessionAndDoAuth(resp.Auth) return errors.Trace(err) } diff --git a/server/conn_test.go b/server/conn_test.go index 037399dd7dd4a..2c01f8a7551d5 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -86,6 +86,20 @@ func (ts ConnTestSuite) TestParseHandshakeResponse(c *C) { c.Assert(err, IsNil) c.Assert(p.User, Equals, "pam") c.Assert(p.DBName, Equals, "test") + + // Test for compatibility of Protocol::HandshakeResponse320 + data = []byte{ + 0x00, 0x80, 0x00, 0x00, 0x01, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x00, + } + p = handshakeResponse41{} + offset, err = parseOldHandshakeResponseHeader(&p, data) + c.Assert(err, IsNil) + capability = mysql.ClientProtocol41 | + mysql.ClientSecureConnection + c.Assert(p.Capability&capability, Equals, capability) + err = parseOldHandshakeResponseBody(&p, data, offset) + c.Assert(err, IsNil) + c.Assert(p.User, Equals, "root") } func (ts ConnTestSuite) TestIssue1768(c *C) {