Skip to content

Commit

Permalink
sql/pgwire: fix statement buffer memory leak when using suspended por…
Browse files Browse the repository at this point in the history
…tals

The connection statement buffer grows indefinitely when the client uses the
execute portal with limit feature of the Postgres protocol, eventually causing
the node to crash out of memory. Any long running query that uses the limit
feature will cause this memory leak such as the `EXPERIMENTAL CHANGEFEED FOR`
statement. The execute portal with limit feature of the Postgres protocol is
used by the JDBC Postgres driver to fetch a limited number of rows at a time.

The leak is caused by commands accumulating in the buffer and never getting
cleared out. The client sends 2 commands every time it wants to receive more
rows:

- `Execute {"Portal": "C_1", "MaxRows": 1}`
- `Sync`

The server processes the commands and leaves them in the buffer, every
iteration causes 2 more commands to leak.

A similar memory leak was fixed by #48859, however the execute with limit
feature is implemented by a side state machine in limitedCommandResult. The
cleanup routine added by #48859 is never executed for suspended portals as
they never return to the main conn_executor loop.

After this change the statement buffer gets trimmed to reclaim memory after
each client command is processed in the limitedCommandResult side state
machine. The StmtBuf.Ltrim function was changed to be public visibility to
enable this. While this is not ideal, it does scope the fix to the
limitedCommandResult side state machine and could be addressed when the
limitedCommandResult functionality is refactored into the conn_executor.

Added a unit test which causes the leak, used the PGWire client in the test as
neither the pg or pgx clients use execute with limit, so cant be used to
demonstrate the leak. Also tested the fix in a cluster by following the steps
outlined in #66849.

Resolves: #66849

See also: #48859

Release note (bug fix): fix statement buffer memory leak when using
suspended portals
  • Loading branch information
joesankey committed Jul 8, 2021
1 parent 1e26809 commit 2cd1ebd
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 12 deletions.
10 changes: 8 additions & 2 deletions pkg/sql/conn_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1543,6 +1543,12 @@ func (ex *connExecutor) execCmd(ctx context.Context) error {
if err != nil {
return err
}
// Update the cmd and pos in the stmtBuf as limitedCommandResult will have
// advanced the position if the the portal is repeatedly executed with a limit
cmd, pos, err = ex.stmtBuf.CurCmd()
if err != nil {
return err
}

case PrepareStmt:
ex.curStmt = tcmd.AST
Expand Down Expand Up @@ -1672,7 +1678,7 @@ func (ex *connExecutor) execCmd(ctx context.Context) error {

if rewindCapability, canRewind := ex.getRewindTxnCapability(); !canRewind {
// Trim statements that cannot be retried to reclaim memory.
ex.stmtBuf.ltrim(ctx, pos)
ex.stmtBuf.Ltrim(ctx, pos)
} else {
rewindCapability.close()
}
Expand Down Expand Up @@ -1798,7 +1804,7 @@ func (ex *connExecutor) setTxnRewindPos(ctx context.Context, pos CmdPos) {
"Was: %d; new value: %d", ex.extraTxnState.txnRewindPos, pos))
}
ex.extraTxnState.txnRewindPos = pos
ex.stmtBuf.ltrim(ctx, pos)
ex.stmtBuf.Ltrim(ctx, pos)
ex.commitPrepStmtNamespace(ctx)
ex.extraTxnState.savepointsAtTxnRewindPos = ex.extraTxnState.savepoints.clone()
}
Expand Down
88 changes: 88 additions & 0 deletions pkg/sql/conn_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/tests"
"github.com/cockroachdb/cockroach/pkg/testutils"
"github.com/cockroachdb/cockroach/pkg/testutils/pgtest"
"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
"github.com/cockroachdb/cockroach/pkg/testutils/skip"
"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
"github.com/cockroachdb/cockroach/pkg/util/ctxgroup"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/datadriven"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/redact"
"github.com/jackc/pgx"
Expand Down Expand Up @@ -829,6 +831,92 @@ func TestTrimFlushedStatements(t *testing.T) {
require.NoError(t, tx.Commit())
}

