Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release-20.2: sql/pgwire: fix statement buffer memory leak when using suspended portals #67370

Merged
merged 1 commit into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pkg/sql/conn_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -1543,6 +1543,12 @@ func (ex *connExecutor) execCmd(ctx context.Context) error {
if err != nil {
return err
}
// Update the cmd and pos in the stmtBuf as limitedCommandResult will have
// advanced the position if the the portal is repeatedly executed with a limit
cmd, pos, err = ex.stmtBuf.CurCmd()
if err != nil {
return err
}

case PrepareStmt:
ex.curStmt = tcmd.AST
Expand Down Expand Up @@ -1672,7 +1678,7 @@ func (ex *connExecutor) execCmd(ctx context.Context) error {

if rewindCapability, canRewind := ex.getRewindTxnCapability(); !canRewind {
// Trim statements that cannot be retried to reclaim memory.
ex.stmtBuf.ltrim(ctx, pos)
ex.stmtBuf.Ltrim(ctx, pos)
} else {
rewindCapability.close()
}
Expand Down Expand Up @@ -1798,7 +1804,7 @@ func (ex *connExecutor) setTxnRewindPos(ctx context.Context, pos CmdPos) {
"Was: %d; new value: %d", ex.extraTxnState.txnRewindPos, pos))
}
ex.extraTxnState.txnRewindPos = pos
ex.stmtBuf.ltrim(ctx, pos)
ex.stmtBuf.Ltrim(ctx, pos)
ex.commitPrepStmtNamespace(ctx)
ex.extraTxnState.savepointsAtTxnRewindPos = ex.extraTxnState.savepoints.clone()
}
Expand Down
88 changes: 88 additions & 0 deletions pkg/sql/conn_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/tests"
"github.com/cockroachdb/cockroach/pkg/testutils"
"github.com/cockroachdb/cockroach/pkg/testutils/pgtest"
"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
"github.com/cockroachdb/cockroach/pkg/testutils/skip"
"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
"github.com/cockroachdb/cockroach/pkg/util/ctxgroup"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/datadriven"
"github.com/cockroachdb/errors"
"github.com/cockroachdb/redact"
"github.com/jackc/pgx"
Expand Down Expand Up @@ -829,6 +831,92 @@ func TestTrimFlushedStatements(t *testing.T) {
require.NoError(t, tx.Commit())
}

