diff --git a/pgproto3/bind.go b/pgproto3/bind.go index b32cd81ca..ad6ac48bf 100644 --- a/pgproto3/bind.go +++ b/pgproto3/bind.go @@ -5,7 +5,9 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "errors" "fmt" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -116,11 +118,17 @@ func (src *Bind) Encode(dst []byte) ([]byte, error) { dst = append(dst, src.PreparedStatement...) dst = append(dst, 0) + if len(src.ParameterFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many parameter format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) for _, fc := range src.ParameterFormatCodes { dst = pgio.AppendInt16(dst, fc) } + if len(src.Parameters) > math.MaxUint16 { + return nil, errors.New("too many parameters") + } dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) for _, p := range src.Parameters { if p == nil { @@ -132,6 +140,9 @@ func (src *Bind) Encode(dst []byte) ([]byte, error) { dst = append(dst, p...) } + if len(src.ResultFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many result format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) for _, fc := range src.ResultFormatCodes { dst = pgio.AppendInt16(dst, fc) diff --git a/pgproto3/copy_both_response.go b/pgproto3/copy_both_response.go index dbbd8e15c..99e1afea4 100644 --- a/pgproto3/copy_both_response.go +++ b/pgproto3/copy_both_response.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -47,6 +48,9 @@ func (dst *CopyBothResponse) Decode(src []byte) error { func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 'W') dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) diff --git a/pgproto3/copy_in_response.go b/pgproto3/copy_in_response.go index 0a772afa0..06cf99ced 100644 --- a/pgproto3/copy_in_response.go +++ b/pgproto3/copy_in_response.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -48,6 +49,9 @@ func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 'G') dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) diff --git a/pgproto3/copy_out_response.go b/pgproto3/copy_out_response.go index 40525da62..549e916c1 100644 --- a/pgproto3/copy_out_response.go +++ b/pgproto3/copy_out_response.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "encoding/json" "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -48,6 +49,9 @@ func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) { dst = append(dst, src.OverallFormat) + if len(src.ColumnFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many column format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) for _, fc := range src.ColumnFormatCodes { dst = pgio.AppendUint16(dst, fc) diff --git a/pgproto3/data_row.go b/pgproto3/data_row.go index cbc76dc24..fdfb0f7f6 100644 --- a/pgproto3/data_row.go +++ b/pgproto3/data_row.go @@ -4,6 +4,8 @@ import ( "encoding/binary" "encoding/hex" "encoding/json" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -66,6 +68,9 @@ func (dst *DataRow) Decode(src []byte) error { func (src *DataRow) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 'D') + if len(src.Values) > math.MaxUint16 { + return nil, errors.New("too many values") + } dst = pgio.AppendUint16(dst, uint16(len(src.Values))) for _, v := range src.Values { if v == nil { diff --git a/pgproto3/function_call.go b/pgproto3/function_call.go index 0b15fce23..7d83579ff 100644 --- a/pgproto3/function_call.go +++ b/pgproto3/function_call.go @@ -2,6 +2,8 @@ package pgproto3 import ( "encoding/binary" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -74,10 +76,18 @@ func (dst *FunctionCall) Decode(src []byte) error { func (src *FunctionCall) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 'F') dst = pgio.AppendUint32(dst, src.Function) + + if len(src.ArgFormatCodes) > math.MaxUint16 { + return nil, errors.New("too many arg format codes") + } dst = pgio.AppendUint16(dst, uint16(len(src.ArgFormatCodes))) for _, argFormatCode := range src.ArgFormatCodes { dst = pgio.AppendUint16(dst, argFormatCode) } + + if len(src.Arguments) > math.MaxUint16 { + return nil, errors.New("too many arguments") + } dst = pgio.AppendUint16(dst, uint16(len(src.Arguments))) for _, argument := range src.Arguments { if argument == nil { diff --git a/pgproto3/parameter_description.go b/pgproto3/parameter_description.go index 685e04b8c..1ef27b75f 100644 --- a/pgproto3/parameter_description.go +++ b/pgproto3/parameter_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -42,6 +44,9 @@ func (dst *ParameterDescription) Decode(src []byte) error { func (src *ParameterDescription) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 't') + if len(src.ParameterOIDs) > math.MaxUint16 { + return nil, errors.New("too many parameter oids") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) diff --git a/pgproto3/parse.go b/pgproto3/parse.go index a59154cda..6ba3486cf 100644 --- a/pgproto3/parse.go +++ b/pgproto3/parse.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -60,6 +62,9 @@ func (src *Parse) Encode(dst []byte) ([]byte, error) { dst = append(dst, src.Query...) dst = append(dst, 0) + if len(src.ParameterOIDs) > math.MaxUint16 { + return nil, errors.New("too many parameter oids") + } dst = pgio.AppendUint16(dst, uint16(len(src.ParameterOIDs))) for _, oid := range src.ParameterOIDs { dst = pgio.AppendUint32(dst, oid) diff --git a/pgproto3/row_description.go b/pgproto3/row_description.go index c68f1d466..dc2a4ddf2 100644 --- a/pgproto3/row_description.go +++ b/pgproto3/row_description.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "encoding/json" + "errors" + "math" "github.com/jackc/pgx/v5/internal/pgio" ) @@ -102,6 +104,9 @@ func (dst *RowDescription) Decode(src []byte) error { func (src *RowDescription) Encode(dst []byte) ([]byte, error) { dst, sp := beginMessage(dst, 'T') + if len(src.Fields) > math.MaxUint16 { + return nil, errors.New("too many fields") + } dst = pgio.AppendUint16(dst, uint16(len(src.Fields))) for _, fd := range src.Fields { dst = append(dst, fd.Name...)