diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 496d126a74ac..c2f2f81bae29 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -8,6 +8,7 @@ go_library( "backend_dialer.go", "connector.go", "error.go", + "forwarder.go", "frontend_admitter.go", "metrics.go", "proxy.go", @@ -50,6 +51,7 @@ go_test( srcs = [ "authentication_test.go", "connector_test.go", + "forwarder_test.go", "frontend_admitter_test.go", "main_test.go", "proxy_handler_test.go", diff --git a/pkg/ccl/sqlproxyccl/forwarder.go b/pkg/ccl/sqlproxyccl/forwarder.go new file mode 100644 index 000000000000..7b28a529d9e6 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/forwarder.go @@ -0,0 +1,107 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "net" +) + +// forwarder is used to forward pgwire messages from the client to the server, +// and vice-versa. At the moment, this does a direct proxying, and there is +// no intercepting. Once https://github.com/cockroachdb/cockroach/issues/76000 +// has been addressed, we will start intercepting pgwire messages at their +// boundaries here. +// +// The forwarder instance should always be constructed through the forward +// function, which also starts the forwarder. +type forwarder struct { + // ctx is a single context used to control all goroutines spawned by the + // forwarder. + ctx context.Context + ctxCancel context.CancelFunc + + // serverConn is only set after the authentication phase for the initial + // connection. In the context of a connection migration, serverConn is only + // replaced once the session has successfully been deserialized, and the + // old connection will be closed. + clientConn net.Conn // client <-> proxy + serverConn net.Conn // proxy <-> server + + // errChan is a buffered channel that contains the first forwarder error. + // This channel may receive nil errors. + errChan chan error +} + +// forward returns a new instance of forwarder, and starts forwarding messages +// from clientConn to serverConn. When this is called, it is expected that the +// caller passes ownership of serverConn to the forwarder, which implies that +// the forwarder will clean up serverConn. +// +// All goroutines spun up must check on f.ctx to prevent leaks, if possible. If +// there was an error within the goroutines, the forwarder will be closed, and +// the first error can be found in f.errChan. +// +// clientConn and serverConn must not be nil in all cases except testing. +// +// Note that callers MUST call Close in all cases, and should not rely on the +// fact that ctx was passed into forward(). There could be a possibility where +// the top-level context was cancelled, but the forwarder has not cleaned up. +func forward(ctx context.Context, clientConn, serverConn net.Conn) *forwarder { + ctx, cancelFn := context.WithCancel(ctx) + + f := &forwarder{ + ctx: ctx, + ctxCancel: cancelFn, + clientConn: clientConn, + serverConn: serverConn, + errChan: make(chan error, 1), + } + + go func() { + // Block until context is done. + <-f.ctx.Done() + + // Close the forwarder to clean up. This goroutine is temporarily here + // because the only way to unblock io.Copy is to close one of the ends, + // which will be done through closing the forwarder. Once we replace + // io.Copy with the interceptors, we could use f.ctx directly, and no + // longer need this goroutine. + // + // Note that if f.Close was called externally, this will result + // in two f.Close calls in total, i.e. one externally, and one here + // once the context gets cancelled. This is fine for now since we'll + // be removing this soon anyway. + f.Close() + }() + + // Copy all pgwire messages from frontend to backend connection until we + // encounter an error or shutdown signal. + go func() { + defer f.Close() + + err := ConnectionCopy(f.serverConn, f.clientConn) + select { + case f.errChan <- err: /* error reported */ + default: /* the channel already contains an error */ + } + }() + + return f +} + +// Close closes the forwarder, and stops the forwarding process. This is +// idempotent. +func (f *forwarder) Close() { + f.ctxCancel() + + // Since Close is idempotent, we'll ignore the error from Close in case it + // has already been closed. + f.serverConn.Close() +} diff --git a/pkg/ccl/sqlproxyccl/forwarder_test.go b/pkg/ccl/sqlproxyccl/forwarder_test.go new file mode 100644 index 000000000000..762402873c14 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/forwarder_test.go @@ -0,0 +1,188 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "net" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/require" +) + +func TestForward(t *testing.T) { + defer leaktest.AfterTest(t)() + + bgCtx := context.Background() + + t.Run("closed_when_processors_error", func(t *testing.T) { + p1, p2 := net.Pipe() + // Close the connection right away. p2 is owned by the forwarder. + p1.Close() + + f := forward(bgCtx, p1, p2) + defer f.Close() + + // We have to wait for the goroutine to run. Once the forwarder stops, + // we're good. + testutils.SucceedsSoon(t, func() error { + if f.ctx.Err() != nil { + return nil + } + return errors.New("forwarder is still running") + }) + }) + + t.Run("client_to_server", func(t *testing.T) { + ctx, cancel := context.WithTimeout(bgCtx, 5*time.Second) + defer cancel() + + clientW, clientR := net.Pipe() + serverW, serverR := net.Pipe() + // We don't close clientW and serverR here since we have no control + // over those. serverW is not closed since the forwarder is responsible + // for that. + defer clientR.Close() + + f := forward(ctx, clientR, serverW) + defer f.Close() + require.Nil(t, f.ctx.Err()) + + // Client writes some pgwire messages. + errChan := make(chan error, 1) + go func() { + _, err := clientW.Write((&pgproto3.Query{ + String: "SELECT 1", + }).Encode(nil)) + if err != nil { + errChan <- err + return + } + + if _, err := clientW.Write((&pgproto3.Execute{ + Portal: "foobar", + MaxRows: 42, + }).Encode(nil)); err != nil { + errChan <- err + return + } + + if _, err := clientW.Write((&pgproto3.Close{ + ObjectType: 'P', + }).Encode(nil)); err != nil { + errChan <- err + return + } + }() + + // Server should receive messages in order. + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(serverR), serverR) + + msg, err := backend.Receive() + require.NoError(t, err) + m1, ok := msg.(*pgproto3.Query) + require.True(t, ok) + require.Equal(t, "SELECT 1", m1.String) + + msg, err = backend.Receive() + require.NoError(t, err) + m2, ok := msg.(*pgproto3.Execute) + require.True(t, ok) + require.Equal(t, "foobar", m2.Portal) + require.Equal(t, uint32(42), m2.MaxRows) + + msg, err = backend.Receive() + require.NoError(t, err) + m3, ok := msg.(*pgproto3.Close) + require.True(t, ok) + require.Equal(t, byte('P'), m3.ObjectType) + + select { + case err = <-errChan: + t.Fatalf("require no error, but got %v", err) + default: + } + }) + + t.Run("server_to_client", func(t *testing.T) { + ctx, cancel := context.WithTimeout(bgCtx, 5*time.Second) + defer cancel() + + clientW, clientR := net.Pipe() + serverW, serverR := net.Pipe() + // We don't close clientW and serverR here since we have no control + // over those. serverW is not closed since the forwarder is responsible + // for that. + defer clientR.Close() + + f := forward(ctx, clientR, serverW) + defer f.Close() + require.Nil(t, f.ctx.Err()) + + // Server writes some pgwire messages. + errChan := make(chan error, 1) + go func() { + if _, err := serverR.Write((&pgproto3.ErrorResponse{ + Code: "100", + Message: "foobarbaz", + }).Encode(nil)); err != nil { + errChan <- err + return + } + + if _, err := serverR.Write((&pgproto3.ReadyForQuery{ + TxStatus: 'I', + }).Encode(nil)); err != nil { + errChan <- err + return + } + }() + + // Client should receive messages in order. + frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(clientW), clientW) + + msg, err := frontend.Receive() + require.NoError(t, err) + m1, ok := msg.(*pgproto3.ErrorResponse) + require.True(t, ok) + require.Equal(t, "100", m1.Code) + require.Equal(t, "foobarbaz", m1.Message) + + msg, err = frontend.Receive() + require.NoError(t, err) + m2, ok := msg.(*pgproto3.ReadyForQuery) + require.True(t, ok) + require.Equal(t, byte('I'), m2.TxStatus) + + select { + case err = <-errChan: + t.Fatalf("require no error, but got %v", err) + default: + } + }) +} + +func TestForwarder_Close(t *testing.T) { + defer leaktest.AfterTest(t)() + + p1, p2 := net.Pipe() + defer p1.Close() // p2 is owned by the forwarder. + + f := forward(context.Background(), p1, p2) + defer f.Close() + require.Nil(t, f.ctx.Err()) + + f.Close() + require.EqualError(t, f.ctx.Err(), context.Canceled.Error()) +} diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index f0e039501ac8..7d6305548ef8 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -324,46 +324,49 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn } return err } - defer crdbConn.Close() + var f *forwarder + defer func() { + // Only close crdbConn if the forwarder hasn't been started. When the + // forwarder has been created, crdbConn is owned by the forwarder. + if f == nil { + _ = crdbConn.Close() + } + }() handler.metrics.SuccessfulConnCount.Inc(1) - ctx, cancel := context.WithCancel(ctx) - defer cancel() - log.Infof(ctx, "new connection") connBegin := timeutil.Now() defer func() { log.Infof(ctx, "closing after %.2fs", timeutil.Since(connBegin).Seconds()) }() - // Copy all pgwire messages from frontend to backend connection until we - // encounter an error or shutdown signal. - go func() { - err := ConnectionCopy(crdbConn, conn) - select { - case errConnection <- err: /* error reported */ - default: /* the channel already contains an error */ - } - }() + // Pass ownership of crdbConn to the forwarder. + f = forward(ctx, conn, crdbConn) + defer f.Close() + // Block until an error is received, or when the stopper starts quiescing, + // whichever that happens first. + // + // TODO(jaylim-crl): We should handle all these errors properly, and + // propagate them back to the client if we're in a safe position to send. + // This PR https://github.com/cockroachdb/cockroach/pull/66205 removed error + // injections after connection handoff because there was a possibility of + // corrupted packets. + // + // TODO(jaylim-crl): It would be nice to have more consistency in how we + // manage background goroutines, communicate errors, etc. select { - case err := <-errConnection: + case err := <-f.errChan: // From forwarder. + handler.metrics.updateForError(err) + return err + case err := <-errConnection: // From denyListWatcher or idleMonitor. handler.metrics.updateForError(err) return err - case <-ctx.Done(): - err := ctx.Err() - if err != nil { - // The client connection expired. - codeErr := newErrorf( - codeExpiredClientConnection, "expired client conn: %v", err, - ) - handler.metrics.updateForError(codeErr) - return codeErr - } - return nil case <-handler.stopper.ShouldQuiesce(): - return nil + err := context.Canceled + handler.metrics.updateForError(err) + return err } }