Skip to content

Commit

Permalink
ccl/sqlproxyccl: use a default buffer size of 8K within the interceptors
Browse files Browse the repository at this point in the history
Previously, we returned an error whenever callers attempt to create
interceptors with a small buffer size. This case is very uncommon, and the API
can be awkward since we now need to handle the error case. To address that,
this commit updates the interceptor's behavior such that we default to an 8K
buffer whenever a buffer size smaller than 5 bytes is used. Since sqlproxy is
the only user, this seems to be a reasonable tradeoff.

Release note: None
  • Loading branch information
jaylim-crl committed Feb 15, 2022
1 parent 179cbe3 commit db4533b
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 118 deletions.
15 changes: 7 additions & 8 deletions pkg/ccl/sqlproxyccl/interceptor/backend_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ import (
// BackendInterceptor is a server int/erceptor for the Postgres backend protocol.
type BackendInterceptor pgInterceptor

// NewBackendInterceptor creates a BackendInterceptor. bufSize must be at least
// the size of a pgwire message header.
func NewBackendInterceptor(src io.Reader, bufSize int) (*BackendInterceptor, error) {
pgi, err := newPgInterceptor(src, bufSize)
if err != nil {
return nil, err
}
return (*BackendInterceptor)(pgi), nil
// NewBackendInterceptor creates a BackendInterceptor. If bufSize is smaller
// than 5 bytes, the defaults (8K) will be used.
//
// NOTE: For future improvement, we can use the options pattern here if there's
// a need for more than one field.
func NewBackendInterceptor(src io.Reader, bufSize int) *BackendInterceptor {
return (*BackendInterceptor)(newPgInterceptor(src, bufSize))
}

// PeekMsg returns the header of the current pgwire message without advancing
Expand Down
19 changes: 7 additions & 12 deletions pkg/ccl/sqlproxyccl/interceptor/backend_interceptor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,15 @@ func TestBackendInterceptor(t *testing.T) {
return src
}

t.Run("bufSize too small", func(t *testing.T) {
bi, err := interceptor.NewBackendInterceptor(nil /* src */, 1 /* bufSize */)
require.Error(t, err)
require.Nil(t, bi)
t.Run("small bufSize", func(t *testing.T) {
bi := interceptor.NewBackendInterceptor(nil /* src */, 1 /* bufSize */)
require.NotNil(t, bi)
})

t.Run("PeekMsg returns the right message type", func(t *testing.T) {
src := buildSrc(t, 1)

bi, err := interceptor.NewBackendInterceptor(src, 16 /* bufSize */)
require.NoError(t, err)
bi := interceptor.NewBackendInterceptor(src, 16 /* bufSize */)
require.NotNil(t, bi)

typ, size, err := bi.PeekMsg()
Expand All @@ -56,8 +54,7 @@ func TestBackendInterceptor(t *testing.T) {

t.Run("WriteMsg writes data to dst", func(t *testing.T) {
dst := new(bytes.Buffer)
bi, err := interceptor.NewBackendInterceptor(nil /* src */, 10 /* bufSize */)
require.NoError(t, err)
bi := interceptor.NewBackendInterceptor(nil /* src */, 10 /* bufSize */)
require.NotNil(t, bi)

// This is a backend interceptor, so writing goes to the server.
Expand All @@ -71,8 +68,7 @@ func TestBackendInterceptor(t *testing.T) {
t.Run("ReadMsg decodes the message correctly", func(t *testing.T) {
src := buildSrc(t, 1)

bi, err := interceptor.NewBackendInterceptor(src, 16 /* bufSize */)
require.NoError(t, err)
bi := interceptor.NewBackendInterceptor(src, 16 /* bufSize */)
require.NotNil(t, bi)

msg, err := bi.ReadMsg()
Expand All @@ -86,8 +82,7 @@ func TestBackendInterceptor(t *testing.T) {
src := buildSrc(t, 1)
dst := new(bytes.Buffer)

bi, err := interceptor.NewBackendInterceptor(src, 16 /* bufSize */)
require.NoError(t, err)
bi := interceptor.NewBackendInterceptor(src, 16 /* bufSize */)
require.NotNil(t, bi)

n, err := bi.ForwardMsg(dst)
Expand Down
21 changes: 12 additions & 9 deletions pkg/ccl/sqlproxyccl/interceptor/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import (
// length itself).
const pgHeaderSizeBytes = 5

// ErrSmallBuffer indicates that the requested buffer for the interceptor is
// too small.
var ErrSmallBuffer = errors.New("buffer is too small")
// defaultBufferSize is the default buffer size for the interceptor. 8K was
// chosen to match Postgres' send and receive buffer sizes.
//
// See: https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134-L135.
const defaultBufferSize = 2 << 13 // 8K

// ErrProtocolError indicates that the packets are malformed, and are not as
// expected.
Expand Down Expand Up @@ -58,17 +60,18 @@ type pgInterceptor struct {
}

// newPgInterceptor creates a new instance of the interceptor with an internal
// buffer of bufSize bytes. bufSize must be at least the size of a pgwire
// message header.
func newPgInterceptor(src io.Reader, bufSize int) (*pgInterceptor, error) {
// The internal buffer must be able to fit the header.
// buffer of bufSize bytes. If bufSize is smaller than 5 bytes, the interceptor
// will default to an 8K buffer size.
func newPgInterceptor(src io.Reader, bufSize int) *pgInterceptor {
// The internal buffer must be able to fit the header. If bufSize is smaller
// than 5 bytes, just default to 8K, or else the interceptor is unusable.
if bufSize < pgHeaderSizeBytes {
return nil, ErrSmallBuffer
bufSize = defaultBufferSize
}
return &pgInterceptor{
src: src,
buf: make([]byte, bufSize),
}, nil
}
}

// PeekMsg returns the header of the current pgwire message without advancing
Expand Down
98 changes: 35 additions & 63 deletions pkg/ccl/sqlproxyccl/interceptor/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,20 @@ func TestNewPgInterceptor(t *testing.T) {

reader, _ := io.Pipe()

// Negative buffer size.
pgi, err := newPgInterceptor(reader, -1 /* bufSize */)
require.EqualError(t, err, ErrSmallBuffer.Error())
require.Nil(t, pgi)

// Small buffer size.
pgi, err = newPgInterceptor(reader, pgHeaderSizeBytes-1)
require.EqualError(t, err, ErrSmallBuffer.Error())
require.Nil(t, pgi)

// Buffer that fits the header exactly.
pgi, err = newPgInterceptor(reader, pgHeaderSizeBytes)
require.NoError(t, err)
require.NotNil(t, pgi)
require.Len(t, pgi.buf, pgHeaderSizeBytes)

// Normal buffer size.
pgi, err = newPgInterceptor(reader, 1024 /* bufSize */)
require.NoError(t, err)
require.NotNil(t, pgi)
require.Len(t, pgi.buf, 1024)
require.Equal(t, reader, pgi.src)
for _, tc := range []struct {
bufSize int
normalizedBufSize int
}{
{-1, defaultBufferSize},
{pgHeaderSizeBytes - 1, defaultBufferSize},
{pgHeaderSizeBytes, pgHeaderSizeBytes},
{1024, 1024},
} {
pgi := newPgInterceptor(reader, tc.bufSize)
require.NotNil(t, pgi)
require.Len(t, pgi.buf, tc.normalizedBufSize)
require.Equal(t, reader, pgi.src)
}
}

func TestPGInterceptor_PeekMsg(t *testing.T) {
Expand All @@ -59,8 +51,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) {
t.Run("read_error", func(t *testing.T) {
r := iotest.ErrReader(errors.New("read error"))

pgi, err := newPgInterceptor(r, 10 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(r, 10 /* bufSize */)

typ, size, err := pgi.PeekMsg()
require.EqualError(t, err, "read error")
Expand All @@ -75,8 +66,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) {
_, err := buf.Write(data[:])
require.NoError(t, err)

pgi, err := newPgInterceptor(buf, 10 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(buf, 10 /* bufSize */)

typ, size, err := pgi.PeekMsg()
require.EqualError(t, err, ErrProtocolError.Error())
Expand All @@ -92,8 +82,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) {
_, err := buf.Write(data[:])
require.NoError(t, err)

pgi, err := newPgInterceptor(buf, 10 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(buf, 10 /* bufSize */)

typ, size, err := pgi.PeekMsg()
require.EqualError(t, err, ErrProtocolError.Error())
Expand All @@ -109,8 +98,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) {
_, err := buf.Write(data[:])
require.NoError(t, err)

pgi, err := newPgInterceptor(buf, 10 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(buf, 10 /* bufSize */)

typ, size, err := pgi.PeekMsg()
require.EqualError(t, err, ErrProtocolError.Error())
Expand All @@ -128,8 +116,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) {
_, err := buf.Write(data[:])
require.NoError(t, err)

pgi, err := newPgInterceptor(buf, 5 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(buf, 5 /* bufSize */)

typ, size, err := pgi.PeekMsg()
require.NoError(t, err)
Expand All @@ -141,8 +128,7 @@ func TestPGInterceptor_PeekMsg(t *testing.T) {
t.Run("successful", func(t *testing.T) {
buf := buildSrc(t, 1)

pgi, err := newPgInterceptor(buf, 10 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(buf, 10 /* bufSize */)

typ, size, err := pgi.PeekMsg()
require.NoError(t, err)
Expand All @@ -168,8 +154,7 @@ func TestPGInterceptor_ReadMsg(t *testing.T) {
// Use a LimitReader to allow PeekMsg to read 5 bytes, then update src
// back to the original version.
src := &errReadWriter{r: buf, count: 2}
pgi, err := newPgInterceptor(io.LimitReader(src, 5), 32 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(io.LimitReader(src, 5), 32 /* bufSize */)

// Call PeekMsg here to populate internal buffer with header.
typ, size, err := pgi.PeekMsg()
Expand All @@ -195,8 +180,7 @@ func TestPGInterceptor_ReadMsg(t *testing.T) {
// testSelect1Bytes has 14 bytes, but only 6 bytes within internal
// buffer, so overflow.
src := &errReadWriter{r: buf, count: 2}
pgi, err := newPgInterceptor(src, 6 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(src, 6 /* bufSize */)

msg, err := pgi.ReadMsg()
require.EqualError(t, err, io.ErrClosedPipe.Error())
Expand All @@ -211,8 +195,7 @@ func TestPGInterceptor_ReadMsg(t *testing.T) {

// Set buffer's size to be a multiple of the message so that we'll
// always hit the case where the message fits.
pgi, err := newPgInterceptor(buf, len(testSelect1Bytes)*3)
require.NoError(t, err)
pgi := newPgInterceptor(buf, len(testSelect1Bytes)*3)

c := 0
n := testing.AllocsPerRun(count-1, func() {
Expand Down Expand Up @@ -246,8 +229,7 @@ func TestPGInterceptor_ReadMsg(t *testing.T) {

// Set the buffer to be large enough to fit more bytes than the header,
// but not the entire message.
pgi, err := newPgInterceptor(buf, 7 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(buf, 7 /* bufSize */)

c := 0
n := testing.AllocsPerRun(count-1, func() {
Expand Down Expand Up @@ -286,8 +268,7 @@ func TestPGInterceptor_ForwardMsg(t *testing.T) {
dst := new(bytes.Buffer)
dstWriter := &errReadWriter{w: dst, count: 1}

pgi, err := newPgInterceptor(src, 32 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(src, 32 /* bufSize */)

n, err := pgi.ForwardMsg(dstWriter)
require.EqualError(t, err, io.ErrClosedPipe.Error())
Expand All @@ -306,8 +287,7 @@ func TestPGInterceptor_ForwardMsg(t *testing.T) {

// testSelect1Bytes has 14 bytes, but only 6 bytes within internal
// buffer, so partially buffered.
pgi, err := newPgInterceptor(src, 6 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(src, 6 /* bufSize */)

n, err := pgi.ForwardMsg(dstWriter)
require.EqualError(t, err, io.ErrClosedPipe.Error())
Expand All @@ -324,8 +304,7 @@ func TestPGInterceptor_ForwardMsg(t *testing.T) {

// Set buffer's size to be a multiple of the message so that we'll
// always hit the case where the message fits.
pgi, err := newPgInterceptor(src, len(testSelect1Bytes)*3)
require.NoError(t, err)
pgi := newPgInterceptor(src, len(testSelect1Bytes)*3)

// Forward all the messages, and ensure 0 allocations.
n := testing.AllocsPerRun(count-1, func() {
Expand All @@ -349,8 +328,7 @@ func TestPGInterceptor_ForwardMsg(t *testing.T) {

// Set the buffer to be large enough to fit more bytes than the header,
// but not the entire message.
pgi, err := newPgInterceptor(src, 7 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(src, 7 /* bufSize */)

n := testing.AllocsPerRun(count-1, func() {
n, err := pgi.ForwardMsg(dst)
Expand All @@ -376,8 +354,7 @@ func TestPGInterceptor_readSize(t *testing.T) {
defer leaktest.AfterTest(t)()

buf := bytes.NewBufferString("foobarbazz")
pgi, err := newPgInterceptor(iotest.OneByteReader(buf), 10 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(iotest.OneByteReader(buf), 10 /* bufSize */)

// No reads to internal buffer.
require.Equal(t, 0, pgi.readSize())
Expand All @@ -395,8 +372,7 @@ func TestPGInterceptor_writeSize(t *testing.T) {
defer leaktest.AfterTest(t)()

buf := bytes.NewBufferString("foobarbazz")
pgi, err := newPgInterceptor(iotest.OneByteReader(buf), 10 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(iotest.OneByteReader(buf), 10 /* bufSize */)

// No writes to internal buffer.
require.Equal(t, 10, pgi.writeSize())
Expand All @@ -414,8 +390,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) {
defer leaktest.AfterTest(t)()

t.Run("invalid n", func(t *testing.T) {
pgi, err := newPgInterceptor(nil /* src */, 8 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(nil /* src */, 8 /* bufSize */)

require.EqualError(t, pgi.ensureNextNBytes(-1),
"invalid number of bytes -1 for buffer size 8")
Expand All @@ -426,8 +401,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) {
t.Run("buffer already has n bytes", func(t *testing.T) {
buf := bytes.NewBufferString("foobarbaz")

pgi, err := newPgInterceptor(iotest.OneByteReader(buf), 8 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(iotest.OneByteReader(buf), 8 /* bufSize */)

// Read "foo" into buffer".
require.NoError(t, pgi.ensureNextNBytes(3))
Expand All @@ -448,8 +422,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) {
t.Run("bytes are realigned", func(t *testing.T) {
buf := bytes.NewBufferString("foobarbazcar")

pgi, err := newPgInterceptor(iotest.OneByteReader(buf), 9 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(iotest.OneByteReader(buf), 9 /* bufSize */)

// Read "foobarb" into buffer.
require.NoError(t, pgi.ensureNextNBytes(7))
Expand All @@ -470,8 +443,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) {
// if there was a Read call.
buf := bytes.NewBufferString("foobarbaz")

pgi, err := newPgInterceptor(buf, 10 /* bufSize */)
require.NoError(t, err)
pgi := newPgInterceptor(buf, 10 /* bufSize */)

// Request for only 1 byte.
require.NoError(t, pgi.ensureNextNBytes(1))
Expand All @@ -480,7 +452,7 @@ func TestPGInterceptor_ensureNextNBytes(t *testing.T) {
require.Equal(t, "foobarbaz", string(pgi.buf[pgi.readPos:pgi.writePos]))

// Should be a no-op.
_, err = buf.WriteString("car")
_, err := buf.WriteString("car")
require.NoError(t, err)
require.NoError(t, pgi.ensureNextNBytes(9))
require.Equal(t, 3, buf.Len())
Expand Down
18 changes: 8 additions & 10 deletions pkg/ccl/sqlproxyccl/interceptor/frontend_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,16 @@ import (
"github.com/jackc/pgproto3/v2"
)

// FrontendInterceptor is a client interceptor for the Postgres frontend
// protocol.
// FrontendInterceptor is a client interceptor for the Postgres frontend protocol.
type FrontendInterceptor pgInterceptor

// NewFrontendInterceptor creates a FrontendInterceptor. bufSize must be at
// least the size of a pgwire message header.
func NewFrontendInterceptor(src io.Reader, bufSize int) (*FrontendInterceptor, error) {
pgi, err := newPgInterceptor(src, bufSize)
if err != nil {
return nil, err
}
return (*FrontendInterceptor)(pgi), nil
// NewFrontendInterceptor creates a FrontendInterceptor. If bufSize is smaller
// than 5 bytes, the defaults (8K) will be used.
//
// NOTE: For future improvement, we can use the options pattern here if there's
// a need for more than one field.
func NewFrontendInterceptor(src io.Reader, bufSize int) *FrontendInterceptor {
return (*FrontendInterceptor)(newPgInterceptor(src, bufSize))
}

// PeekMsg returns the header of the current pgwire message without advancing
Expand Down
Loading

0 comments on commit db4533b

Please sign in to comment.