diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index ad81ec60a..0bf03f335 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -1674,25 +1674,55 @@ func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte + err error } // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { - batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) + if batch.err != nil { + return + } batch.ExecPrepared("", paramValues, paramFormats, resultFormats) } // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { - batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) - batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) - batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) + if batch.err != nil { + return + } + + batch.buf, batch.err = (&pgproto3.Execute{}).Encode(batch.buf) + if batch.err != nil { + return + } } // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. This is a simpler way of executing // multiple queries in a single round trip than using pipeline mode. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { + if batch.err != nil { + return &MultiResultReader{ + closed: true, + err: batch.err, + } + } + if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, @@ -1718,7 +1748,13 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR pgConn.contextWatcher.Watch(ctx) } - batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) + batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf) + if batch.err != nil { + multiResult.closed = true + multiResult.err = batch.err + pgConn.unlock() + return multiResult + } pgConn.enterPotentialWriteReadDeadlock() defer pgConn.exitPotentialWriteReadDeadlock() diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index f04fa79ad..b77d21c17 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -3363,9 +3363,9 @@ func TestSNISupport(t *testing.T) { return } - srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil)) - srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)) - srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)) + srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))) + srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))) + srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))) serverSNINameChan <- sniHost }() @@ -3472,3 +3472,10 @@ func TestFatalErrorReceivedInPipelineMode(t *testing.T) { err = pipeline.Close() require.Error(t, err) } + +func mustEncode(buf []byte, err error) []byte { + if err != nil { + panic(err) + } + return buf +} diff --git a/pgproto3/authentication_cleartext_password.go b/pgproto3/authentication_cleartext_password.go index d8f98b9af..ac2962e9e 100644 --- a/pgproto3/authentication_cleartext_password.go +++ b/pgproto3/authentication_cleartext_password.go @@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 8) +func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_gss.go b/pgproto3/authentication_gss.go index 0d234222f..178ef31d8 100644 --- a/pgproto3/authentication_gss.go +++ b/pgproto3/authentication_gss.go @@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error { return nil } -func (a *AuthenticationGSS) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 4) +func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeGSS) - return dst + return finishMessage(dst, sp) } func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/authentication_gss_continue.go b/pgproto3/authentication_gss_continue.go index 63789dc1a..2ba3f3b3e 100644 --- a/pgproto3/authentication_gss_continue.go +++ b/pgproto3/authentication_gss_continue.go @@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error { return nil } -func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, int32(len(a.Data))+8) +func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeGSSCont) dst = append(dst, a.Data...) - return dst + return finishMessage(dst, sp) } func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) { diff --git a/pgproto3/authentication_md5_password.go b/pgproto3/authentication_md5_password.go index 5671c84c5..854c6404e 100644 --- a/pgproto3/authentication_md5_password.go +++ b/pgproto3/authentication_md5_password.go @@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 12) +func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeMD5Password) dst = append(dst, src.Salt[:]...) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_ok.go b/pgproto3/authentication_ok.go index 88d648ae7..ec11d39f1 100644 --- a/pgproto3/authentication_ok.go +++ b/pgproto3/authentication_ok.go @@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationOk) Encode(dst []byte) []byte { - dst = append(dst, 'R') - dst = pgio.AppendInt32(dst, 8) +func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeOk) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_sasl.go b/pgproto3/authentication_sasl.go index 59650d4cd..e66580f44 100644 --- a/pgproto3/authentication_sasl.go +++ b/pgproto3/authentication_sasl.go @@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASL) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASL) for _, s := range src.AuthMechanisms { @@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte { } dst = append(dst, 0) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_sasl_continue.go b/pgproto3/authentication_sasl_continue.go index 2ce70a477..70fba4a67 100644 --- a/pgproto3/authentication_sasl_continue.go +++ b/pgproto3/authentication_sasl_continue.go @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) - dst = append(dst, src.Data...) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/authentication_sasl_final.go b/pgproto3/authentication_sasl_final.go index a38a8b912..84976c2a3 100644 --- a/pgproto3/authentication_sasl_final.go +++ b/pgproto3/authentication_sasl_final.go @@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { - dst = append(dst, 'R') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'R') dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) - dst = append(dst, src.Data...) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Unmarshaler. diff --git a/pgproto3/backend.go b/pgproto3/backend.go index efa909c3a..d146c3384 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -16,7 +16,8 @@ type Backend struct { // before it is actually transmitted (i.e. before Flush). tracer *tracer - wbuf []byte + wbuf []byte + encodeError error // Frontend message flyweights bind Bind @@ -55,11 +56,21 @@ func NewBackend(r io.Reader, w io.Writer) *Backend { return &Backend{cr: cr, w: w} } -// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is -// called. +// Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error +// encountered will be returned from Flush. func (b *Backend) Send(msg BackendMessage) { + if b.encodeError != nil { + return + } + prevLen := len(b.wbuf) - b.wbuf = msg.Encode(b.wbuf) + newBuf, err := msg.Encode(b.wbuf) + if err != nil { + b.encodeError = err + return + } + b.wbuf = newBuf + if b.tracer != nil { b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg) } @@ -67,6 +78,12 @@ func (b *Backend) Send(msg BackendMessage) { // Flush writes any pending messages to the frontend (i.e. the client). func (b *Backend) Flush() error { + if err := b.encodeError; err != nil { + b.encodeError = nil + b.wbuf = b.wbuf[:0] + return &writeError{err: err, safeToRetry: true} + } + n, err := b.w.Write(b.wbuf) const maxLen = 1024 diff --git a/pgproto3/backend_key_data.go b/pgproto3/backend_key_data.go index 12c608170..23f5da677 100644 --- a/pgproto3/backend_key_data.go +++ b/pgproto3/backend_key_data.go @@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *BackendKeyData) Encode(dst []byte) []byte { - dst = append(dst, 'K') - dst = pgio.AppendUint32(dst, 12) +func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'K') dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.SecretKey) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/backend_test.go b/pgproto3/backend_test.go index 5655122a8..5107ef76a 100644 --- a/pgproto3/backend_test.go +++ b/pgproto3/backend_test.go @@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) { "username": "tester", }, } - dst := []byte{} - dst = want.Encode(dst) + dst, err := want.Encode([]byte{}) + require.NoError(t, err) server := &interruptReader{} server.push(dst) diff --git a/pgproto3/bind.go b/pgproto3/bind.go index fdd2d3b81..b32cd81ca 100644 --- a/pgproto3/bind.go +++ b/pgproto3/bind.go @@ -108,10 +108,8 @@ func (dst *Bind) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Bind) Encode(dst []byte) []byte { - dst = append(dst, 'B') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *Bind) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'B') dst = append(dst, src.DestinationPortal...) dst = append(dst, 0) @@ -139,9 +137,7 @@ func (src *Bind) Encode(dst []byte) []byte { dst = pgio.AppendInt16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/bind_complete.go b/pgproto3/bind_complete.go index 3be256c89..bacf30d88 100644 --- a/pgproto3/bind_complete.go +++ b/pgproto3/bind_complete.go @@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *BindComplete) Encode(dst []byte) []byte { - return append(dst, '2', 0, 0, 0, 4) +func (src *BindComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '2', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/bind_test.go b/pgproto3/bind_test.go new file mode 100644 index 000000000..6ec0e0245 --- /dev/null +++ b/pgproto3/bind_test.go @@ -0,0 +1,20 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) { + t.Parallel() + + // Maximum allowed size. + _, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil) + require.NoError(t, err) + + // 1 byte too big + _, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil) + require.Error(t, err) +} diff --git a/pgproto3/cancel_request.go b/pgproto3/cancel_request.go index 8fcf8217a..6b52dd977 100644 --- a/pgproto3/cancel_request.go +++ b/pgproto3/cancel_request.go @@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *CancelRequest) Encode(dst []byte) []byte { +func (src *CancelRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 16) dst = pgio.AppendInt32(dst, cancelRequestCode) dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.SecretKey) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/close.go b/pgproto3/close.go index f99b59439..0b50f27cb 100644 --- a/pgproto3/close.go +++ b/pgproto3/close.go @@ -4,8 +4,6 @@ import ( "bytes" "encoding/json" "errors" - - "github.com/jackc/pgx/v5/internal/pgio" ) type Close struct { @@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Close) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Close) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'C') dst = append(dst, src.ObjectType) dst = append(dst, src.Name...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/close_complete.go b/pgproto3/close_complete.go index 1d7b8f085..833f7a12c 100644 --- a/pgproto3/close_complete.go +++ b/pgproto3/close_complete.go @@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CloseComplete) Encode(dst []byte) []byte { - return append(dst, '3', 0, 0, 0, 4) +func (src *CloseComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '3', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/command_complete.go b/pgproto3/command_complete.go index 814027ca1..eba70947d 100644 --- a/pgproto3/command_complete.go +++ b/pgproto3/command_complete.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type CommandComplete struct { @@ -31,17 +29,11 @@ func (dst *CommandComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CommandComplete) Encode(dst []byte) []byte { - dst = append(dst, 'C') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *CommandComplete) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'C') dst = append(dst, src.CommandTag...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go index 8840a89ec..dbbd8e15c 100644 --- a/pgproto3/copy_both_response.go +++ b/pgproto3/copy_both_response.go @@ -44,19 +44,15 @@ func (dst *CopyBothResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyBothResponse) Encode(dst []byte) []byte { - dst = append(dst, 'W') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'W') dst = append(dst, src.OverallFormat) dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_both_response_test.go b/pgproto3/copy_both_response_test.go index 4437de1da..1c988f21d 100644 --- a/pgproto3/copy_both_response_test.go +++ b/pgproto3/copy_both_response_test.go @@ -5,6 +5,7 @@ import ( "github.com/jackc/pgx/v5/pgproto3" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestEncodeDecode(t *testing.T) { @@ -13,6 +14,7 @@ func TestEncodeDecode(t *testing.T) { err := dstResp.Decode(srcBytes[5:]) assert.NoError(t, err, "No errors on decode") dstBytes := []byte{} - dstBytes = dstResp.Encode(dstBytes) + dstBytes, err = dstResp.Encode(dstBytes) + require.NoError(t, err) assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match") } diff --git a/pgproto3/copy_data.go b/pgproto3/copy_data.go index 59e3dd942..89ecdd4dd 100644 --- a/pgproto3/copy_data.go +++ b/pgproto3/copy_data.go @@ -3,8 +3,6 @@ package pgproto3 import ( "encoding/hex" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type CopyData struct { @@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyData) Encode(dst []byte) []byte { - dst = append(dst, 'd') - dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) +func (src *CopyData) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'd') dst = append(dst, src.Data...) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_done.go b/pgproto3/copy_done.go index 0e13282bf..040814dbd 100644 --- a/pgproto3/copy_done.go +++ b/pgproto3/copy_done.go @@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyDone) Encode(dst []byte) []byte { - return append(dst, 'c', 0, 0, 0, 4) +func (src *CopyDone) Encode(dst []byte) ([]byte, error) { + return append(dst, 'c', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_fail.go b/pgproto3/copy_fail.go index 0041bbb1d..72a85fd09 100644 --- a/pgproto3/copy_fail.go +++ b/pgproto3/copy_fail.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type CopyFail struct { @@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyFail) Encode(dst []byte) []byte { - dst = append(dst, 'f') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *CopyFail) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'f') dst = append(dst, src.Message...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go index 4584f7df2..0a772afa0 100644 --- a/pgproto3/copy_in_response.go +++ b/pgproto3/copy_in_response.go @@ -44,10 +44,8 @@ func (dst *CopyInResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyInResponse) Encode(dst []byte) []byte { - dst = append(dst, 'G') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'G') dst = append(dst, src.OverallFormat) dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) @@ -55,9 +53,7 @@ func (src *CopyInResponse) Encode(dst []byte) []byte { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go index 3175c6a40..40525da62 100644 --- a/pgproto3/copy_out_response.go +++ b/pgproto3/copy_out_response.go @@ -43,10 +43,8 @@ func (dst *CopyOutResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *CopyOutResponse) Encode(dst []byte) []byte { - dst = append(dst, 'H') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'H') dst = append(dst, src.OverallFormat) @@ -55,9 +53,7 @@ func (src *CopyOutResponse) Encode(dst []byte) []byte { dst = pgio.AppendUint16(dst, fc) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go index 4de779772..cbc76dc24 100644 --- a/pgproto3/data_row.go +++ b/pgproto3/data_row.go @@ -63,10 +63,8 @@ func (dst *DataRow) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *DataRow) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *DataRow) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'D') dst = pgio.AppendUint16(dst, uint16(len(src.Values))) for _, v := range src.Values { @@ -79,9 +77,7 @@ func (src *DataRow) Encode(dst []byte) []byte { dst = append(dst, v...) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/describe.go b/pgproto3/describe.go index f131d1f48..89feff215 100644 --- a/pgproto3/describe.go +++ b/pgproto3/describe.go @@ -4,8 +4,6 @@ import ( "bytes" "encoding/json" "errors" - - "github.com/jackc/pgx/v5/internal/pgio" ) type Describe struct { @@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Describe) Encode(dst []byte) []byte { - dst = append(dst, 'D') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Describe) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'D') dst = append(dst, src.ObjectType) dst = append(dst, src.Name...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/empty_query_response.go b/pgproto3/empty_query_response.go index 2b85e744b..cb6cca073 100644 --- a/pgproto3/empty_query_response.go +++ b/pgproto3/empty_query_response.go @@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *EmptyQueryResponse) Encode(dst []byte) []byte { - return append(dst, 'I', 0, 0, 0, 4) +func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) { + return append(dst, 'I', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/error_response.go b/pgproto3/error_response.go index 45c9a9810..6ef9bd061 100644 --- a/pgproto3/error_response.go +++ b/pgproto3/error_response.go @@ -2,7 +2,6 @@ package pgproto3 import ( "bytes" - "encoding/binary" "encoding/json" "strconv" ) @@ -111,119 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ErrorResponse) Encode(dst []byte) []byte { - return append(dst, src.marshalBinary('E')...) +func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'E') + dst = src.appendFields(dst) + return finishMessage(dst, sp) } -func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { - var bigEndian BigEndianBuf - buf := &bytes.Buffer{} - - buf.WriteByte(typeByte) - buf.Write(bigEndian.Uint32(0)) - +func (src *ErrorResponse) appendFields(dst []byte) []byte { if src.Severity != "" { - buf.WriteByte('S') - buf.WriteString(src.Severity) - buf.WriteByte(0) + dst = append(dst, 'S') + dst = append(dst, src.Severity...) + dst = append(dst, 0) } if src.SeverityUnlocalized != "" { - buf.WriteByte('V') - buf.WriteString(src.SeverityUnlocalized) - buf.WriteByte(0) + dst = append(dst, 'V') + dst = append(dst, src.SeverityUnlocalized...) + dst = append(dst, 0) } if src.Code != "" { - buf.WriteByte('C') - buf.WriteString(src.Code) - buf.WriteByte(0) + dst = append(dst, 'C') + dst = append(dst, src.Code...) + dst = append(dst, 0) } if src.Message != "" { - buf.WriteByte('M') - buf.WriteString(src.Message) - buf.WriteByte(0) + dst = append(dst, 'M') + dst = append(dst, src.Message...) + dst = append(dst, 0) } if src.Detail != "" { - buf.WriteByte('D') - buf.WriteString(src.Detail) - buf.WriteByte(0) + dst = append(dst, 'D') + dst = append(dst, src.Detail...) + dst = append(dst, 0) } if src.Hint != "" { - buf.WriteByte('H') - buf.WriteString(src.Hint) - buf.WriteByte(0) + dst = append(dst, 'H') + dst = append(dst, src.Hint...) + dst = append(dst, 0) } if src.Position != 0 { - buf.WriteByte('P') - buf.WriteString(strconv.Itoa(int(src.Position))) - buf.WriteByte(0) + dst = append(dst, 'P') + dst = append(dst, strconv.Itoa(int(src.Position))...) + dst = append(dst, 0) } if src.InternalPosition != 0 { - buf.WriteByte('p') - buf.WriteString(strconv.Itoa(int(src.InternalPosition))) - buf.WriteByte(0) + dst = append(dst, 'p') + dst = append(dst, strconv.Itoa(int(src.InternalPosition))...) + dst = append(dst, 0) } if src.InternalQuery != "" { - buf.WriteByte('q') - buf.WriteString(src.InternalQuery) - buf.WriteByte(0) + dst = append(dst, 'q') + dst = append(dst, src.InternalQuery...) + dst = append(dst, 0) } if src.Where != "" { - buf.WriteByte('W') - buf.WriteString(src.Where) - buf.WriteByte(0) + dst = append(dst, 'W') + dst = append(dst, src.Where...) + dst = append(dst, 0) } if src.SchemaName != "" { - buf.WriteByte('s') - buf.WriteString(src.SchemaName) - buf.WriteByte(0) + dst = append(dst, 's') + dst = append(dst, src.SchemaName...) + dst = append(dst, 0) } if src.TableName != "" { - buf.WriteByte('t') - buf.WriteString(src.TableName) - buf.WriteByte(0) + dst = append(dst, 't') + dst = append(dst, src.TableName...) + dst = append(dst, 0) } if src.ColumnName != "" { - buf.WriteByte('c') - buf.WriteString(src.ColumnName) - buf.WriteByte(0) + dst = append(dst, 'c') + dst = append(dst, src.ColumnName...) + dst = append(dst, 0) } if src.DataTypeName != "" { - buf.WriteByte('d') - buf.WriteString(src.DataTypeName) - buf.WriteByte(0) + dst = append(dst, 'd') + dst = append(dst, src.DataTypeName...) + dst = append(dst, 0) } if src.ConstraintName != "" { - buf.WriteByte('n') - buf.WriteString(src.ConstraintName) - buf.WriteByte(0) + dst = append(dst, 'n') + dst = append(dst, src.ConstraintName...) + dst = append(dst, 0) } if src.File != "" { - buf.WriteByte('F') - buf.WriteString(src.File) - buf.WriteByte(0) + dst = append(dst, 'F') + dst = append(dst, src.File...) + dst = append(dst, 0) } if src.Line != 0 { - buf.WriteByte('L') - buf.WriteString(strconv.Itoa(int(src.Line))) - buf.WriteByte(0) + dst = append(dst, 'L') + dst = append(dst, strconv.Itoa(int(src.Line))...) + dst = append(dst, 0) } if src.Routine != "" { - buf.WriteByte('R') - buf.WriteString(src.Routine) - buf.WriteByte(0) + dst = append(dst, 'R') + dst = append(dst, src.Routine...) + dst = append(dst, 0) } for k, v := range src.UnknownFields { - buf.WriteByte(k) - buf.WriteString(v) - buf.WriteByte(0) + dst = append(dst, k) + dst = append(dst, v...) + dst = append(dst, 0) } - buf.WriteByte(0) - - binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) + dst = append(dst, 0) - return buf.Bytes() + return dst } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/example/pgfortune/server.go b/pgproto3/example/pgfortune/server.go index 14ae71f83..06a45dda0 100644 --- a/pgproto3/example/pgfortune/server.go +++ b/pgproto3/example/pgfortune/server.go @@ -46,7 +46,7 @@ func (p *PgFortuneBackend) Run() error { return fmt.Errorf("error generating query response: %w", err) } - buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ + buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ { Name: []byte("fortune"), TableOID: 0, @@ -56,10 +56,10 @@ func (p *PgFortuneBackend) Run() error { TypeModifier: -1, Format: 0, }, - }}).Encode(nil) - buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf) - buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf) - buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) + }}).Encode(nil)) + buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf)) + buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf)) + buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)) _, err = p.conn.Write(buf) if err != nil { return fmt.Errorf("error writing query response: %w", err) @@ -80,8 +80,8 @@ func (p *PgFortuneBackend) handleStartup() error { switch startupMessage.(type) { case *pgproto3.StartupMessage: - buf := (&pgproto3.AuthenticationOk{}).Encode(nil) - buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) + buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)) + buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf)) _, err = p.conn.Write(buf) if err != nil { return fmt.Errorf("error sending ready for query: %w", err) @@ -102,3 +102,10 @@ func (p *PgFortuneBackend) handleStartup() error { func (p *PgFortuneBackend) Close() error { return p.conn.Close() } + +func mustEncode(buf []byte, err error) []byte { + if err != nil { + panic(err) + } + return buf +} diff --git a/pgproto3/execute.go b/pgproto3/execute.go index a5fee7cb9..31bc714d1 100644 --- a/pgproto3/execute.go +++ b/pgproto3/execute.go @@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Execute) Encode(dst []byte) []byte { - dst = append(dst, 'E') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *Execute) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'E') dst = append(dst, src.Portal...) dst = append(dst, 0) - dst = pgio.AppendUint32(dst, src.MaxRows) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/flush.go b/pgproto3/flush.go index 2725f6894..e5dc1fbbd 100644 --- a/pgproto3/flush.go +++ b/pgproto3/flush.go @@ -20,8 +20,8 @@ func (dst *Flush) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Flush) Encode(dst []byte) []byte { - return append(dst, 'H', 0, 0, 0, 4) +func (src *Flush) Encode(dst []byte) ([]byte, error) { + return append(dst, 'H', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 60c34ef02..b41abbe10 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -18,7 +18,8 @@ type Frontend struct { // idle. Setting and unsetting tracer provides equivalent functionality to PQtrace and PQuntrace in libpq. tracer *tracer - wbuf []byte + wbuf []byte + encodeError error // Backend message flyweights authenticationOk AuthenticationOk @@ -64,16 +65,26 @@ func NewFrontend(r io.Reader, w io.Writer) *Frontend { return &Frontend{cr: cr, w: w} } -// Send sends a message to the backend (i.e. the server). The message is not guaranteed to be written until Flush is -// called. +// Send sends a message to the backend (i.e. the server). The message is buffered until Flush is called. Any error +// encountered will be returned from Flush. // // Send can work with any FrontendMessage. Some commonly used message types such as Bind have specialized send methods // such as SendBind. These methods should be preferred when the type of message is known up front (e.g. when building an // extended query protocol query) as they may be faster due to knowing the type of msg rather than it being hidden // behind an interface. func (f *Frontend) Send(msg FrontendMessage) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceMessage('F', int32(len(f.wbuf)-prevLen), msg) } @@ -81,6 +92,12 @@ func (f *Frontend) Send(msg FrontendMessage) { // Flush writes any pending messages to the backend (i.e. the server). func (f *Frontend) Flush() error { + if err := f.encodeError; err != nil { + f.encodeError = nil + f.wbuf = f.wbuf[:0] + return &writeError{err: err, safeToRetry: true} + } + if len(f.wbuf) == 0 { return nil } @@ -116,71 +133,141 @@ func (f *Frontend) Untrace() { f.tracer = nil } -// SendBind sends a Bind message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendBind sends a Bind message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendBind(msg *Bind) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceBind('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendParse sends a Parse message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendParse sends a Parse message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendParse(msg *Parse) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceParse('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendClose sends a Close message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendClose sends a Close message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendClose(msg *Close) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceClose('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendDescribe sends a Describe message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendDescribe sends a Describe message to the backend (i.e. the server). The message is buffered until Flush is +// called. Any error encountered will be returned from Flush. func (f *Frontend) SendDescribe(msg *Describe) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceDescribe('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendExecute sends an Execute message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendExecute sends an Execute message to the backend (i.e. the server). The message is buffered until Flush is called. +// Any error encountered will be returned from Flush. func (f *Frontend) SendExecute(msg *Execute) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.TraceQueryute('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendSync sends a Sync message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendSync sends a Sync message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendSync(msg *Sync) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceSync('F', int32(len(f.wbuf)-prevLen), msg) } } -// SendQuery sends a Query message to the backend (i.e. the server). The message is not guaranteed to be written until -// Flush is called. +// SendQuery sends a Query message to the backend (i.e. the server). The message is buffered until Flush is called. Any +// error encountered will be returned from Flush. func (f *Frontend) SendQuery(msg *Query) { + if f.encodeError != nil { + return + } + prevLen := len(f.wbuf) - f.wbuf = msg.Encode(f.wbuf) + newBuf, err := msg.Encode(f.wbuf) + if err != nil { + f.encodeError = err + return + } + f.wbuf = newBuf + if f.tracer != nil { f.tracer.traceQuery('F', int32(len(f.wbuf)-prevLen), msg) } diff --git a/pgproto3/function_call.go b/pgproto3/function_call.go index 2c4f38dfd..0b15fce23 100644 --- a/pgproto3/function_call.go +++ b/pgproto3/function_call.go @@ -71,10 +71,8 @@ func (dst *FunctionCall) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *FunctionCall) Encode(dst []byte) []byte { - dst = append(dst, 'F') - sp := len(dst) - dst = pgio.AppendUint32(dst, 0) // Unknown length, set it at the end +func (src *FunctionCall) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'F') dst = pgio.AppendUint32(dst, src.Function) dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) for _, argFormatCode := range src.ArgFormatCodes { @@ -90,6 +88,5 @@ func (src *FunctionCall) Encode(dst []byte) []byte { } } dst = pgio.AppendUint16(dst, src.ResultFormatCode) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - return dst + return finishMessage(dst, sp) } diff --git a/pgproto3/function_call_response.go b/pgproto3/function_call_response.go index 3d3606ddb..1f2734952 100644 --- a/pgproto3/function_call_response.go +++ b/pgproto3/function_call_response.go @@ -39,10 +39,8 @@ func (dst *FunctionCallResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *FunctionCallResponse) Encode(dst []byte) []byte { - dst = append(dst, 'V') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *FunctionCallResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'V') if src.Result == nil { dst = pgio.AppendInt32(dst, -1) @@ -51,9 +49,7 @@ func (src *FunctionCallResponse) Encode(dst []byte) []byte { dst = append(dst, src.Result...) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/function_call_test.go b/pgproto3/function_call_test.go index 8c08bb240..2a70fd308 100644 --- a/pgproto3/function_call_test.go +++ b/pgproto3/function_call_test.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "reflect" "testing" + + "github.com/stretchr/testify/require" ) func TestFunctionCall_EncodeDecode(t *testing.T) { @@ -30,7 +32,8 @@ func TestFunctionCall_EncodeDecode(t *testing.T) { Arguments: tt.fields.Arguments, ResultFormatCode: tt.fields.ResultFormatCode, } - encoded := src.Encode([]byte{}) + encoded, err := src.Encode([]byte{}) + require.NoError(t, err) dst := &FunctionCall{} // Check the header msgTypeCode := encoded[0] @@ -44,7 +47,7 @@ func TestFunctionCall_EncodeDecode(t *testing.T) { t.Errorf("Incorrect message length, got = %v, wanted = %v", l, len(encoded)) } // Check decoding works as expected - err := dst.Decode(encoded[5:]) + err = dst.Decode(encoded[5:]) if err != nil { if !tt.wantErr { t.Errorf("FunctionCall.Decode() error = %v, wantErr %v", err, tt.wantErr) diff --git a/pgproto3/gss_enc_request.go b/pgproto3/gss_enc_request.go index 30ffc08d2..70cb20cd5 100644 --- a/pgproto3/gss_enc_request.go +++ b/pgproto3/gss_enc_request.go @@ -31,10 +31,10 @@ func (dst *GSSEncRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *GSSEncRequest) Encode(dst []byte) []byte { +func (src *GSSEncRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 8) dst = pgio.AppendInt32(dst, gssEncReqNumber) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/gss_response.go b/pgproto3/gss_response.go index 64bfbd049..10d937759 100644 --- a/pgproto3/gss_response.go +++ b/pgproto3/gss_response.go @@ -2,8 +2,6 @@ package pgproto3 import ( "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type GSSResponse struct { @@ -18,11 +16,10 @@ func (g *GSSResponse) Decode(data []byte) error { return nil } -func (g *GSSResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(g.Data))) +func (g *GSSResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, g.Data...) - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/no_data.go b/pgproto3/no_data.go index d8f85d38a..cbcaad40c 100644 --- a/pgproto3/no_data.go +++ b/pgproto3/no_data.go @@ -20,8 +20,8 @@ func (dst *NoData) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NoData) Encode(dst []byte) []byte { - return append(dst, 'n', 0, 0, 0, 4) +func (src *NoData) Encode(dst []byte) ([]byte, error) { + return append(dst, 'n', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/notice_response.go b/pgproto3/notice_response.go index 4ac28a791..497aba6dd 100644 --- a/pgproto3/notice_response.go +++ b/pgproto3/notice_response.go @@ -12,6 +12,8 @@ func (dst *NoticeResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NoticeResponse) Encode(dst []byte) []byte { - return append(dst, (*ErrorResponse)(src).marshalBinary('N')...) +func (src *NoticeResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'N') + dst = (*ErrorResponse)(src).appendFields(dst) + return finishMessage(dst, sp) } diff --git a/pgproto3/notification_response.go b/pgproto3/notification_response.go index 228e0dac3..243b6bf7c 100644 --- a/pgproto3/notification_response.go +++ b/pgproto3/notification_response.go @@ -45,20 +45,14 @@ func (dst *NotificationResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *NotificationResponse) Encode(dst []byte) []byte { - dst = append(dst, 'A') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *NotificationResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'A') dst = pgio.AppendUint32(dst, src.PID) dst = append(dst, src.Channel...) dst = append(dst, 0) dst = append(dst, src.Payload...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go index 374d38a39..685e04b8c 100644 --- a/pgproto3/parameter_description.go +++ b/pgproto3/parameter_description.go @@ -39,19 +39,15 @@ func (dst *ParameterDescription) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParameterDescription) Encode(dst []byte) []byte { - dst = append(dst, 't') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 't') dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/parameter_status.go b/pgproto3/parameter_status.go index a303e4536..9ee0720b5 100644 --- a/pgproto3/parameter_status.go +++ b/pgproto3/parameter_status.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type ParameterStatus struct { @@ -37,19 +35,13 @@ func (dst *ParameterStatus) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParameterStatus) Encode(dst []byte) []byte { - dst = append(dst, 'S') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) - +func (src *ParameterStatus) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'S') dst = append(dst, src.Name...) dst = append(dst, 0) dst = append(dst, src.Value...) dst = append(dst, 0) - - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/parse.go b/pgproto3/parse.go index b53200dca..a59154cda 100644 --- a/pgproto3/parse.go +++ b/pgproto3/parse.go @@ -52,10 +52,8 @@ func (dst *Parse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Parse) Encode(dst []byte) []byte { - dst = append(dst, 'P') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *Parse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'P') dst = append(dst, src.Name...) dst = append(dst, 0) @@ -67,9 +65,7 @@ func (src *Parse) Encode(dst []byte) []byte { dst = pgio.AppendUint32(dst, oid) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/parse_complete.go b/pgproto3/parse_complete.go index 92c9498b6..cff9e27d0 100644 --- a/pgproto3/parse_complete.go +++ b/pgproto3/parse_complete.go @@ -20,8 +20,8 @@ func (dst *ParseComplete) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ParseComplete) Encode(dst []byte) []byte { - return append(dst, '1', 0, 0, 0, 4) +func (src *ParseComplete) Encode(dst []byte) ([]byte, error) { + return append(dst, '1', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/password_message.go b/pgproto3/password_message.go index 41f98692b..d820d3275 100644 --- a/pgproto3/password_message.go +++ b/pgproto3/password_message.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type PasswordMessage struct { @@ -32,14 +30,11 @@ func (dst *PasswordMessage) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *PasswordMessage) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(src.Password)+1)) - +func (src *PasswordMessage) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, src.Password...) dst = append(dst, 0) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/pgproto3.go b/pgproto3/pgproto3.go index 8df383c2c..480abfc06 100644 --- a/pgproto3/pgproto3.go +++ b/pgproto3/pgproto3.go @@ -4,8 +4,14 @@ import ( "encoding/hex" "errors" "fmt" + + "github.com/jackc/pgx/v5/internal/pgio" ) +// maxMessageBodyLen is the maximum length of a message body in bytes. See PG_LARGE_MESSAGE_LIMIT in the PostgreSQL +// source. It is defined as (MaxAllocSize - 1). MaxAllocSize is defined as 0x3fffffff. +const maxMessageBodyLen = (0x3fffffff - 1) + // Message is the interface implemented by an object that can decode and encode // a particular PostgreSQL message. type Message interface { @@ -14,7 +20,7 @@ type Message interface { Decode(data []byte) error // Encode appends itself to dst and returns the new buffer. - Encode(dst []byte) []byte + Encode(dst []byte) ([]byte, error) } // FrontendMessage is a message sent by the frontend (i.e. the client). @@ -92,3 +98,23 @@ func getValueFromJSON(v map[string]string) ([]byte, error) { } return nil, errors.New("unknown protocol representation") } + +// beginMessage begines a new message of type t. It appends the message type and a placeholder for the message length to +// dst. It returns the new buffer and the position of the message length placeholder. +func beginMessage(dst []byte, t byte) ([]byte, int) { + dst = append(dst, t) + sp := len(dst) + dst = pgio.AppendInt32(dst, -1) + return dst, sp +} + +// finishMessage finishes a message that was started with beginMessage. It computes the message length and writes it to +// dst[sp]. If the message length is too large it returns an error. Otherwise it returns the final message buffer. +func finishMessage(dst []byte, sp int) ([]byte, error) { + messageBodyLen := len(dst[sp:]) + if messageBodyLen > maxMessageBodyLen { + return nil, errors.New("message body too large") + } + pgio.SetInt32(dst[sp:], int32(messageBodyLen)) + return dst, nil +} diff --git a/pgproto3/pgproto3_private_test.go b/pgproto3/pgproto3_private_test.go new file mode 100644 index 000000000..15da1eafb --- /dev/null +++ b/pgproto3/pgproto3_private_test.go @@ -0,0 +1,3 @@ +package pgproto3 + +const MaxMessageBodyLen = maxMessageBodyLen diff --git a/pgproto3/portal_suspended.go b/pgproto3/portal_suspended.go index 1a9e7bfb1..9e2f8cbc4 100644 --- a/pgproto3/portal_suspended.go +++ b/pgproto3/portal_suspended.go @@ -20,8 +20,8 @@ func (dst *PortalSuspended) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *PortalSuspended) Encode(dst []byte) []byte { - return append(dst, 's', 0, 0, 0, 4) +func (src *PortalSuspended) Encode(dst []byte) ([]byte, error) { + return append(dst, 's', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/query.go b/pgproto3/query.go index e963a0ece..aebdfde89 100644 --- a/pgproto3/query.go +++ b/pgproto3/query.go @@ -3,8 +3,6 @@ package pgproto3 import ( "bytes" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type Query struct { @@ -28,14 +26,11 @@ func (dst *Query) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Query) Encode(dst []byte) []byte { - dst = append(dst, 'Q') - dst = pgio.AppendInt32(dst, int32(4+len(src.String)+1)) - +func (src *Query) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'Q') dst = append(dst, src.String...) dst = append(dst, 0) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/query_test.go b/pgproto3/query_test.go new file mode 100644 index 000000000..9551fc14d --- /dev/null +++ b/pgproto3/query_test.go @@ -0,0 +1,20 @@ +package pgproto3_test + +import ( + "testing" + + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/require" +) + +func TestQueryBiggerThanMaxMessageBodyLen(t *testing.T) { + t.Parallel() + + // Maximum allowed size. 4 bytes for size and 1 byte for 0 terminated string. + _, err := (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-5))}).Encode(nil) + require.NoError(t, err) + + // 1 byte too big + _, err = (&pgproto3.Query{String: string(make([]byte, pgproto3.MaxMessageBodyLen-4))}).Encode(nil) + require.Error(t, err) +} diff --git a/pgproto3/ready_for_query.go b/pgproto3/ready_for_query.go index 67a39be39..a56af9fb2 100644 --- a/pgproto3/ready_for_query.go +++ b/pgproto3/ready_for_query.go @@ -25,8 +25,8 @@ func (dst *ReadyForQuery) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *ReadyForQuery) Encode(dst []byte) []byte { - return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus) +func (src *ReadyForQuery) Encode(dst []byte) ([]byte, error) { + return append(dst, 'Z', 0, 0, 0, 5, src.TxStatus), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go index 6f6f06817..c68f1d466 100644 --- a/pgproto3/row_description.go +++ b/pgproto3/row_description.go @@ -99,10 +99,8 @@ func (dst *RowDescription) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *RowDescription) Encode(dst []byte) []byte { - dst = append(dst, 'T') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *RowDescription) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'T') dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) for _, fd := range src.Fields { @@ -117,9 +115,7 @@ func (src *RowDescription) Encode(dst []byte) []byte { dst = pgio.AppendInt16(dst, fd.Format) } - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/sasl_initial_response.go b/pgproto3/sasl_initial_response.go index eeda4691a..9eb1b6a4b 100644 --- a/pgproto3/sasl_initial_response.go +++ b/pgproto3/sasl_initial_response.go @@ -39,10 +39,8 @@ func (dst *SASLInitialResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *SASLInitialResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - sp := len(dst) - dst = pgio.AppendInt32(dst, -1) +func (src *SASLInitialResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, []byte(src.AuthMechanism)...) dst = append(dst, 0) @@ -50,9 +48,7 @@ func (src *SASLInitialResponse) Encode(dst []byte) []byte { dst = pgio.AppendInt32(dst, int32(len(src.Data))) dst = append(dst, src.Data...) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/sasl_response.go b/pgproto3/sasl_response.go index 54c3d96f3..1b604c254 100644 --- a/pgproto3/sasl_response.go +++ b/pgproto3/sasl_response.go @@ -3,8 +3,6 @@ package pgproto3 import ( "encoding/hex" "encoding/json" - - "github.com/jackc/pgx/v5/internal/pgio" ) type SASLResponse struct { @@ -22,13 +20,10 @@ func (dst *SASLResponse) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *SASLResponse) Encode(dst []byte) []byte { - dst = append(dst, 'p') - dst = pgio.AppendInt32(dst, int32(4+len(src.Data))) - +func (src *SASLResponse) Encode(dst []byte) ([]byte, error) { + dst, sp := beginMessage(dst, 'p') dst = append(dst, src.Data...) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/ssl_request.go b/pgproto3/ssl_request.go index 1b00c16b3..b0fc28476 100644 --- a/pgproto3/ssl_request.go +++ b/pgproto3/ssl_request.go @@ -31,10 +31,10 @@ func (dst *SSLRequest) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 4 byte message length. -func (src *SSLRequest) Encode(dst []byte) []byte { +func (src *SSLRequest) Encode(dst []byte) ([]byte, error) { dst = pgio.AppendInt32(dst, 8) dst = pgio.AppendInt32(dst, sslRequestNumber) - return dst + return dst, nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/startup_message.go b/pgproto3/startup_message.go index 65de4a360..3af4587d8 100644 --- a/pgproto3/startup_message.go +++ b/pgproto3/startup_message.go @@ -64,7 +64,7 @@ func (dst *StartupMessage) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *StartupMessage) Encode(dst []byte) []byte { +func (src *StartupMessage) Encode(dst []byte) ([]byte, error) { sp := len(dst) dst = pgio.AppendInt32(dst, -1) @@ -77,9 +77,7 @@ func (src *StartupMessage) Encode(dst []byte) []byte { } dst = append(dst, 0) - pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) - - return dst + return finishMessage(dst, sp) } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/sync.go b/pgproto3/sync.go index 5db8e07ac..ea4fc9594 100644 --- a/pgproto3/sync.go +++ b/pgproto3/sync.go @@ -20,8 +20,8 @@ func (dst *Sync) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Sync) Encode(dst []byte) []byte { - return append(dst, 'S', 0, 0, 0, 4) +func (src *Sync) Encode(dst []byte) ([]byte, error) { + return append(dst, 'S', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler. diff --git a/pgproto3/terminate.go b/pgproto3/terminate.go index 135191eae..35a9dc837 100644 --- a/pgproto3/terminate.go +++ b/pgproto3/terminate.go @@ -20,8 +20,8 @@ func (dst *Terminate) Decode(src []byte) error { } // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. -func (src *Terminate) Encode(dst []byte) []byte { - return append(dst, 'X', 0, 0, 0, 4) +func (src *Terminate) Encode(dst []byte) ([]byte, error) { + return append(dst, 'X', 0, 0, 0, 4), nil } // MarshalJSON implements encoding/json.Marshaler.