From 138b36d8c8740094743e81713e912a439e1f6296 Mon Sep 17 00:00:00 2001 From: Jon Eisen Date: Tue, 12 Jun 2012 19:57:31 -0400 Subject: [PATCH 1/2] go fix --- convert.go | 25 +++++++++--------- error.go | 3 +++ handler.go | 31 ++++++++++------------ mysql.go | 62 +++++++++++++++++++++---------------------- mysql_test.go | 73 +++++++++++++++++++++++++-------------------------- packet.go | 47 ++++++++++++++++----------------- password.go | 6 ++--- reader.go | 7 +++-- result.go | 4 +-- statement.go | 33 +++++++++++------------ writer.go | 5 ++-- 11 files changed, 145 insertions(+), 151 deletions(-) diff --git a/convert.go b/convert.go index 0e2455c..08b1b59 100644 --- a/convert.go +++ b/convert.go @@ -6,8 +6,8 @@ package mysql import ( + "io" "math" - "os" "strconv" ) @@ -121,6 +121,7 @@ func ui32tob(n uint32) (b []byte) { } return } + // bytes to int64 func btoi64(b []byte) int64 { return int64(btoui64(b)) @@ -169,7 +170,7 @@ func f64tob(f float64) []byte { } // bytes to length -func btolcb(b []byte) (num uint64, n int, err os.Error) { +func btolcb(b []byte) (num uint64, n int, err error) { switch { // 0-250 = value of first byte case b[0] <= 250: @@ -193,14 +194,14 @@ func btolcb(b []byte) (num uint64, n int, err os.Error) { } // Check there are enough bytes if len(b) < n { - err = os.EOF - return + err = io.EOF + return num, n, err } // Get uint64 t := make([]byte, 8) copy(t, b[1:n]) num = btoui64(t) - return + return num, n, nil } // length to bytes @@ -229,7 +230,7 @@ func atoui64(i interface{}) (n uint64) { return t case string: // Convert to int64 first for signing bit - in, err := strconv.Atoi64(t) + in, err := strconv.ParseInt(t, 10, 64) if err != nil { panic("Invalid string for integer conversion") } @@ -248,8 +249,8 @@ func atof64(i interface{}) (f float64) { case float64: return t case string: - var err os.Error - f, err = strconv.Atof64(t) + var err error + f, err = strconv.ParseFloat(t, 64) if err != nil { panic("Invalid string for floating point conversion") } @@ -263,13 +264,13 @@ func atof64(i interface{}) (f float64) { func atos(i interface{}) (s string) { switch t := i.(type) { case int64: - s = strconv.Itoa64(t) + s = strconv.FormatInt(t, 10) case uint64: - s = strconv.Uitoa64(t) + s = strconv.FormatUint(t, 10) case float32: - s = strconv.Ftoa32(t, 'f', -1) + s = strconv.FormatFloat(float64(t), 'f', -1, 32) case float64: - s = strconv.Ftoa64(t, 'f', -1) + s = strconv.FormatFloat(t, 'f', -1, 64) case []byte: s = string(t) case Date: diff --git a/error.go b/error.go index 7efcc68..c1378c2 100644 --- a/error.go +++ b/error.go @@ -133,3 +133,6 @@ type ServerError struct { func (e *ServerError) String() string { return fmt.Sprintf("#%d %s", e.Errno, e.Error) } +func (e *ServerError) Error() string { + return e.String() +} diff --git a/handler.go b/handler.go index 269393f..b61ba79 100644 --- a/handler.go +++ b/handler.go @@ -5,13 +5,10 @@ // license that can be found in the LICENSE file. package mysql -import ( - "os" - "strconv" -) +import "strconv" // OK packet handler -func handleOK(p *packetOK, c *Client, a, i *uint64, w *uint16) (err os.Error) { +func handleOK(p *packetOK, c *Client, a, i *uint64, w *uint16) (err error) { // Log OK result c.log(1, "[%d] Received OK packet", p.sequence) // Check sequence @@ -32,7 +29,7 @@ func handleOK(p *packetOK, c *Client, a, i *uint64, w *uint16) (err os.Error) { } // Error packet handler -func handleError(p *packetError, c *Client) (err os.Error) { +func handleError(p *packetError, c *Client) (err error) { // Log error result c.log(1, "[%d] Received error packet", p.sequence) // Check sequence @@ -50,7 +47,7 @@ func handleError(p *packetError, c *Client) (err os.Error) { } // EOF packet handler -func handleEOF(p *packetEOF, c *Client) (err os.Error) { +func handleEOF(p *packetEOF, c *Client) (err error) { // Log EOF result c.log(1, "[%d] Received EOF packet", p.sequence) // Check sequence @@ -70,7 +67,7 @@ func handleEOF(p *packetEOF, c *Client) (err os.Error) { } // Result set packet handler -func handleResultSet(p *packetResultSet, c *Client, r *Result) (err os.Error) { +func handleResultSet(p *packetResultSet, c *Client, r *Result) (err error) { // Log error result c.log(1, "[%d] Received result set packet", p.sequence) // Check sequence @@ -84,7 +81,7 @@ func handleResultSet(p *packetResultSet, c *Client, r *Result) (err os.Error) { } // Field packet handler -func handleField(p *packetField, c *Client, r *Result) (err os.Error) { +func handleField(p *packetField, c *Client, r *Result) (err error) { // Log field result c.log(1, "[%d] Received field packet", p.sequence) // Check sequence @@ -110,7 +107,7 @@ func handleField(p *packetField, c *Client, r *Result) (err os.Error) { } // Row packet hander -func handleRow(p *packetRowData, c *Client, r *Result) (err os.Error) { +func handleRow(p *packetRowData, c *Client, r *Result) (err error) { // Log field result c.log(1, "[%d] Received row packet", p.sequence) // Check sequence @@ -128,23 +125,23 @@ func handleRow(p *packetRowData, c *Client, r *Result) (err os.Error) { // Iterate fields to get types for i, f := range r.fields { // Check null - if len(p.row[i].([]byte)) ==0 { + if len(p.row[i].([]byte)) == 0 { field = nil } else { switch f.Type { // Signed/unsigned ints case FIELD_TYPE_TINY, FIELD_TYPE_SHORT, FIELD_TYPE_YEAR, FIELD_TYPE_INT24, FIELD_TYPE_LONG, FIELD_TYPE_LONGLONG: if f.Flags&FLAG_UNSIGNED > 0 { - field, err = strconv.Atoui64(string(p.row[i].([]byte))) + field, err = strconv.ParseUint(string(p.row[i].([]byte)), 10, 64) } else { - field, err = strconv.Atoi64(string(p.row[i].([]byte))) + field, err = strconv.ParseInt(string(p.row[i].([]byte)), 10, 64) } if err != nil { return } // Floats and doubles case FIELD_TYPE_FLOAT, FIELD_TYPE_DOUBLE: - field, err = strconv.Atof64(string(p.row[i].([]byte))) + field, err = strconv.ParseFloat(string(p.row[i].([]byte)), 64) if err != nil { return } @@ -173,7 +170,7 @@ func handleRow(p *packetRowData, c *Client, r *Result) (err os.Error) { } // Prepare OK packet handler -func handlePrepareOK(p *packetPrepareOK, c *Client, s *Statement) (err os.Error) { +func handlePrepareOK(p *packetPrepareOK, c *Client, s *Statement) (err error) { // Log result c.log(1, "[%d] Received prepare OK packet", p.sequence) // Check sequence @@ -190,7 +187,7 @@ func handlePrepareOK(p *packetPrepareOK, c *Client, s *Statement) (err os.Error) } // Parameter packet handler -func handleParam(p *packetParameter, c *Client) (err os.Error) { +func handleParam(p *packetParameter, c *Client) (err error) { // Log result c.log(1, "[%d] Received parameter packet", p.sequence) // Check sequence @@ -203,7 +200,7 @@ func handleParam(p *packetParameter, c *Client) (err os.Error) { } // Binary row packet handler -func handleBinaryRow(p *packetRowBinary, c *Client, r *Result) (err os.Error) { +func handleBinaryRow(p *packetRowBinary, c *Client, r *Result) (err error) { // Log binary row result c.log(1, "[%d] Received binary row packet", p.sequence) // Check sequence diff --git a/mysql.go b/mysql.go index 3a64999..e1a7250 100644 --- a/mysql.go +++ b/mysql.go @@ -11,8 +11,8 @@ import ( "fmt" "io" "log" - "os" "net" + "os" "strings" "sync" "time" @@ -100,7 +100,7 @@ func NewClient(protocol ...uint8) (c *Client) { } // Connect to server via TCP -func DialTCP(raddr, user, passwd string, dbname ...string) (c *Client, err os.Error) { +func DialTCP(raddr, user, passwd string, dbname ...string) (c *Client, err error) { c = NewClient(DEFAULT_PROTOCOL) // Add port if not set if strings.Index(raddr, ":") == -1 { @@ -112,7 +112,7 @@ func DialTCP(raddr, user, passwd string, dbname ...string) (c *Client, err os.Er } // Connect to server via unix socket -func DialUnix(raddr, user, passwd string, dbname ...string) (c *Client, err os.Error) { +func DialUnix(raddr, user, passwd string, dbname ...string) (c *Client, err error) { c = NewClient(DEFAULT_PROTOCOL) // Use default socket if socket is empty if raddr == "" { @@ -124,7 +124,7 @@ func DialUnix(raddr, user, passwd string, dbname ...string) (c *Client, err os.E } // Connect to the server -func (c *Client) Connect(network, raddr, user, passwd string, dbname ...string) (err os.Error) { +func (c *Client) Connect(network, raddr, user, passwd string, dbname ...string) (err error) { // Log connect c.log(1, "=== Begin connect ===") // Check not already connected @@ -152,7 +152,7 @@ func (c *Client) Connect(network, raddr, user, passwd string, dbname ...string) } // Close connection to server -func (c *Client) Close() (err os.Error) { +func (c *Client) Close() (err error) { // Log close c.log(1, "=== Begin close ===") // Check connection @@ -173,7 +173,7 @@ func (c *Client) Close() (err os.Error) { } // Change the current database -func (c *Client) ChangeDb(dbname string) (err os.Error) { +func (c *Client) ChangeDb(dbname string) (err error) { // Auto reconnect defer func() { if err != nil && c.checkNet(err) && c.Reconnect { @@ -205,7 +205,7 @@ func (c *Client) ChangeDb(dbname string) (err os.Error) { } // Send a query/queries to the server -func (c *Client) Query(sql string) (err os.Error) { +func (c *Client) Query(sql string) (err error) { // Auto reconnect defer func() { if err != nil && c.checkNet(err) && c.Reconnect { @@ -242,7 +242,7 @@ func (c *Client) Query(sql string) (err os.Error) { } // Fetch all rows for a result and store it, returning the result set -func (c *Client) StoreResult() (result *Result, err os.Error) { +func (c *Client) StoreResult() (result *Result, err error) { // Auto reconnect defer func() { err = c.simpleReconnect(err) @@ -269,7 +269,7 @@ func (c *Client) StoreResult() (result *Result, err os.Error) { } // Use a result set, does not store rows -func (c *Client) UseResult() (result *Result, err os.Error) { +func (c *Client) UseResult() (result *Result, err error) { // Auto reconnect defer func() { err = c.simpleReconnect(err) @@ -290,7 +290,7 @@ func (c *Client) UseResult() (result *Result, err os.Error) { } // Free the current result -func (c *Client) FreeResult() (err os.Error) { +func (c *Client) FreeResult() (err error) { // Auto reconnect defer func() { err = c.simpleReconnect(err) @@ -329,7 +329,7 @@ func (c *Client) MoreResults() bool { } // Move to the next available result -func (c *Client) NextResult() (more bool, err os.Error) { +func (c *Client) NextResult() (more bool, err error) { // Auto reconnect defer func() { err = c.simpleReconnect(err) @@ -354,7 +354,7 @@ func (c *Client) NextResult() (more bool, err os.Error) { } // Set autocommit -func (c *Client) SetAutoCommit(state bool) (err os.Error) { +func (c *Client) SetAutoCommit(state bool) (err error) { // Log set autocommit c.log(1, "=== Begin set autocommit ===") // Use set autocommit query @@ -368,7 +368,7 @@ func (c *Client) SetAutoCommit(state bool) (err os.Error) { } // Start a transaction -func (c *Client) Start() (err os.Error) { +func (c *Client) Start() (err error) { // Log start transaction c.log(1, "=== Begin start transaction ===") // Use start transaction query @@ -376,7 +376,7 @@ func (c *Client) Start() (err os.Error) { } // Commit a transaction -func (c *Client) Commit() (err os.Error) { +func (c *Client) Commit() (err error) { // Log commit c.log(1, "=== Begin commit ===") // Use commit query @@ -384,7 +384,7 @@ func (c *Client) Commit() (err os.Error) { } // Rollback a transaction -func (c *Client) Rollback() (err os.Error) { +func (c *Client) Rollback() (err error) { // Log rollback c.log(1, "=== Begin rollback ===") // Use rollback query @@ -412,7 +412,7 @@ func (c *Client) Escape(s string) (esc string) { } // Initialise a new statment -func (c *Client) InitStmt() (stmt *Statement, err os.Error) { +func (c *Client) InitStmt() (stmt *Statement, err error) { // Check connection if !c.checkConn() { return nil, &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR} @@ -424,7 +424,7 @@ func (c *Client) InitStmt() (stmt *Statement, err os.Error) { } // Initialise and prepare a new statement -func (c *Client) Prepare(sql string) (stmt *Statement, err os.Error) { +func (c *Client) Prepare(sql string) (stmt *Statement, err error) { // Initialise a new statement stmt, err = c.InitStmt() if err != nil { @@ -530,7 +530,7 @@ func (c *Client) checkResult() bool { } // Check if a network error occurred -func (c *Client) checkNet(err os.Error) bool { +func (c *Client) checkNet(err error) bool { if cErr, ok := err.(*ClientError); ok { if cErr.Errno == CR_SERVER_GONE_ERROR || cErr.Errno == CR_SERVER_LOST { return true @@ -540,7 +540,7 @@ func (c *Client) checkNet(err os.Error) bool { } // Performs the actual connect -func (c *Client) connect() (err os.Error) { +func (c *Client) connect() (err error) { // Connect to server err = c.dial() if err != nil { @@ -582,7 +582,7 @@ func (c *Client) connect() (err os.Error) { } // Connect to server -func (c *Client) dial() (err os.Error) { +func (c *Client) dial() (err error) { // Log connect c.log(1, "Connecting to server via %s to %s", c.network, c.raddr) // Connect to server @@ -597,7 +597,7 @@ func (c *Client) dial() (err os.Error) { } // Log error if cErr, ok := err.(*ClientError); ok { - c.log(1, string(cErr.Error)) + c.log(1, string(cErr.Err)) } return } @@ -612,7 +612,7 @@ func (c *Client) dial() (err os.Error) { } // Read initial packet from server -func (c *Client) init() (err os.Error) { +func (c *Client) init() (err error) { // Log read packet c.log(1, "Reading handshake initialization packet from server") // Read packet @@ -653,7 +653,7 @@ func (c *Client) init() (err os.Error) { } // Send auth packet to the server -func (c *Client) auth() (err os.Error) { +func (c *Client) auth() (err error) { // Log write packet c.log(1, "Sending authentication packet to server") // Construct packet @@ -699,7 +699,7 @@ func (c *Client) auth() (err os.Error) { } // Simple non-recovered reconnect -func (c *Client) simpleReconnect(err os.Error) os.Error { +func (c *Client) simpleReconnect(err error) error { if err != nil && c.checkNet(err) && c.Reconnect { c.log(1, "!!! Lost connection to server !!!") c.connected = false @@ -712,7 +712,7 @@ func (c *Client) simpleReconnect(err os.Error) os.Error { } // Perform reconnect if a network error occurs -func (c *Client) reconnect() (err os.Error) { +func (c *Client) reconnect() (err error) { // Log auto reconnect c.log(1, "=== Begin auto reconnect attempt ===") // Reset the client @@ -730,7 +730,7 @@ func (c *Client) reconnect() (err os.Error) { } // Send a command to the server -func (c *Client) command(command command, args ...interface{}) (err os.Error) { +func (c *Client) command(command command, args ...interface{}) (err error) { // Log write packet c.log(1, "Sending command packet to server") // Simple validation, arg count @@ -783,7 +783,7 @@ func (c *Client) command(command command, args ...interface{}) (err os.Error) { } // Get field packets for a result -func (c *Client) getFields() (err os.Error) { +func (c *Client) getFields() (err error) { // Check for a valid result if c.result == nil { return &ClientError{CR_NO_RESULT_SET, CR_NO_RESULT_SET_STR} @@ -803,7 +803,7 @@ func (c *Client) getFields() (err os.Error) { } // Get next row for a result -func (c *Client) getRow() (eof bool, err os.Error) { +func (c *Client) getRow() (eof bool, err error) { // Check for a valid result if c.result == nil { return false, &ClientError{CR_NO_RESULT_SET, CR_NO_RESULT_SET_STR} @@ -815,7 +815,7 @@ func (c *Client) getRow() (eof bool, err os.Error) { } // Get all rows for the result -func (c *Client) getAllRows() (err os.Error) { +func (c *Client) getAllRows() (err error) { for { eof, err := c.getRow() if err != nil { @@ -829,7 +829,7 @@ func (c *Client) getAllRows() (err os.Error) { } // Get result -func (c *Client) getResult(types packetType) (eof bool, err os.Error) { +func (c *Client) getResult(types packetType) (eof bool, err error) { // Log read result c.log(1, "Reading result packet from server") // Get result packet @@ -860,7 +860,7 @@ func (c *Client) getResult(types packetType) (eof bool, err os.Error) { } // Sequence check -func (c *Client) checkSequence(sequence uint8) (err os.Error) { +func (c *Client) checkSequence(sequence uint8) (err error) { if sequence != c.sequence { c.log(1, "Sequence doesn't match - expected %d but got %d, commands out of sync", c.sequence, sequence) return &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR} diff --git a/mysql_test.go b/mysql_test.go index 99c78f3..58aa479 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -7,8 +7,7 @@ package mysql import ( "fmt" - "os" - "rand" + "math/rand" "strconv" "testing" ) @@ -50,7 +49,7 @@ const ( var ( db *Client - err os.Error + err error ) type SimpleRow struct { @@ -144,14 +143,14 @@ func TestSimple(t *testing.T) { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Create table") err = db.Query(CREATE_SIMPLE) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Insert 1000 records") rowMap := make(map[uint64][]string) for i := 0; i < 1000; i++ { @@ -164,21 +163,21 @@ func TestSimple(t *testing.T) { row := []string{fmt.Sprintf("%d", num), str1, str2} rowMap[db.LastInsertId] = row } - + t.Logf("Select inserted data") err = db.Query(SELECT_SIMPLE) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Use result") res, err := db.UseResult() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Validate inserted data") for { row := res.FetchRow() @@ -186,20 +185,20 @@ func TestSimple(t *testing.T) { break } id := row[0].(uint64) - num, str1, str2 := strconv.Itoa64(row[1].(int64)), row[2].(string), string(row[3].([]byte)) + num, str1, str2 := strconv.FormatInt(row[1].(int64), 10), row[2].(string), string(row[3].([]byte)) if rowMap[id][0] != num || rowMap[id][1] != str1 || rowMap[id][2] != str2 { t.Logf("String from database doesn't match local string") t.Fail() } } - + t.Logf("Free result") err = res.Free() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Update some records") for i := uint64(0); i < 1000; i += 5 { rowMap[i+1][2] = randString(256) @@ -213,21 +212,21 @@ func TestSimple(t *testing.T) { t.Fail() } } - + t.Logf("Select updated data") err = db.Query(SELECT_SIMPLE) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Store result") res, err = db.StoreResult() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Validate updated data") for { row := res.FetchRow() @@ -235,14 +234,14 @@ func TestSimple(t *testing.T) { break } id := row[0].(uint64) - num, str1, str2 := strconv.Itoa64(row[1].(int64)), row[2].(string), string(row[3].([]byte)) + num, str1, str2 := strconv.FormatInt(row[1].(int64), 10), row[2].(string), string(row[3].([]byte)) if rowMap[id][0] != num || rowMap[id][1] != str1 || rowMap[id][2] != str2 { t.Logf("%#v %#v", rowMap[id], row) t.Logf("String from database doesn't match local string") t.Fail() } } - + t.Logf("Free result") err = res.Free() if err != nil { @@ -256,7 +255,7 @@ func TestSimple(t *testing.T) { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Close connection") err = db.Close() if err != nil { @@ -273,35 +272,35 @@ func TestSimpleStatement(t *testing.T) { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Init statement") stmt, err := db.InitStmt() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Prepare create table") err = stmt.Prepare(CREATE_SIMPLE) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Execute create table") err = stmt.Execute() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Prepare insert") err = stmt.Prepare(INSERT_SIMPLE_STMT) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Insert 1000 records") rowMap := make(map[uint64][]string) for i := 0; i < 1000; i++ { @@ -319,25 +318,25 @@ func TestSimpleStatement(t *testing.T) { row := []string{fmt.Sprintf("%d", num), str1, str2} rowMap[stmt.LastInsertId] = row } - + t.Logf("Prepare select") err = stmt.Prepare(SELECT_SIMPLE) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Execute select") err = stmt.Execute() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Bind result") row := SimpleRow{} stmt.BindResult(&row.Id, &row.Number, &row.String, &row.Text, &row.Date) - + t.Logf("Validate inserted data") for { eof, err := stmt.Fetch() @@ -353,21 +352,21 @@ func TestSimpleStatement(t *testing.T) { t.Fail() } } - + t.Logf("Reset statement") err = stmt.Reset() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Prepare update") err = stmt.Prepare(UPDATE_SIMPLE_STMT) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Update some records") for i := uint64(0); i < 1000; i += 5 { rowMap[i+1][2] = randString(256) @@ -382,21 +381,21 @@ func TestSimpleStatement(t *testing.T) { t.Fail() } } - + t.Logf("Prepare select updated") err = stmt.Prepare(SELECT_SIMPLE) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Execute select updated") err = stmt.Execute() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Validate updated data") for { eof, err := stmt.Fetch() @@ -412,35 +411,35 @@ func TestSimpleStatement(t *testing.T) { t.Fail() } } - + t.Logf("Free result") err = stmt.FreeResult() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Prepare drop") err = stmt.Prepare(DROP_SIMPLE) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Execute drop") err = stmt.Execute() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Close statement") err = stmt.Close() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Close connection") err = db.Close() if err != nil { diff --git a/packet.go b/packet.go index 73978a8..64a1d8a 100644 --- a/packet.go +++ b/packet.go @@ -7,7 +7,7 @@ package mysql import ( "bytes" - "os" + "io" ) // Packet type identifier @@ -33,12 +33,12 @@ const ( // Readable packet interface type packetReadable interface { - read(data []byte) (err os.Error) + read(data []byte) (err error) } // Writable packet interface type packetWritable interface { - write() (data []byte, err os.Error) + write() (data []byte, err error) } // Generic packet interface (read/writable) @@ -54,19 +54,19 @@ type packetBase struct { } // Read a slice from the data -func (p *packetBase) readSlice(data []byte, delim byte) (slice []byte, err os.Error) { +func (p *packetBase) readSlice(data []byte, delim byte) (slice []byte, err error) { pos := bytes.IndexByte(data, delim) if pos > -1 { slice = data[:pos] } else { slice = data - err = os.EOF + err = io.EOF } return } // Read length coded string -func (p *packetBase) readLengthCodedString(data []byte) (s string, n int, err os.Error) { +func (p *packetBase) readLengthCodedString(data []byte) (s string, n int, err error) { // Read bytes and convert to string b, n, err := p.readLengthCodedBytes(data) if err != nil { @@ -76,7 +76,7 @@ func (p *packetBase) readLengthCodedString(data []byte) (s string, n int, err os return } -func (p *packetBase) readLengthCodedBytes(data []byte) (b []byte, n int, err os.Error) { +func (p *packetBase) readLengthCodedBytes(data []byte) (b []byte, n int, err error) { // Get string length num, n, err := btolcb(data) if err != nil { @@ -84,7 +84,7 @@ func (p *packetBase) readLengthCodedBytes(data []byte) (b []byte, n int, err os. } // Check data length if len(data) < n+int(num) { - err = os.EOF + err = io.EOF return } // Get bytes @@ -101,7 +101,6 @@ func (p *packetBase) addHeader(data []byte) (pkt []byte) { return } - // Init packet type packetInit struct { packetBase @@ -115,7 +114,7 @@ type packetInit struct { } // Init packet reader -func (p *packetInit) read(data []byte) (err os.Error) { +func (p *packetInit) read(data []byte) (err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -169,7 +168,7 @@ type packetAuth struct { } // Auth packet writer -func (p *packetAuth) write() (data []byte, err os.Error) { +func (p *packetAuth) write() (data []byte, err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -238,7 +237,7 @@ type packetOK struct { } // OK packet reader -func (p *packetOK) read(data []byte) (err os.Error) { +func (p *packetOK) read(data []byte) (err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -285,7 +284,7 @@ type packetError struct { } // Error packet reader -func (p *packetError) read(data []byte) (err os.Error) { +func (p *packetError) read(data []byte) (err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -318,7 +317,7 @@ type packetEOF struct { } // EOF packet reader -func (p *packetEOF) read(data []byte) (err os.Error) { +func (p *packetEOF) read(data []byte) (err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -347,7 +346,7 @@ type packetPassword struct { } // Password packet writer -func (p *packetPassword) write() (data []byte, err os.Error) { +func (p *packetPassword) write() (data []byte, err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -371,7 +370,7 @@ type packetCommand struct { } // Command packet writer -func (p *packetCommand) write() (data []byte, err os.Error) { +func (p *packetCommand) write() (data []byte, err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -450,7 +449,7 @@ type packetResultSet struct { } // Result set packet reader -func (p *packetResultSet) read(data []byte) (err os.Error) { +func (p *packetResultSet) read(data []byte) (err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -493,7 +492,7 @@ type packetField struct { } // Field packet reader -func (p *packetField) read(data []byte) (err os.Error) { +func (p *packetField) read(data []byte) (err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -599,7 +598,7 @@ type packetRowData struct { } // Row data packet reader -func (p *packetRowData) read(data []byte) (err os.Error) { +func (p *packetRowData) read(data []byte) (err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -636,7 +635,7 @@ type packetPrepareOK struct { } // Prepare ok packet reader -func (p *packetPrepareOK) read(data []byte) (err os.Error) { +func (p *packetPrepareOK) read(data []byte) (err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -669,7 +668,7 @@ type packetParameter struct { } // Parameter packet reader -func (p *packetParameter) read(data []byte) (err os.Error) { +func (p *packetParameter) read(data []byte) (err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -690,7 +689,7 @@ type packetLongData struct { } // Long data packet writer -func (p *packetLongData) write() (data []byte, err os.Error) { +func (p *packetLongData) write() (data []byte, err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -724,7 +723,7 @@ type packetExecute struct { } // Execute packet writer -func (p *packetExecute) write() (data []byte, err os.Error) { +func (p *packetExecute) write() (data []byte, err error) { // Recover errors defer func() { if e := recover(); e != nil { @@ -769,7 +768,7 @@ type packetRowBinary struct { } // Row binary packet reader -func (p *packetRowBinary) read(data []byte) (err os.Error) { +func (p *packetRowBinary) read(data []byte) (err error) { // Recover errors defer func() { if e := recover(); e != nil { diff --git a/password.go b/password.go index 93b36f9..93f64e9 100644 --- a/password.go +++ b/password.go @@ -92,17 +92,17 @@ func scramble41(message, password []byte) (result []byte) { // SHA1 encode crypt := sha1.New() crypt.Write(password) - stg1Hash := crypt.Sum() + stg1Hash := crypt.Sum(nil) // token = SHA1(SHA1(stage1_hash), scramble) XOR stage1_hash // SHA1 encode again crypt.Reset() crypt.Write(stg1Hash) - stg2Hash := crypt.Sum() + stg2Hash := crypt.Sum(nil) // SHA1 2nd hash and scramble crypt.Reset() crypt.Write(message) crypt.Write(stg2Hash) - stg3Hash := crypt.Sum() + stg3Hash := crypt.Sum(nil) // XOR with first hash result = make([]byte, 20) for i := range result { diff --git a/reader.go b/reader.go index de888b8..c7648c0 100644 --- a/reader.go +++ b/reader.go @@ -8,7 +8,6 @@ package mysql import ( "io" "net" - "os" ) // Packet reader struct @@ -26,12 +25,12 @@ func newReader(conn io.ReadWriteCloser) *reader { } // Read the next packet -func (r *reader) readPacket(types packetType) (p packetReadable, err os.Error) { +func (r *reader) readPacket(types packetType) (p packetReadable, err error) { // Deferred error processing defer func() { if err != nil { // EOF errors - if err == os.EOF || err == io.ErrUnexpectedEOF { + if err == io.EOF || err == io.ErrUnexpectedEOF { err = &ClientError{CR_SERVER_LOST, CR_SERVER_LOST_STR} } // OpError @@ -127,7 +126,7 @@ func (r *reader) readPacket(types packetType) (p packetReadable, err os.Error) { } // Read n bytes long number -func (r *reader) readNumber(n uint8) (num uint64, err os.Error) { +func (r *reader) readNumber(n uint8) (num uint64, err error) { // Read bytes into array buf := make([]byte, n) nr, err := io.ReadFull(r.conn, buf) diff --git a/result.go b/result.go index ed8a758..2f80528 100644 --- a/result.go +++ b/result.go @@ -5,8 +5,6 @@ // license that can be found in the LICENSE file. package mysql -import "os" - // Result struct type Result struct { // Pointer to the client @@ -122,7 +120,7 @@ func (r *Result) FetchRows() []Row { } // Free the result -func (r *Result) Free() (err os.Error) { +func (r *Result) Free() (err error) { err = r.c.FreeResult() return } diff --git a/statement.go b/statement.go index d155f78..a791cbe 100644 --- a/statement.go +++ b/statement.go @@ -6,7 +6,6 @@ package mysql import ( - "os" "reflect" "strconv" ) @@ -42,7 +41,7 @@ type Statement struct { } // Prepare new statement -func (s *Statement) Prepare(sql string) (err os.Error) { +func (s *Statement) Prepare(sql string) (err error) { // Auto reconnect defer func() { if err != nil && s.c.checkNet(err) && s.c.Reconnect { @@ -105,7 +104,7 @@ func (s *Statement) ParamCount() uint16 { } // Bind params -func (s *Statement) BindParams(params ...interface{}) (err os.Error) { +func (s *Statement) BindParams(params ...interface{}) (err error) { // Check prepared if !s.prepared { return &ClientError{CR_NO_PREPARE_STMT, CR_NO_PREPARE_STMT_STR} @@ -208,7 +207,7 @@ func (s *Statement) BindParams(params ...interface{}) (err os.Error) { } // Send long data -func (s *Statement) SendLongData(num int, data []byte) (err os.Error) { +func (s *Statement) SendLongData(num int, data []byte) (err error) { // Auto reconnect defer func() { err = s.c.simpleReconnect(err) @@ -264,7 +263,7 @@ func (s *Statement) SendLongData(num int, data []byte) (err os.Error) { } // Execute -func (s *Statement) Execute() (err os.Error) { +func (s *Statement) Execute() (err error) { // Auto reconnect defer func() { if err != nil && s.c.checkNet(err) && s.c.Reconnect { @@ -362,7 +361,7 @@ func (s *Statement) FetchColumns() []*Field { } // Bind result -func (s *Statement) BindResult(params ...interface{}) (err os.Error) { +func (s *Statement) BindResult(params ...interface{}) (err error) { s.resultParams = params return } @@ -377,7 +376,7 @@ func (s *Statement) RowCount() uint64 { } // Fetch next row -func (s *Statement) Fetch() (eof bool, err os.Error) { +func (s *Statement) Fetch() (eof bool, err error) { // Auto reconnect defer func() { err = s.c.simpleReconnect(err) @@ -473,7 +472,7 @@ func (s *Statement) Fetch() (eof bool, err os.Error) { } // Store result -func (s *Statement) StoreResult() (err os.Error) { +func (s *Statement) StoreResult() (err error) { // Auto reconnect defer func() { err = s.c.simpleReconnect(err) @@ -500,7 +499,7 @@ func (s *Statement) StoreResult() (err os.Error) { } // Free result -func (s *Statement) FreeResult() (err os.Error) { +func (s *Statement) FreeResult() (err error) { // Auto reconnect defer func() { err = s.c.simpleReconnect(err) @@ -526,7 +525,7 @@ func (s *Statement) MoreResults() bool { } // Next result -func (s *Statement) NextResult() (more bool, err os.Error) { +func (s *Statement) NextResult() (more bool, err error) { // Auto reconnect defer func() { err = s.c.simpleReconnect(err) @@ -558,7 +557,7 @@ func (s *Statement) NextResult() (more bool, err os.Error) { } // Reset statement -func (s *Statement) Reset() (err os.Error) { +func (s *Statement) Reset() (err error) { // Auto reconnect defer func() { err = s.c.simpleReconnect(err) @@ -591,7 +590,7 @@ func (s *Statement) Reset() (err os.Error) { } // Close statement -func (s *Statement) Close() (err os.Error) { +func (s *Statement) Close() (err error) { // Auto reconnect defer func() { err = s.c.simpleReconnect(err) @@ -648,7 +647,7 @@ func (s *Statement) getNullBitMap() (nbm []byte) { } // Get all result fields -func (s *Statement) getFields() (err os.Error) { +func (s *Statement) getFields() (err error) { // Loop till EOF for { s.c.sequence++ @@ -664,7 +663,7 @@ func (s *Statement) getFields() (err os.Error) { } // Get next row for a result -func (s *Statement) getRow() (eof bool, err os.Error) { +func (s *Statement) getRow() (eof bool, err error) { // Check for a valid result if s.result == nil { return false, &ClientError{CR_NO_RESULT_SET, CR_NO_RESULT_SET_STR} @@ -676,7 +675,7 @@ func (s *Statement) getRow() (eof bool, err os.Error) { } // Get all rows for the result -func (s *Statement) getAllRows() (err os.Error) { +func (s *Statement) getAllRows() (err error) { for { eof, err := s.getRow() if err != nil { @@ -690,7 +689,7 @@ func (s *Statement) getAllRows() (err os.Error) { } // Get result -func (s *Statement) getResult(types packetType) (eof bool, err os.Error) { +func (s *Statement) getResult(types packetType) (eof bool, err error) { // Log read result s.c.log(1, "Reading result packet from server") // Get result packet @@ -725,7 +724,7 @@ func (s *Statement) getResult(types packetType) (eof bool, err os.Error) { } // Free any result sets waiting to be read -func (s *Statement) freeAll(next bool) (err os.Error) { +func (s *Statement) freeAll(next bool) (err error) { // Check for unread rows if !s.result.allRead { // Read all rows diff --git a/writer.go b/writer.go index 9a62c81..94ed778 100644 --- a/writer.go +++ b/writer.go @@ -8,7 +8,6 @@ package mysql import ( "io" "net" - "os" ) // Packet writer struct @@ -24,12 +23,12 @@ func newWriter(conn io.ReadWriteCloser) *writer { } // Write packet to the server -func (w *writer) writePacket(p packetWritable) (err os.Error) { +func (w *writer) writePacket(p packetWritable) (err error) { // Deferred error processing defer func() { if err != nil { // EOF errors - if err == os.EOF || err == io.ErrUnexpectedEOF { + if err == io.EOF || err == io.ErrUnexpectedEOF { err = &ClientError{CR_SERVER_LOST, CR_SERVER_LOST_STR} } // OpError From 8e8f75cb746b7f2b4f571d8e92093af557854020 Mon Sep 17 00:00:00 2001 From: Jon Eisen Date: Tue, 12 Jun 2012 22:38:26 -0400 Subject: [PATCH 2/2] All test pass. gofmt --- error.go | 17 ++++++++++------- handler.go | 8 ++++---- mysql.go | 8 ++++---- packet.go | 2 +- statement.go | 8 ++++---- 5 files changed, 23 insertions(+), 20 deletions(-) diff --git a/error.go b/error.go index c1378c2..29955d9 100644 --- a/error.go +++ b/error.go @@ -114,25 +114,28 @@ const ( // Client error struct type ClientError struct { - Errno Errno - Error Error + Errno Errno + ErrorS Error } // Convert to string func (e *ClientError) String() string { - return fmt.Sprintf("#%d %s", e.Errno, e.Error) + return fmt.Sprintf("#%d %s", e.Errno, e.ErrorS) +} +func (e *ClientError) Error() string { + return e.String() } // Server error struct type ServerError struct { - Errno Errno - Error Error + Errno Errno + ErrorS Error } // Convert to string func (e *ServerError) String() string { - return fmt.Sprintf("#%d %s", e.Errno, e.Error) + return fmt.Sprintf("#%d %s", e.Errno, e.ErrorS) } func (e *ServerError) Error() string { - return e.String() + return e.String() } diff --git a/handler.go b/handler.go index b61ba79..ce926dc 100644 --- a/handler.go +++ b/handler.go @@ -275,7 +275,7 @@ func handleBinaryRow(p *packetRowBinary, c *Client, r *Result) (err error) { FIELD_TYPE_VAR_STRING, FIELD_TYPE_STRING, FIELD_TYPE_GEOMETRY: num, n, err := btolcb(p.data[pos:]) if err != nil { - return + return err } field = p.data[pos+uint64(n) : pos+uint64(n)+num] pos += uint64(n) + num @@ -283,7 +283,7 @@ func handleBinaryRow(p *packetRowBinary, c *Client, r *Result) (err error) { case FIELD_TYPE_DATE: num, n, err := btolcb(p.data[pos:]) if err != nil { - return + return err } // New date d := Date{} @@ -305,7 +305,7 @@ func handleBinaryRow(p *packetRowBinary, c *Client, r *Result) (err error) { case FIELD_TYPE_TIME: num, n, err := btolcb(p.data[pos:]) if err != nil { - return + return err } // New time t := Time{} @@ -327,7 +327,7 @@ func handleBinaryRow(p *packetRowBinary, c *Client, r *Result) (err error) { case FIELD_TYPE_TIMESTAMP, FIELD_TYPE_DATETIME: num, n, err := btolcb(p.data[pos:]) if err != nil { - return + return err } // New datetime d := DateTime{} diff --git a/mysql.go b/mysql.go index e1a7250..02bea2e 100644 --- a/mysql.go +++ b/mysql.go @@ -597,7 +597,7 @@ func (c *Client) dial() (err error) { } // Log error if cErr, ok := err.(*ClientError); ok { - c.log(1, string(cErr.Err)) + c.log(1, cErr.Error()) } return } @@ -793,13 +793,13 @@ func (c *Client) getFields() (err error) { c.sequence++ eof, err := c.getResult(PACKET_FIELD | PACKET_EOF) if err != nil { - return + return err } if eof { break } } - return + return nil } // Get next row for a result @@ -819,7 +819,7 @@ func (c *Client) getAllRows() (err error) { for { eof, err := c.getRow() if err != nil { - return + return err } if eof { break diff --git a/packet.go b/packet.go index 64a1d8a..faaf9b1 100644 --- a/packet.go +++ b/packet.go @@ -612,7 +612,7 @@ func (p *packetRowData) read(data []byte) (err error) { // Read string b, n, err := p.readLengthCodedBytes(data[pos:]) if err != nil { - return + return err } // Add to slice p.row = append(p.row, b) diff --git a/statement.go b/statement.go index a791cbe..fe3e05f 100644 --- a/statement.go +++ b/statement.go @@ -78,7 +78,7 @@ func (s *Statement) Prepare(sql string) (err error) { s.c.sequence++ eof, err := s.getResult(PACKET_PARAM | PACKET_EOF) if err != nil { - return + return err } if eof { break @@ -653,13 +653,13 @@ func (s *Statement) getFields() (err error) { s.c.sequence++ eof, err := s.getResult(PACKET_FIELD | PACKET_EOF) if err != nil { - return + return err } if eof { break } } - return + return nil } // Get next row for a result @@ -679,7 +679,7 @@ func (s *Statement) getAllRows() (err error) { for { eof, err := s.getRow() if err != nil { - return + return err } if eof { break