func TestTrimSuspendedPortals(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

const (
// select generates 10 rows of results which the test retrieves using ExecPortal
selectStmt = "SELECT generate_series(1, 10)"

// stmtBufMaxLen is the maximum length the statement buffer should be during
// execution
stmtBufMaxLen = 2

// The name of the portal, used to get a handle on the statement buffer
portalName = "C_1"
)

ctx := context.Background()

var stmtBuff *sql.StmtBuf
s, _, _ := serverutils.StartServer(t, base.TestServerArgs{
Knobs: base.TestingKnobs{
SQLExecutor: &sql.ExecutorTestingKnobs{
// get a handle to the statement buffer during the Bind phase
AfterExecCmd: func(_ context.Context, cmd sql.Command, buf *sql.StmtBuf) {
switch tcmd := cmd.(type) {
case sql.BindStmt:
if tcmd.PortalName == portalName {
stmtBuff = buf
}
default:
}
},
},
},
Insecure: true,
})
defer s.Stopper().Stop(ctx)

// Connect to the cluster via the PGWire client.
p, err := pgtest.NewPGTest(ctx, s.SQLAddr(), security.RootUser)
require.NoError(t, err)

// setup the portal
require.NoError(t, p.SendOneLine(`Query {"String": "BEGIN"}`))
require.NoError(t, p.SendOneLine(fmt.Sprintf(`Parse {"Query": "%s"}`, selectStmt)))
require.NoError(t, p.SendOneLine(fmt.Sprintf(`Bind {"DestinationPortal": "%s"}`, portalName)))

// wait for ready
until := pgtest.ParseMessages("ReadyForQuery")
_, err = p.Until(false /* keepErrMsg */, until...)
require.NoError(t, err)

// Execute the portal 10 times
for i := 1; i <= 10; i++ {

// Exec the portal
require.NoError(t, p.SendOneLine(fmt.Sprintf(`Execute {"Portal": "%s", "MaxRows": 1}`, portalName)))
require.NoError(t, p.SendOneLine(`Sync`))

// wait for ready
msg, _ := p.Until(false /* keepErrMsg */, until...)

// received messages should include a data row with the correct value
received := pgtest.MsgsToJSONWithIgnore(msg, &datadriven.TestData{})
require.Equal(t, 1, strings.Count(received, fmt.Sprintf(`"Type":"DataRow","Values":[{"text":"%d"}]`, i)))

// assert that the stmtBuff never exceeds the expected size
stmtBufLen := stmtBuff.Len()
if stmtBufLen > stmtBufMaxLen {
t.Fatalf("statement buffer grew to %d (> %d) after %dth execution", stmtBufLen, stmtBufMaxLen, i)
}
}

// explicitly close portal
require.NoError(t, p.SendOneLine(fmt.Sprintf(`Close {"ObjectType": 80,"Name": "%s"}`, portalName)))

// send commit
require.NoError(t, p.SendOneLine(`Query {"String": "COMMIT"}`))

// wait for ready
msg, _ := p.Until(false /* keepErrMsg */, until...)
received := pgtest.MsgsToJSONWithIgnore(msg, &datadriven.TestData{})
require.Equal(t, 1, strings.Count(received, `"Type":"CommandComplete","CommandTag":"COMMIT"`))

}

