diff --git a/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor.go b/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor.go index 5f4899c4cae8..4b644ef4bfbf 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor.go +++ b/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor.go @@ -18,14 +18,13 @@ import ( // BackendInterceptor is a server int/erceptor for the Postgres backend protocol. type BackendInterceptor pgInterceptor -// NewBackendInterceptor creates a BackendInterceptor. bufSize must be at least -// the size of a pgwire message header. -func NewBackendInterceptor(src io.Reader, bufSize int) (*BackendInterceptor, error) { - pgi, err := newPgInterceptor(src, bufSize) - if err != nil { - return nil, err - } - return (*BackendInterceptor)(pgi), nil +// NewBackendInterceptor creates a BackendInterceptor. If bufSize is smaller +// than 5 bytes, the defaults (8K) will be used. +// +// NOTE: For future improvement, we can use the options pattern here if there's +// a need for more than one field. +func NewBackendInterceptor(src io.Reader, bufSize int) *BackendInterceptor { + return (*BackendInterceptor)(newPgInterceptor(src, bufSize)) } // PeekMsg returns the header of the current pgwire message without advancing diff --git a/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go b/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go index c06350a48c91..1403f472e707 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go +++ b/pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go @@ -35,17 +35,15 @@ func TestBackendInterceptor(t *testing.T) { return src } - t.Run("bufSize too small", func(t *testing.T) { - bi, err := interceptor.NewBackendInterceptor(nil /* src */, 1 /* bufSize */) - require.Error(t, err) - require.Nil(t, bi) + t.Run("small bufSize", func(t *testing.T) { + bi := interceptor.NewBackendInterceptor(nil /* src */, 1 /* bufSize */) + require.NotNil(t, bi) }) t.Run("PeekMsg returns the right message type", func(t *testing.T) { src := buildSrc(t, 1) - bi, err := interceptor.NewBackendInterceptor(src, 16 /* bufSize */) - require.NoError(t, err) + bi := interceptor.NewBackendInterceptor(src, 16 /* bufSize */) require.NotNil(t, bi) typ, size, err := bi.PeekMsg() @@ -56,8 +54,7 @@ func TestBackendInterceptor(t *testing.T) { t.Run("WriteMsg writes data to dst", func(t *testing.T) { dst := new(bytes.Buffer) - bi, err := interceptor.NewBackendInterceptor(nil /* src */, 10 /* bufSize */) - require.NoError(t, err) + bi := interceptor.NewBackendInterceptor(nil /* src */, 10 /* bufSize */) require.NotNil(t, bi) // This is a backend interceptor, so writing goes to the server. @@ -71,8 +68,7 @@ func TestBackendInterceptor(t *testing.T) { t.Run("ReadMsg decodes the message correctly", func(t *testing.T) { src := buildSrc(t, 1) - bi, err := interceptor.NewBackendInterceptor(src, 16 /* bufSize */) - require.NoError(t, err) + bi := interceptor.NewBackendInterceptor(src, 16 /* bufSize */) require.NotNil(t, bi) msg, err := bi.ReadMsg() @@ -86,8 +82,7 @@ func TestBackendInterceptor(t *testing.T) { src := buildSrc(t, 1) dst := new(bytes.Buffer) - bi, err := interceptor.NewBackendInterceptor(src, 16 /* bufSize */) - require.NoError(t, err) + bi := interceptor.NewBackendInterceptor(src, 16 /* bufSize */) require.NotNil(t, bi) n, err := bi.ForwardMsg(dst) diff --git a/pkg/ccl/sqlproxyccl/interceptor/base.go b/pkg/ccl/sqlproxyccl/interceptor/base.go index 04b92c3069c4..417b383a662e 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/base.go +++ b/pkg/ccl/sqlproxyccl/interceptor/base.go @@ -22,6 +22,12 @@ import ( // length itself). const pgHeaderSizeBytes = 5 +// defaultBufferSize is the default buffer size for the interceptor. 8K was +// chosen to match Postgres' send and receive buffer sizes. +// +// See: https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134-L135. +const defaultBufferSize = 2 << 13 // 8K + // ErrSmallBuffer indicates that the requested buffer for the interceptor is // too small. var ErrSmallBuffer = errors.New("buffer is too small") @@ -58,17 +64,18 @@ type pgInterceptor struct { } // newPgInterceptor creates a new instance of the interceptor with an internal -// buffer of bufSize bytes. bufSize must be at least the size of a pgwire -// message header. -func newPgInterceptor(src io.Reader, bufSize int) (*pgInterceptor, error) { - // The internal buffer must be able to fit the header. +// buffer of bufSize bytes. If bufSize is smaller than 5 bytes, the interceptor +// will default to an 8K buffer size. +func newPgInterceptor(src io.Reader, bufSize int) *pgInterceptor { + // The internal buffer must be able to fit the header. If bufSize is smaller + // than 5 bytes, just default to 8K, or else the interceptor is unusable. if bufSize < pgHeaderSizeBytes { - return nil, ErrSmallBuffer + bufSize = defaultBufferSize } return &pgInterceptor{ src: src, buf: make([]byte, bufSize), - }, nil + } } // PeekMsg returns the header of the current pgwire message without advancing diff --git a/pkg/ccl/sqlproxyccl/interceptor/base_test.go b/pkg/ccl/sqlproxyccl/interceptor/base_test.go index 21c47a72b9c3..aca49e46f97b 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/base_test.go +++ b/pkg/ccl/sqlproxyccl/interceptor/base_test.go @@ -29,28 +29,20 @@ func TestNewPgInterceptor(t *testing.T) { reader, _ := io.Pipe() - // Negative buffer size. - pgi, err := newPgInterceptor(reader, -1 /* bufSize */) - require.EqualError(t, err, ErrSmallBuffer.Error()) - require.Nil(t, pgi) - - // Small buffer size. - pgi, err = newPgInterceptor(reader, pgHeaderSizeBytes-1) - require.EqualError(t, err, ErrSmallBuffer.Error()) - require.Nil(t, pgi) - - // Buffer that fits the header exactly. - pgi, err = newPgInterceptor(reader, pgHeaderSizeBytes) - require.NoError(t, err) - require.NotNil(t, pgi) - require.Len(t, pgi.buf, pgHeaderSizeBytes) - - // Normal buffer size. - pgi, err = newPgInterceptor(reader, 1024 /* bufSize */) - require.NoError(t, err) - require.NotNil(t, pgi) - require.Len(t, pgi.buf, 1024) - require.Equal(t, reader, pgi.src) + for _, tc := range []struct { + bufSize int + normalizedBufSize int + }{ + {-1, defaultBufferSize}, + {pgHeaderSizeBytes - 1, defaultBufferSize}, + {pgHeaderSizeBytes, pgHeaderSizeBytes}, + {1024, 1024}, + } { + pgi := newPgInterceptor(reader, tc.bufSize) + require.NotNil(t, pgi) + require.Len(t, pgi.buf, tc.normalizedBufSize) + require.Equal(t, reader, pgi.src) + } } func TestPGInterceptor_PeekMsg(t *testing.T) { @@ -59,8 +51,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { t.Run("read_error", func(t *testing.T) { r := iotest.ErrReader(errors.New("read error")) - pgi, err := newPgInterceptor(r, 10 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(r, 10 /* bufSize */) typ, size, err := pgi.PeekMsg() require.EqualError(t, err, "read error") @@ -75,8 +66,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { _, err := buf.Write(data[:]) require.NoError(t, err) - pgi, err := newPgInterceptor(buf, 10 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 10 /* bufSize */) typ, size, err := pgi.PeekMsg() require.EqualError(t, err, ErrProtocolError.Error()) @@ -92,8 +82,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { _, err := buf.Write(data[:]) require.NoError(t, err) - pgi, err := newPgInterceptor(buf, 10 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 10 /* bufSize */) typ, size, err := pgi.PeekMsg() require.EqualError(t, err, ErrProtocolError.Error()) @@ -109,8 +98,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { _, err := buf.Write(data[:]) require.NoError(t, err) - pgi, err := newPgInterceptor(buf, 10 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 10 /* bufSize */) typ, size, err := pgi.PeekMsg() require.EqualError(t, err, ErrProtocolError.Error()) @@ -128,8 +116,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { _, err := buf.Write(data[:]) require.NoError(t, err) - pgi, err := newPgInterceptor(buf, 5 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 5 /* bufSize */) typ, size, err := pgi.PeekMsg() require.NoError(t, err) @@ -141,8 +128,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) { t.Run("successful", func(t *testing.T) { buf := buildSrc(t, 1) - pgi, err := newPgInterceptor(buf, 10 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 10 /* bufSize */) typ, size, err := pgi.PeekMsg() require.NoError(t, err) @@ -168,8 +154,7 @@ func TestPGInterceptor_ReadMsg(t *testing.T) { // Use a LimitReader to allow PeekMsg to read 5 bytes, then update src // back to the original version. src := &errReadWriter{r: buf, count: 2} - pgi, err := newPgInterceptor(io.LimitReader(src, 5), 32 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(io.LimitReader(src, 5), 32 /* bufSize */) // Call PeekMsg here to populate internal buffer with header. typ, size, err := pgi.PeekMsg() @@ -195,8 +180,7 @@ func TestPGInterceptor_ReadMsg(t *testing.T) { // testSelect1Bytes has 14 bytes, but only 6 bytes within internal // buffer, so overflow. src := &errReadWriter{r: buf, count: 2} - pgi, err := newPgInterceptor(src, 6 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(src, 6 /* bufSize */) msg, err := pgi.ReadMsg() require.EqualError(t, err, io.ErrClosedPipe.Error()) @@ -211,8 +195,7 @@ func TestPGInterceptor_ReadMsg(t *testing.T) { // Set buffer's size to be a multiple of the message so that we'll // always hit the case where the message fits. - pgi, err := newPgInterceptor(buf, len(testSelect1Bytes)*3) - require.NoError(t, err) + pgi := newPgInterceptor(buf, len(testSelect1Bytes)*3) c := 0 n := testing.AllocsPerRun(count-1, func() { @@ -246,8 +229,7 @@ func TestPGInterceptor_ReadMsg(t *testing.T) { // Set the buffer to be large enough to fit more bytes than the header, // but not the entire message. - pgi, err := newPgInterceptor(buf, 7 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 7 /* bufSize */) c := 0 n := testing.AllocsPerRun(count-1, func() { @@ -286,8 +268,7 @@ func TestPGInterceptor_ForwardMsg(t *testing.T) { dst := new(bytes.Buffer) dstWriter := &errReadWriter{w: dst, count: 1} - pgi, err := newPgInterceptor(src, 32 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(src, 32 /* bufSize */) n, err := pgi.ForwardMsg(dstWriter) require.EqualError(t, err, io.ErrClosedPipe.Error()) @@ -306,8 +287,7 @@ func TestPGInterceptor_ForwardMsg(t *testing.T) { // testSelect1Bytes has 14 bytes, but only 6 bytes within internal // buffer, so partially buffered. - pgi, err := newPgInterceptor(src, 6 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(src, 6 /* bufSize */) n, err := pgi.ForwardMsg(dstWriter) require.EqualError(t, err, io.ErrClosedPipe.Error()) @@ -324,8 +304,7 @@ func TestPGInterceptor_ForwardMsg(t *testing.T) { // Set buffer's size to be a multiple of the message so that we'll // always hit the case where the message fits. - pgi, err := newPgInterceptor(src, len(testSelect1Bytes)*3) - require.NoError(t, err) + pgi := newPgInterceptor(src, len(testSelect1Bytes)*3) // Forward all the messages, and ensure 0 allocations. n := testing.AllocsPerRun(count-1, func() { @@ -349,8 +328,7 @@ func TestPGInterceptor_ForwardMsg(t *testing.T) { // Set the buffer to be large enough to fit more bytes than the header, // but not the entire message. - pgi, err := newPgInterceptor(src, 7 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(src, 7 /* bufSize */) n := testing.AllocsPerRun(count-1, func() { n, err := pgi.ForwardMsg(dst) @@ -376,8 +354,7 @@ func TestPGInterceptor_readSize(t *testing.T) { defer leaktest.AfterTest(t)() buf := bytes.NewBufferString("foobarbazz") - pgi, err := newPgInterceptor(iotest.OneByteReader(buf), 10 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(iotest.OneByteReader(buf), 10 /* bufSize */) // No reads to internal buffer. require.Equal(t, 0, pgi.readSize()) @@ -395,8 +372,7 @@ func TestPGInterceptor_writeSize(t *testing.T) { defer leaktest.AfterTest(t)() buf := bytes.NewBufferString("foobarbazz") - pgi, err := newPgInterceptor(iotest.OneByteReader(buf), 10 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(iotest.OneByteReader(buf), 10 /* bufSize */) // No writes to internal buffer. require.Equal(t, 10, pgi.writeSize()) @@ -414,8 +390,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) { defer leaktest.AfterTest(t)() t.Run("invalid n", func(t *testing.T) { - pgi, err := newPgInterceptor(nil /* src */, 8 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(nil /* src */, 8 /* bufSize */) require.EqualError(t, pgi.ensureNextNBytes(-1), "invalid number of bytes -1 for buffer size 8") @@ -426,8 +401,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) { t.Run("buffer already has n bytes", func(t *testing.T) { buf := bytes.NewBufferString("foobarbaz") - pgi, err := newPgInterceptor(iotest.OneByteReader(buf), 8 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(iotest.OneByteReader(buf), 8 /* bufSize */) // Read "foo" into buffer". require.NoError(t, pgi.ensureNextNBytes(3)) @@ -448,8 +422,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) { t.Run("bytes are realigned", func(t *testing.T) { buf := bytes.NewBufferString("foobarbazcar") - pgi, err := newPgInterceptor(iotest.OneByteReader(buf), 9 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(iotest.OneByteReader(buf), 9 /* bufSize */) // Read "foobarb" into buffer. require.NoError(t, pgi.ensureNextNBytes(7)) @@ -470,8 +443,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) { // if there was a Read call. buf := bytes.NewBufferString("foobarbaz") - pgi, err := newPgInterceptor(buf, 10 /* bufSize */) - require.NoError(t, err) + pgi := newPgInterceptor(buf, 10 /* bufSize */) // Request for only 1 byte. require.NoError(t, pgi.ensureNextNBytes(1)) @@ -480,7 +452,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) { require.Equal(t, "foobarbaz", string(pgi.buf[pgi.readPos:pgi.writePos])) // Should be a no-op. - _, err = buf.WriteString("car") + _, err := buf.WriteString("car") require.NoError(t, err) require.NoError(t, pgi.ensureNextNBytes(9)) require.Equal(t, 3, buf.Len()) diff --git a/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor.go b/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor.go index 8e580b77beff..d6dff9d1ac4d 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor.go +++ b/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor.go @@ -15,18 +15,16 @@ import ( "github.com/jackc/pgproto3/v2" ) -// FrontendInterceptor is a client interceptor for the Postgres frontend -// protocol. +// FrontendInterceptor is a client interceptor for the Postgres frontend protocol. type FrontendInterceptor pgInterceptor -// NewFrontendInterceptor creates a FrontendInterceptor. bufSize must be at -// least the size of a pgwire message header. -func NewFrontendInterceptor(src io.Reader, bufSize int) (*FrontendInterceptor, error) { - pgi, err := newPgInterceptor(src, bufSize) - if err != nil { - return nil, err - } - return (*FrontendInterceptor)(pgi), nil +// NewFrontendInterceptor creates a FrontendInterceptor. If bufSize is smaller +// than 5 bytes, the defaults (8K) will be used. +// +// NOTE: For future improvement, we can use the options pattern here if there's +// a need for more than one field. +func NewFrontendInterceptor(src io.Reader, bufSize int) *FrontendInterceptor { + return (*FrontendInterceptor)(newPgInterceptor(src, bufSize)) } // PeekMsg returns the header of the current pgwire message without advancing diff --git a/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor_test.go b/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor_test.go index 494f9690563b..3b7c28ad0dfb 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor_test.go +++ b/pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor_test.go @@ -35,17 +35,15 @@ func TestFrontendInterceptor(t *testing.T) { return src } - t.Run("bufSize too small", func(t *testing.T) { - fi, err := interceptor.NewFrontendInterceptor(nil /* src */, 1 /* bufSize */) - require.Error(t, err) - require.Nil(t, fi) + t.Run("small bufSize", func(t *testing.T) { + fi := interceptor.NewFrontendInterceptor(nil /* src */, 1 /* bufSize */) + require.NotNil(t, fi) }) t.Run("PeekMsg returns the right message type", func(t *testing.T) { src := buildSrc(t, 1) - fi, err := interceptor.NewFrontendInterceptor(src, 16 /* bufSize */) - require.NoError(t, err) + fi := interceptor.NewFrontendInterceptor(src, 16 /* bufSize */) require.NotNil(t, fi) typ, size, err := fi.PeekMsg() @@ -56,8 +54,7 @@ func TestFrontendInterceptor(t *testing.T) { t.Run("WriteMsg writes data to dst", func(t *testing.T) { dst := new(bytes.Buffer) - fi, err := interceptor.NewFrontendInterceptor(nil /* src */, 10 /* bufSize */) - require.NoError(t, err) + fi := interceptor.NewFrontendInterceptor(nil /* src */, 10 /* bufSize */) require.NotNil(t, fi) // This is a frontend interceptor, so writing goes to the client. @@ -71,8 +68,7 @@ func TestFrontendInterceptor(t *testing.T) { t.Run("ReadMsg decodes the message correctly", func(t *testing.T) { src := buildSrc(t, 1) - fi, err := interceptor.NewFrontendInterceptor(src, 16 /* bufSize */) - require.NoError(t, err) + fi := interceptor.NewFrontendInterceptor(src, 16 /* bufSize */) require.NotNil(t, fi) msg, err := fi.ReadMsg() @@ -86,8 +82,7 @@ func TestFrontendInterceptor(t *testing.T) { src := buildSrc(t, 1) dst := new(bytes.Buffer) - fi, err := interceptor.NewFrontendInterceptor(src, 16 /* bufSize */) - require.NoError(t, err) + fi := interceptor.NewFrontendInterceptor(src, 16 /* bufSize */) require.NotNil(t, fi) n, err := fi.ForwardMsg(dst) diff --git a/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go b/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go index b682df666ef6..d5cbb1776314 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go +++ b/pkg/ccl/sqlproxyccl/interceptor/interceptor_test.go @@ -33,10 +33,8 @@ func TestSimpleProxy(t *testing.T) { toServer := new(bytes.Buffer) // Create client and server interceptors. - clientInt, err := interceptor.NewBackendInterceptor(fromClient, bufferSize) - require.NoError(t, err) - serverInt, err := interceptor.NewFrontendInterceptor(fromServer, bufferSize) - require.NoError(t, err) + clientInt := interceptor.NewBackendInterceptor(fromClient, bufferSize) + serverInt := interceptor.NewFrontendInterceptor(fromServer, bufferSize) t.Run("client to server", func(t *testing.T) { // Client sends a list of SQL queries.