func TestTrimSuspendedPortals(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

const (
// select generates 10 rows of results which the test retrieves using ExecPortal
selectStmt = "SELECT generate_series(1, 10)"

// stmtBufMaxLen is the maximum length the statement buffer should be during
// execution
stmtBufMaxLen = 2

// The name of the portal, used to get a handle on the statement buffer
portalName = "C_1"
)

ctx := context.Background()

var stmtBuff *sql.StmtBuf
s, _, _ := serverutils.StartServer(t, base.TestServerArgs{
Knobs: base.TestingKnobs{
SQLExecutor: &sql.ExecutorTestingKnobs{
// get a handle to the statement buffer during the Bind phase
AfterExecCmd: func(_ context.Context, cmd sql.Command, buf *sql.StmtBuf) {
switch tcmd := cmd.(type) {
case sql.BindStmt:
if tcmd.PortalName == portalName {
stmtBuff = buf
}
default:
}
},
},
},
Insecure: true,
})
defer s.Stopper().Stop(ctx)

// Connect to the cluster via the PGWire client.
p, err := pgtest.NewPGTest(ctx, s.SQLAddr(), security.RootUser)
require.NoError(t, err)

// setup the portal
require.NoError(t, p.SendOneLine(`Query {"String": "BEGIN"}`))
require.NoError(t, p.SendOneLine(fmt.Sprintf(`Parse {"Query": "%s"}`, selectStmt)))
require.NoError(t, p.SendOneLine(fmt.Sprintf(`Bind {"DestinationPortal": "%s"}`, portalName)))

// wait for ready
until := pgtest.ParseMessages("ReadyForQuery")
_, err = p.Until(false /* keepErrMsg */, until...)
require.NoError(t, err)

// Execute the portal 10 times
for i := 1; i <= 10; i++ {

// Exec the portal
require.NoError(t, p.SendOneLine(fmt.Sprintf(`Execute {"Portal": "%s", "MaxRows": 1}`, portalName)))
require.NoError(t, p.SendOneLine(`Sync`))

// wait for ready
msg, _ := p.Until(false /* keepErrMsg */, until...)

// received messages should include a data row with the correct value
received := pgtest.MsgsToJSONWithIgnore(msg, &datadriven.TestData{})
require.Equal(t, 1, strings.Count(received, fmt.Sprintf(`"Type":"DataRow","Values":[{"text":"%d"}]`, i)))

// assert that the stmtBuff never exceeds the expected size
stmtBufLen := stmtBuff.Len()
if stmtBufLen > stmtBufMaxLen {
t.Fatalf("statement buffer grew to %d (> %d) after %dth execution", stmtBufLen, stmtBufMaxLen, i)
}
}

// explicitly close portal
require.NoError(t, p.SendOneLine(fmt.Sprintf(`Close {"ObjectType": 80,"Name": "%s"}`, portalName)))

// send commit
require.NoError(t, p.SendOneLine(`Query {"String": "COMMIT"}`))

// wait for ready
msg, _ := p.Until(false /* keepErrMsg */, until...)
received := pgtest.MsgsToJSONWithIgnore(msg, &datadriven.TestData{})
require.Equal(t, 1, strings.Count(received, `"Type":"CommandComplete","CommandTag":"COMMIT"`))

}

