From a0fb94e95c22045aa38c00721bc3bc24801fd5bd Mon Sep 17 00:00:00 2001 From: joesankey Date: Thu, 1 Jul 2021 13:49:56 -0700 Subject: [PATCH] sql/pgwire: fix statement buffer memory leak when using suspended portals The connection statement buffer grows indefinitely when the client uses the execute portal with limit feature of the Postgres protocol, eventually causing the node to crash out of memory. Any long running query that uses the limit feature will cause this memory leak such as the `EXPERIMENTAL CHANGEFEED FOR` statement. The execute portal with limit feature of the Postgres protocol is used by the JDBC Postgres driver to fetch a limited number of rows at a time. The leak is caused by commands accumulating in the buffer and never getting cleared out. The client sends 2 commands every time it wants to receive more rows: - `Execute {"Portal": "C_1", "MaxRows": 1}` - `Sync` The server processes the commands and leaves them in the buffer, every iteration causes 2 more commands to leak. A similar memory leak was fixed by #48859, however the execute with limit feature is implemented by a side state machine in limitedCommandResult. The cleanup routine added by #48859 is never executed for suspended portals as they never return to the main conn_executor loop. After this change the statement buffer gets trimmed to reclaim memory after each client command is processed in the limitedCommandResult side state machine. The StmtBuf.Ltrim function was changed to be public visibility to enable this. While this is not ideal, it does scope the fix to the limitedCommandResult side state machine and could be addressed when the limitedCommandResult functionality is refactored into the conn_executor. Added a unit test which causes the leak, used the PGWire client in the test as neither the pg or pgx clients use execute with limit, so cant be used to demonstrate the leak. Also tested the fix in a cluster by following the steps outlined in #66849. Resolves: #66849 See also: #48859 Release note (bug fix): fix statement buffer memory leak when using suspended portals --- pkg/sql/conn_executor.go | 10 +++- pkg/sql/conn_executor_test.go | 88 ++++++++++++++++++++++++++++++ pkg/sql/conn_io.go | 6 +- pkg/sql/conn_io_test.go | 2 +- pkg/sql/pgwire/command_result.go | 4 ++ pkg/testutils/pgtest/datadriven.go | 18 ++++-- pkg/testutils/pgtest/pgtest.go | 25 +++++++++ 7 files changed, 141 insertions(+), 12 deletions(-) diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index 94b59113e36d..b1642a88c88d 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -1543,6 +1543,12 @@ func (ex *connExecutor) execCmd(ctx context.Context) error { if err != nil { return err } + // Update the cmd and pos in the stmtBuf as limitedCommandResult will have + // advanced the position if the the portal is repeatedly executed with a limit + cmd, pos, err = ex.stmtBuf.CurCmd() + if err != nil { + return err + } case PrepareStmt: ex.curStmt = tcmd.AST @@ -1672,7 +1678,7 @@ func (ex *connExecutor) execCmd(ctx context.Context) error { if rewindCapability, canRewind := ex.getRewindTxnCapability(); !canRewind { // Trim statements that cannot be retried to reclaim memory. - ex.stmtBuf.ltrim(ctx, pos) + ex.stmtBuf.Ltrim(ctx, pos) } else { rewindCapability.close() } @@ -1798,7 +1804,7 @@ func (ex *connExecutor) setTxnRewindPos(ctx context.Context, pos CmdPos) { "Was: %d; new value: %d", ex.extraTxnState.txnRewindPos, pos)) } ex.extraTxnState.txnRewindPos = pos - ex.stmtBuf.ltrim(ctx, pos) + ex.stmtBuf.Ltrim(ctx, pos) ex.commitPrepStmtNamespace(ctx) ex.extraTxnState.savepointsAtTxnRewindPos = ex.extraTxnState.savepoints.clone() } diff --git a/pkg/sql/conn_executor_test.go b/pkg/sql/conn_executor_test.go index 4723a5b373c9..27846debdc9b 100644 --- a/pkg/sql/conn_executor_test.go +++ b/pkg/sql/conn_executor_test.go @@ -38,6 +38,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/tests" "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/testutils/pgtest" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/skip" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" @@ -45,6 +46,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/datadriven" "github.com/cockroachdb/errors" "github.com/cockroachdb/redact" "github.com/jackc/pgx" @@ -829,6 +831,92 @@ func TestTrimFlushedStatements(t *testing.T) { require.NoError(t, tx.Commit()) } +func TestTrimSuspendedPortals(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + const ( + // select generates 10 rows of results which the test retrieves using ExecPortal + selectStmt = "SELECT generate_series(1, 10)" + + // stmtBufMaxLen is the maximum length the statement buffer should be during + // execution + stmtBufMaxLen = 2 + + // The name of the portal, used to get a handle on the statement buffer + portalName = "C_1" + ) + + ctx := context.Background() + + var stmtBuff *sql.StmtBuf + s, _, _ := serverutils.StartServer(t, base.TestServerArgs{ + Knobs: base.TestingKnobs{ + SQLExecutor: &sql.ExecutorTestingKnobs{ + // get a handle to the statement buffer during the Bind phase + AfterExecCmd: func(_ context.Context, cmd sql.Command, buf *sql.StmtBuf) { + switch tcmd := cmd.(type) { + case sql.BindStmt: + if tcmd.PortalName == portalName { + stmtBuff = buf + } + default: + } + }, + }, + }, + Insecure: true, + }) + defer s.Stopper().Stop(ctx) + + // Connect to the cluster via the PGWire client. + p, err := pgtest.NewPGTest(ctx, s.SQLAddr(), security.RootUser) + require.NoError(t, err) + + // setup the portal + require.NoError(t, p.SendOneLine(`Query {"String": "BEGIN"}`)) + require.NoError(t, p.SendOneLine(fmt.Sprintf(`Parse {"Query": "%s"}`, selectStmt))) + require.NoError(t, p.SendOneLine(fmt.Sprintf(`Bind {"DestinationPortal": "%s"}`, portalName))) + + // wait for ready + until := pgtest.ParseMessages("ReadyForQuery") + _, err = p.Until(false /* keepErrMsg */, until...) + require.NoError(t, err) + + // Execute the portal 10 times + for i := 1; i <= 10; i++ { + + // Exec the portal + require.NoError(t, p.SendOneLine(fmt.Sprintf(`Execute {"Portal": "%s", "MaxRows": 1}`, portalName))) + require.NoError(t, p.SendOneLine(`Sync`)) + + // wait for ready + msg, _ := p.Until(false /* keepErrMsg */, until...) + + // received messages should include a data row with the correct value + received := pgtest.MsgsToJSONWithIgnore(msg, &datadriven.TestData{}) + require.Equal(t, 1, strings.Count(received, fmt.Sprintf(`"Type":"DataRow","Values":[{"text":"%d"}]`, i))) + + // assert that the stmtBuff never exceeds the expected size + stmtBufLen := stmtBuff.Len() + if stmtBufLen > stmtBufMaxLen { + t.Fatalf("statement buffer grew to %d (> %d) after %dth execution", stmtBufLen, stmtBufMaxLen, i) + } + } + + // explicitly close portal + require.NoError(t, p.SendOneLine(fmt.Sprintf(`Close {"ObjectType": 80,"Name": "%s"}`, portalName))) + + // send commit + require.NoError(t, p.SendOneLine(`Query {"String": "COMMIT"}`)) + + // wait for ready + msg, _ := p.Until(false /* keepErrMsg */, until...) + received := pgtest.MsgsToJSONWithIgnore(msg, &datadriven.TestData{}) + require.Equal(t, 1, strings.Count(received, `"Type":"CommandComplete","CommandTag":"COMMIT"`)) + +} + func TestShowLastQueryStatistics(t *testing.T) { defer leaktest.AfterTest(t)() diff --git a/pkg/sql/conn_io.go b/pkg/sql/conn_io.go index a402e4f292eb..5f343d128e7b 100644 --- a/pkg/sql/conn_io.go +++ b/pkg/sql/conn_io.go @@ -455,11 +455,11 @@ func (buf *StmtBuf) translatePosLocked(pos CmdPos) (int, error) { return int(pos - buf.mu.startPos), nil } -// ltrim iterates over the buffer forward and removes all commands up to +// Ltrim iterates over the buffer forward and removes all commands up to // (not including) the command at pos. // -// It's illegal to ltrim to a position higher than the current cursor. -func (buf *StmtBuf) ltrim(ctx context.Context, pos CmdPos) { +// It's illegal to Ltrim to a position higher than the current cursor. +func (buf *StmtBuf) Ltrim(ctx context.Context, pos CmdPos) { buf.mu.Lock() defer buf.mu.Unlock() if pos < buf.mu.startPos { diff --git a/pkg/sql/conn_io_test.go b/pkg/sql/conn_io_test.go index 49d98052e656..537e0559c7ff 100644 --- a/pkg/sql/conn_io_test.go +++ b/pkg/sql/conn_io_test.go @@ -182,7 +182,7 @@ func TestStmtBufLtrim(t *testing.T) { buf.AdvanceOne() buf.AdvanceOne() trimPos := CmdPos(2) - buf.ltrim(ctx, trimPos) + buf.Ltrim(ctx, trimPos) if l := buf.mu.data.Len(); l != 3 { t.Fatalf("expected 3 left, got: %d", l) } diff --git a/pkg/sql/pgwire/command_result.go b/pkg/sql/pgwire/command_result.go index 15bd095a4e92..84c49a755a3c 100644 --- a/pkg/sql/pgwire/command_result.go +++ b/pkg/sql/pgwire/command_result.go @@ -452,6 +452,10 @@ func (r *limitedCommandResult) moreResultsNeeded(ctx context.Context) error { // The client wants to see a ready for query message // back. Send it then run the for loop again. r.conn.stmtBuf.AdvanceOne() + // Trim old statements to reclaim memory. We need to perform this clean up + // here as the conn_executor cleanup is not executed because of the + // limitedCommandResult side state machine. + r.conn.stmtBuf.Ltrim(ctx, prevPos) // We can hard code InTxnBlock here because we don't // support implicit transactions, so we know we're in // a transaction. diff --git a/pkg/testutils/pgtest/datadriven.go b/pkg/testutils/pgtest/datadriven.go index bd9c733b0e41..ce071bf79b0f 100644 --- a/pkg/testutils/pgtest/datadriven.go +++ b/pkg/testutils/pgtest/datadriven.go @@ -114,23 +114,23 @@ func RunTest(t *testing.T, path, addr, user string) { (d.HasArg("noncrdb_only") && p.isCockroachDB) { return d.Expected } - until := parseMessages(d.Input) + until := ParseMessages(d.Input) msgs, err := p.Receive(hasKeepErrMsg(d), until...) if err != nil { t.Fatalf("%s: %+v", d.Pos, err) } - return msgsToJSONWithIgnore(msgs, d) + return MsgsToJSONWithIgnore(msgs, d) case "until": if (d.HasArg("crdb_only") && !p.isCockroachDB) || (d.HasArg("noncrdb_only") && p.isCockroachDB) { return d.Expected } - until := parseMessages(d.Input) + until := ParseMessages(d.Input) msgs, err := p.Until(hasKeepErrMsg(d), until...) if err != nil { t.Fatalf("%s: %+v", d.Pos, err) } - return msgsToJSONWithIgnore(msgs, d) + return MsgsToJSONWithIgnore(msgs, d) default: t.Fatalf("unknown command %s", d.Cmd) return "" @@ -141,7 +141,10 @@ func RunTest(t *testing.T, path, addr, user string) { } } -func parseMessages(s string) []pgproto3.BackendMessage { +// ParseMessages parses a string containing multiple pgproto3 messages separated +// by the newline symbol. See testdata for examples ("until" or "receive" +// commands). +func ParseMessages(s string) []pgproto3.BackendMessage { var msgs []pgproto3.BackendMessage for _, typ := range strings.Split(s, "\n") { msgs = append(msgs, toMessage(typ).(pgproto3.BackendMessage)) @@ -158,7 +161,10 @@ func hasKeepErrMsg(d *datadriven.TestData) bool { return false } -func msgsToJSONWithIgnore(msgs []pgproto3.BackendMessage, args *datadriven.TestData) string { +// MsgsToJSONWithIgnore converts the pgproto3 messages to JSON format. The +// second argument can specify how to adjust the messages (e.g. to make them +// more deterministic) if needed, see testdata for examples. +func MsgsToJSONWithIgnore(msgs []pgproto3.BackendMessage, args *datadriven.TestData) string { ignore := map[string]bool{} errs := map[string]string{} for _, arg := range args.CmdArgs { diff --git a/pkg/testutils/pgtest/pgtest.go b/pkg/testutils/pgtest/pgtest.go index 7f3c23bed893..53a82a6972f1 100644 --- a/pkg/testutils/pgtest/pgtest.go +++ b/pkg/testutils/pgtest/pgtest.go @@ -14,9 +14,11 @@ import ( "bytes" "context" "encoding/gob" + "encoding/json" "fmt" "net" "reflect" + "strings" "testing" "github.com/cockroachdb/errors" @@ -80,6 +82,29 @@ func (p *PGTest) Close() error { return p.fe.Send(&pgproto3.Terminate{}) } +// SendOneLine sends a single msg to the server represented as a single string +// in the format ` `. See testdata for examples. +func (p *PGTest) SendOneLine(line string) error { + sp := strings.SplitN(line, " ", 2) + msg := toMessage(sp[0]) + if len(sp) == 2 { + msgBytes := []byte(sp[1]) + switch msg := msg.(type) { + case *pgproto3.CopyData: + var data struct{ Data string } + if err := json.Unmarshal(msgBytes, &data); err != nil { + return err + } + msg.Data = []byte(data.Data) + default: + if err := json.Unmarshal(msgBytes, msg); err != nil { + return err + } + } + } + return p.Send(msg.(pgproto3.FrontendMessage)) +} + // Send sends msg to the serrver. func (p *PGTest) Send(msg pgproto3.FrontendMessage) error { if testing.Verbose() {