From df932de684d760fcb48bc1f327c4cbb2d899c8c9 Mon Sep 17 00:00:00 2001 From: Jay Date: Mon, 7 Feb 2022 11:12:16 -0500 Subject: [PATCH] ccl/sqlproxyccl: add basic forwarder component Informs #76000. This commit refactors the ConnectionCopy call in proxy_handler.go into a new forwarder component, which was described in the connection migration RFC. At the moment, this forwarder component does basic forwarding through ConnectionCopy, just like before, so there should be no behavioral changes to the proxy. This will serve as a building block for subsequent PRs. Release note: None --- pkg/ccl/sqlproxyccl/BUILD.bazel | 2 + pkg/ccl/sqlproxyccl/forwarder.go | 107 +++++++++++++++ pkg/ccl/sqlproxyccl/forwarder_test.go | 188 ++++++++++++++++++++++++++ pkg/ccl/sqlproxyccl/proxy_handler.go | 55 ++++---- 4 files changed, 326 insertions(+), 26 deletions(-) create mode 100644 pkg/ccl/sqlproxyccl/forwarder.go create mode 100644 pkg/ccl/sqlproxyccl/forwarder_test.go diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 496d126a74ac..c2f2f81bae29 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -8,6 +8,7 @@ go_library( "backend_dialer.go", "connector.go", "error.go", + "forwarder.go", "frontend_admitter.go", "metrics.go", "proxy.go", @@ -50,6 +51,7 @@ go_test( srcs = [ "authentication_test.go", "connector_test.go", + "forwarder_test.go", "frontend_admitter_test.go", "main_test.go", "proxy_handler_test.go", diff --git a/pkg/ccl/sqlproxyccl/forwarder.go b/pkg/ccl/sqlproxyccl/forwarder.go new file mode 100644 index 000000000000..7b28a529d9e6 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/forwarder.go @@ -0,0 +1,107 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "net" +) + +// forwarder is used to forward pgwire messages from the client to the server, +// and vice-versa. At the moment, this does a direct proxying, and there is +// no intercepting. Once https://github.com/cockroachdb/cockroach/issues/76000 +// has been addressed, we will start intercepting pgwire messages at their +// boundaries here. +// +// The forwarder instance should always be constructed through the forward +// function, which also starts the forwarder. +type forwarder struct { + // ctx is a single context used to control all goroutines spawned by the + // forwarder. + ctx context.Context + ctxCancel context.CancelFunc + + // serverConn is only set after the authentication phase for the initial + // connection. In the context of a connection migration, serverConn is only + // replaced once the session has successfully been deserialized, and the + // old connection will be closed. + clientConn net.Conn // client <-> proxy + serverConn net.Conn // proxy <-> server + + // errChan is a buffered channel that contains the first forwarder error. + // This channel may receive nil errors. + errChan chan error +} + +// forward returns a new instance of forwarder, and starts forwarding messages +// from clientConn to serverConn. When this is called, it is expected that the +// caller passes ownership of serverConn to the forwarder, which implies that +// the forwarder will clean up serverConn. +// +// All goroutines spun up must check on f.ctx to prevent leaks, if possible. If +// there was an error within the goroutines, the forwarder will be closed, and +// the first error can be found in f.errChan. +// +// clientConn and serverConn must not be nil in all cases except testing. +// +// Note that callers MUST call Close in all cases, and should not rely on the +// fact that ctx was passed into forward(). There could be a possibility where +// the top-level context was cancelled, but the forwarder has not cleaned up. +func forward(ctx context.Context, clientConn, serverConn net.Conn) *forwarder { + ctx, cancelFn := context.WithCancel(ctx) + + f := &forwarder{ + ctx: ctx, + ctxCancel: cancelFn, + clientConn: clientConn, + serverConn: serverConn, + errChan: make(chan error, 1), + } + + go func() { + // Block until context is done. + <-f.ctx.Done() + + // Close the forwarder to clean up. This goroutine is temporarily here + // because the only way to unblock io.Copy is to close one of the ends, + // which will be done through closing the forwarder. Once we replace + // io.Copy with the interceptors, we could use f.ctx directly, and no + // longer need this goroutine. + // + // Note that if f.Close was called externally, this will result + // in two f.Close calls in total, i.e. one externally, and one here + // once the context gets cancelled. This is fine for now since we'll + // be removing this soon anyway. + f.Close() + }() + + // Copy all pgwire messages from frontend to backend connection until we + // encounter an error or shutdown signal. + go func() { + defer f.Close() + + err := ConnectionCopy(f.serverConn, f.clientConn) + select { + case f.errChan <- err: /* error reported */ + default: /* the channel already contains an error */ + } + }() + + return f +} + +// Close closes the forwarder, and stops the forwarding process. This is +// idempotent. +func (f *forwarder) Close() { + f.ctxCancel() + + // Since Close is idempotent, we'll ignore the error from Close in case it + // has already been closed. + f.serverConn.Close() +} diff --git a/pkg/ccl/sqlproxyccl/forwarder_test.go b/pkg/ccl/sqlproxyccl/forwarder_test.go new file mode 100644 index 000000000000..762402873c14 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/forwarder_test.go @@ -0,0 +1,188 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "net" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/require" +) + +func TestForward(t *testing.T) { + defer leaktest.AfterTest(t)() + + bgCtx := context.Background() + + t.Run("closed_when_processors_error", func(t *testing.T) { + p1, p2 := net.Pipe() + // Close the connection right away. p2 is owned by the forwarder. + p1.Close() + + f := forward(bgCtx, p1, p2) + defer f.Close() + + // We have to wait for the goroutine to run. Once the forwarder stops, + // we're good. + testutils.SucceedsSoon(t, func() error { + if f.ctx.Err() != nil { + return nil + } + return errors.New("forwarder is still running") + }) + }) + + t.Run("client_to_server", func(t *testing.T) { + ctx, cancel := context.WithTimeout(bgCtx, 5*time.Second) + defer cancel() + + clientW, clientR := net.Pipe() + serverW, serverR := net.Pipe() + // We don't close clientW and serverR here since we have no control + // over those. serverW is not closed since the forwarder is responsible + // for that. + defer clientR.Close() + + f := forward(ctx, clientR, serverW) + defer f.Close() + require.Nil(t, f.ctx.Err()) + + // Client writes some pgwire messages. + errChan := make(chan error, 1) + go func() { + _, err := clientW.Write((&pgproto3.Query{ + String: "SELECT 1", + }).Encode(nil)) + if err != nil { + errChan <- err + return + } + + if _, err := clientW.Write((&pgproto3.Execute{ + Portal: "foobar", + MaxRows: 42, + }).Encode(nil)); err != nil { + errChan <- err + return + } + + if _, err := clientW.Write((&pgproto3.Close{ + ObjectType: 'P', + }).Encode(nil)); err != nil { + errChan <- err + return + } + }() + + // Server should receive messages in order. + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(serverR), serverR) + + msg, err := backend.Receive() + require.NoError(t, err) + m1, ok := msg.(*pgproto3.Query) + require.True(t, ok) + require.Equal(t, "SELECT 1", m1.String) + + msg, err = backend.Receive() + require.NoError(t, err) + m2, ok := msg.(*pgproto3.Execute) + require.True(t, ok) + require.Equal(t, "foobar", m2.Portal) + require.Equal(t, uint32(42), m2.MaxRows) + + msg, err = backend.Receive() + require.NoError(t, err) + m3, ok := msg.(*pgproto3.Close) + require.True(t, ok) + require.Equal(t, byte('P'), m3.ObjectType) + + select { + case err = <-errChan: + t.Fatalf("require no error, but got %v", err) + default: + } + }) + + t.Run("server_to_client", func(t *testing.T) { + ctx, cancel := context.WithTimeout(bgCtx, 5*time.Second) + defer cancel() + + clientW, clientR := net.Pipe() + serverW, serverR := net.Pipe() + // We don't close clientW and serverR here since we have no control + // over those. serverW is not closed since the forwarder is responsible + // for that. + defer clientR.Close() + + f := forward(ctx, clientR, serverW) + defer f.Close() + require.Nil(t, f.ctx.Err()) + + // Server writes some pgwire messages. + errChan := make(chan error, 1) + go func() { + if _, err := serverR.Write((&pgproto3.ErrorResponse{ + Code: "100", + Message: "foobarbaz", + }).Encode(nil)); err != nil { + errChan <- err + return + } + + if _, err := serverR.Write((&pgproto3.ReadyForQuery{ + TxStatus: 'I', + }).Encode(nil)); err != nil { + errChan <- err + return + } + }() + + // Client should receive messages in order. + frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(clientW), clientW) + + msg, err := frontend.Receive() + require.NoError(t, err) + m1, ok := msg.(*pgproto3.ErrorResponse) + require.True(t, ok) + require.Equal(t, "100", m1.Code) + require.Equal(t, "foobarbaz", m1.Message) + + msg, err = frontend.Receive() + require.NoError(t, err) + m2, ok := msg.(*pgproto3.ReadyForQuery) + require.True(t, ok) + require.Equal(t, byte('I'), m2.TxStatus) + + select { + case err = <-errChan: + t.Fatalf("require no error, but got %v", err) + default: + } + }) +} + +func TestForwarder_Close(t *testing.T) { + defer leaktest.AfterTest(t)() + + p1, p2 := net.Pipe() + defer p1.Close() // p2 is owned by the forwarder. + + f := forward(context.Background(), p1, p2) + defer f.Close() + require.Nil(t, f.ctx.Err()) + + f.Close() + require.EqualError(t, f.ctx.Err(), context.Canceled.Error()) +} diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index f0e039501ac8..7d6305548ef8 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -324,46 +324,49 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn } return err } - defer crdbConn.Close() + var f *forwarder + defer func() { + // Only close crdbConn if the forwarder hasn't been started. When the + // forwarder has been created, crdbConn is owned by the forwarder. + if f == nil { + _ = crdbConn.Close() + } + }() handler.metrics.SuccessfulConnCount.Inc(1) - ctx, cancel := context.WithCancel(ctx) - defer cancel() - log.Infof(ctx, "new connection") connBegin := timeutil.Now() defer func() { log.Infof(ctx, "closing after %.2fs", timeutil.Since(connBegin).Seconds()) }() - // Copy all pgwire messages from frontend to backend connection until we - // encounter an error or shutdown signal. - go func() { - err := ConnectionCopy(crdbConn, conn) - select { - case errConnection <- err: /* error reported */ - default: /* the channel already contains an error */ - } - }() + // Pass ownership of crdbConn to the forwarder. + f = forward(ctx, conn, crdbConn) + defer f.Close() + // Block until an error is received, or when the stopper starts quiescing, + // whichever that happens first. + // + // TODO(jaylim-crl): We should handle all these errors properly, and + // propagate them back to the client if we're in a safe position to send. + // This PR https://github.com/cockroachdb/cockroach/pull/66205 removed error + // injections after connection handoff because there was a possibility of + // corrupted packets. + // + // TODO(jaylim-crl): It would be nice to have more consistency in how we + // manage background goroutines, communicate errors, etc. select { - case err := <-errConnection: + case err := <-f.errChan: // From forwarder. + handler.metrics.updateForError(err) + return err + case err := <-errConnection: // From denyListWatcher or idleMonitor. handler.metrics.updateForError(err) return err - case <-ctx.Done(): - err := ctx.Err() - if err != nil { - // The client connection expired. - codeErr := newErrorf( - codeExpiredClientConnection, "expired client conn: %v", err, - ) - handler.metrics.updateForError(codeErr) - return codeErr - } - return nil case <-handler.stopper.ShouldQuiesce(): - return nil + err := context.Canceled + handler.metrics.updateForError(err) + return err } }