func TestShowLastQueryStatistics(t *testing.T) {
defer leaktest.AfterTest(t)()

Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/conn_io.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,11 +455,11 @@ func (buf *StmtBuf) translatePosLocked(pos CmdPos) (int, error) {
return int(pos - buf.mu.startPos), nil
}

// ltrim iterates over the buffer forward and removes all commands up to
// Ltrim iterates over the buffer forward and removes all commands up to
// (not including) the command at pos.
//
// It's illegal to ltrim to a position higher than the current cursor.
func (buf *StmtBuf) ltrim(ctx context.Context, pos CmdPos) {
// It's illegal to Ltrim to a position higher than the current cursor.
func (buf *StmtBuf) Ltrim(ctx context.Context, pos CmdPos) {
buf.mu.Lock()
defer buf.mu.Unlock()
if pos < buf.mu.startPos {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/conn_io_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func TestStmtBufLtrim(t *testing.T) {
buf.AdvanceOne()
buf.AdvanceOne()
trimPos := CmdPos(2)
buf.ltrim(ctx, trimPos)
buf.Ltrim(ctx, trimPos)
if l := buf.mu.data.Len(); l != 3 {
t.Fatalf("expected 3 left, got: %d", l)
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/sql/pgwire/command_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ func (r *limitedCommandResult) moreResultsNeeded(ctx context.Context) error {
// The client wants to see a ready for query message
// back. Send it then run the for loop again.
r.conn.stmtBuf.AdvanceOne()
// Trim old statements to reclaim memory. We need to perform this clean up
// here as the conn_executor cleanup is not executed because of the
// limitedCommandResult side state machine.
r.conn.stmtBuf.Ltrim(ctx, prevPos)
// We can hard code InTxnBlock here because we don't
// support implicit transactions, so we know we're in
// a transaction.
Expand Down
12 changes: 6 additions & 6 deletions pkg/testutils/pgtest/datadriven.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,23 @@ func RunTest(t *testing.T, path, addr, user string) {
(d.HasArg("noncrdb_only") && p.isCockroachDB) {
return d.Expected
}
until := parseMessages(d.Input)
until := ParseMessages(d.Input)
msgs, err := p.Receive(hasKeepErrMsg(d), until...)
if err != nil {
t.Fatalf("%s: %+v", d.Pos, err)
}
return msgsToJSONWithIgnore(msgs, d)
return MsgsToJSONWithIgnore(msgs, d)
case "until":
if (d.HasArg("crdb_only") && !p.isCockroachDB) ||
(d.HasArg("noncrdb_only") && p.isCockroachDB) {
return d.Expected
}
until := parseMessages(d.Input)
until := ParseMessages(d.Input)
msgs, err := p.Until(hasKeepErrMsg(d), until...)
if err != nil {
t.Fatalf("%s: %+v", d.Pos, err)
}
return msgsToJSONWithIgnore(msgs, d)
return MsgsToJSONWithIgnore(msgs, d)
default:
t.Fatalf("unknown command %s", d.Cmd)
return ""
Expand All @@ -141,7 +141,7 @@ func RunTest(t *testing.T, path, addr, user string) {
}
}

func parseMessages(s string) []pgproto3.BackendMessage {
func ParseMessages(s string) []pgproto3.BackendMessage {
var msgs []pgproto3.BackendMessage
for _, typ := range strings.Split(s, "\n") {
msgs = append(msgs, toMessage(typ).(pgproto3.BackendMessage))
Expand All @@ -158,7 +158,7 @@ func hasKeepErrMsg(d *datadriven.TestData) bool {
return false
}

func msgsToJSONWithIgnore(msgs []pgproto3.BackendMessage, args *datadriven.TestData) string {
func MsgsToJSONWithIgnore(msgs []pgproto3.BackendMessage, args *datadriven.TestData) string {
ignore := map[string]bool{}
errs := map[string]string{}
for _, arg := range args.CmdArgs {
Expand Down
25 changes: 25 additions & 0 deletions pkg/testutils/pgtest/pgtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ import (
"bytes"
"context"
"encoding/gob"
"encoding/json"
"fmt"
"net"
"reflect"
"strings"
"testing"

"github.com/cockroachdb/errors"
Expand Down Expand Up @@ -80,6 +82,29 @@ func (p *PGTest) Close() error {
return p.fe.Send(&pgproto3.Terminate{})
}

// SendOneLine sends a single msg to the server represented as a single string
// in the format `<msg type> <msg body in JSON>`. See testdata for examples.
func (p *PGTest) SendOneLine(line string) error {
sp := strings.SplitN(line, " ", 2)
msg := toMessage(sp[0])
if len(sp) == 2 {
msgBytes := []byte(sp[1])
switch msg := msg.(type) {
case *pgproto3.CopyData:
var data struct{ Data string }
if err := json.Unmarshal(msgBytes, &data); err != nil {
return err
}
msg.Data = []byte(data.Data)
default:
if err := json.Unmarshal(msgBytes, msg); err != nil {
return err
}
}
}
return p.Send(msg.(pgproto3.FrontendMessage))
}

// Send sends msg to the serrver.
func (p *PGTest) Send(msg pgproto3.FrontendMessage) error {
if testing.Verbose() {
Expand Down

0 comments on commit 2cd1ebd

Please sign in to comment.