func TestShowLastQueryStatistics(t *testing.T) {
defer leaktest.AfterTest(t)()

Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/conn_io.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,11 +455,11 @@ func (buf *StmtBuf) translatePosLocked(pos CmdPos) (int, error) {
return int(pos - buf.mu.startPos), nil
}

// ltrim iterates over the buffer forward and removes all commands up to
// Ltrim iterates over the buffer forward and removes all commands up to
// (not including) the command at pos.
//
// It's illegal to ltrim to a position higher than the current cursor.
func (buf *StmtBuf) ltrim(ctx context.Context, pos CmdPos) {
// It's illegal to Ltrim to a position higher than the current cursor.
func (buf *StmtBuf) Ltrim(ctx context.Context, pos CmdPos) {
buf.mu.Lock()
defer buf.mu.Unlock()
if pos < buf.mu.startPos {
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/conn_io_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func TestStmtBufLtrim(t *testing.T) {
buf.AdvanceOne()
buf.AdvanceOne()
trimPos := CmdPos(2)
buf.ltrim(ctx, trimPos)
buf.Ltrim(ctx, trimPos)
if l := buf.mu.data.Len(); l != 3 {
t.Fatalf("expected 3 left, got: %d", l)
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/sql/pgwire/command_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,10 @@ func (r *limitedCommandResult) moreResultsNeeded(ctx context.Context) error {
// The client wants to see a ready for query message
// back. Send it then run the for loop again.
r.conn.stmtBuf.AdvanceOne()
// Trim old statements to reclaim memory. We need to perform this clean up
// here as the conn_executor cleanup is not executed because of the
// limitedCommandResult side state machine.
r.conn.stmtBuf.Ltrim(ctx, prevPos)
// We can hard code InTxnBlock here because we don't
// support implicit transactions, so we know we're in
// a transaction.
Expand Down
18 changes: 12 additions & 6 deletions pkg/testutils/pgtest/datadriven.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,23 @@ func RunTest(t *testing.T, path, addr, user string) {
(d.HasArg("noncrdb_only") && p.isCockroachDB) {
return d.Expected
}
until := parseMessages(d.Input)
until := ParseMessages(d.Input)
msgs, err := p.Receive(hasKeepErrMsg(d), until...)
if err != nil {
t.Fatalf("%s: %+v", d.Pos, err)
}
return msgsToJSONWithIgnore(msgs, d)
return MsgsToJSONWithIgnore(msgs, d)
case "until":
if (d.HasArg("crdb_only") && !p.isCockroachDB) ||
(d.HasArg("noncrdb_only") && p.isCockroachDB) {
return d.Expected
}
until := parseMessages(d.Input)
until := ParseMessages(d.Input)
msgs, err := p.Until(hasKeepErrMsg(d), until...)
if err != nil {
t.Fatalf("%s: %+v", d.Pos, err)
}
return msgsToJSONWithIgnore(msgs, d)
return MsgsToJSONWithIgnore(msgs, d)
default:
t.Fatalf("unknown command %s", d.Cmd)
return ""
Expand All @@ -141,7 +141,10 @@ func RunTest(t *testing.T, path, addr, user string) {
}
}

func parseMessages(s string) []pgproto3.BackendMessage {
// ParseMessages parses a string containing multiple pgproto3 messages separated
// by the newline symbol. See testdata for examples ("until" or "receive"
// commands).
func ParseMessages(s string) []pgproto3.BackendMessage {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you will also need to add comments to these functions now that they are exported (in order to pass linting checks)

// ParseMessages parses a string containing multiple pgproto3 messages separated
// by the newline symbol. See testdata for examples ("until" or "receive"
// commands).
// MsgsToJSONWithIgnore converts the pgproto3 messages to JSON format. The
// second argument can specify how to adjust the messages (e.g. to make them
// more deterministic) if needed, see testdata for examples.

var msgs []pgproto3.BackendMessage
for _, typ := range strings.Split(s, "\n") {
msgs = append(msgs, toMessage(typ).(pgproto3.BackendMessage))
Expand All @@ -158,7 +161,10 @@ func hasKeepErrMsg(d *datadriven.TestData) bool {
return false
}

func msgsToJSONWithIgnore(msgs []pgproto3.BackendMessage, args *datadriven.TestData) string {
// MsgsToJSONWithIgnore converts the pgproto3 messages to JSON format. The
// second argument can specify how to adjust the messages (e.g. to make them
// more deterministic) if needed, see testdata for examples.
func MsgsToJSONWithIgnore(msgs []pgproto3.BackendMessage, args *datadriven.TestData) string {
ignore := map[string]bool{}
errs := map[string]string{}
for _, arg := range args.CmdArgs {
Expand Down
25 changes: 25 additions & 0 deletions pkg/testutils/pgtest/pgtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ import (
"bytes"
"context"
"encoding/gob"
"encoding/json"
"fmt"
"net"
"reflect"
"strings"
"testing"

"github.com/cockroachdb/errors"
Expand Down Expand Up @@ -80,6 +82,29 @@ func (p *PGTest) Close() error {
return p.fe.Send(&pgproto3.Terminate{})
}

// SendOneLine sends a single msg to the server represented as a single string
// in the format `<msg type> <msg body in JSON>`. See testdata for examples.
func (p *PGTest) SendOneLine(line string) error {
sp := strings.SplitN(line, " ", 2)
msg := toMessage(sp[0])
if len(sp) == 2 {
msgBytes := []byte(sp[1])
switch msg := msg.(type) {
case *pgproto3.CopyData:
var data struct{ Data string }
if err := json.Unmarshal(msgBytes, &data); err != nil {
return err
}
msg.Data = []byte(data.Data)
default:
if err := json.Unmarshal(msgBytes, msg); err != nil {
return err
}
}
}
return p.Send(msg.(pgproto3.FrontendMessage))
}

// Send sends msg to the serrver.
func (p *PGTest) Send(msg pgproto3.FrontendMessage) error {
if testing.Verbose() {
Expand Down