From d700758e12d468435ae57305a8a8596b48ba84e0 Mon Sep 17 00:00:00 2001 From: Jay Date: Tue, 15 Feb 2022 23:41:52 -0500 Subject: [PATCH] ccl/sqlproxyccl: replace ConnectionCopy logic in forwarder with interceptors Informs #76000. Previously, we were using io.Copy through ConnectionCopy to forward messages between the client and SQL server. Now that the interceptors have been merged, we will update the forwarder to use these interceptors instead of the old approach. There are a few notable changes in this commit: 1. We wrap clientConn with the a readTimeoutConn that was exposed in the previous commit. This allows us to unblock on Read whenever an activity occurs (e.g. context cancellation, and in the future, when transfer is requested). 2. There are two goroutines per connection within the forwarder: one for the request processor (client to server), and the other for the response processor (server to client). 3. We also removed unnecessary checks for codeExpiredClientConnection and codeIdleDisconnect errors when copying from server to client. These are already handled by the idle monitor's callback as well as the denylist watcher's callback. When those constructs were implemented back then, we did not remove them from ConnectionCopy. We need to clean up error messages, and how we propagate them back to the user because today we just close the connection without returning a response, resulting in, I believe, a broken pipe error. Release note: None --- pkg/ccl/sqlproxyccl/BUILD.bazel | 2 + pkg/ccl/sqlproxyccl/forwarder.go | 165 ++++++++++++++++++++------ pkg/ccl/sqlproxyccl/forwarder_test.go | 113 ++++++++++++++++++ pkg/ccl/sqlproxyccl/proxy.go | 43 ------- pkg/ccl/sqlproxyccl/proxy_handler.go | 6 +- 5 files changed, 251 insertions(+), 78 deletions(-) diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index c2f2f81bae29..8eec3dd6ba84 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -21,10 +21,12 @@ go_library( deps = [ "//pkg/ccl/sqlproxyccl/denylist", "//pkg/ccl/sqlproxyccl/idle", + "//pkg/ccl/sqlproxyccl/interceptor", "//pkg/ccl/sqlproxyccl/tenant", "//pkg/ccl/sqlproxyccl/throttler", "//pkg/roachpb", "//pkg/security/certmgr", + "//pkg/sql/pgwire", "//pkg/sql/pgwire/pgcode", "//pkg/util/contextutil", "//pkg/util/grpcutil", diff --git a/pkg/ccl/sqlproxyccl/forwarder.go b/pkg/ccl/sqlproxyccl/forwarder.go index 7b28a529d9e6..9810ef695a7f 100644 --- a/pkg/ccl/sqlproxyccl/forwarder.go +++ b/pkg/ccl/sqlproxyccl/forwarder.go @@ -11,6 +11,10 @@ package sqlproxyccl import ( "context" "net" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire" + "github.com/cockroachdb/errors" ) // forwarder is used to forward pgwire messages from the client to the server, @@ -31,9 +35,28 @@ type forwarder struct { // connection. In the context of a connection migration, serverConn is only // replaced once the session has successfully been deserialized, and the // old connection will be closed. + // + // All reads from these connections must go through the interceptors. It is + // not safe to read from these directly as the interceptors may have + // buffered data. clientConn net.Conn // client <-> proxy serverConn net.Conn // proxy <-> server + // clientInterceptor and serverInterceptor provides a convenient way to + // read and forward Postgres messages, while minimizing IO reads and memory + // allocations. + // + // These interceptors have to match clientConn and serverConn. See comment + // above on when those fields will be updated. + // + // TODO(jaylim-crl): Add updater functions that sets both conn and + // interceptor fields at the same time. At the moment, there's no use case + // besides the forward function. When connection migration happens, we + // will need to create a new serverInterceptor. We should remember to close + // old serverConn as well. + clientInterceptor *interceptor.BackendInterceptor // clientConn -> serverConn + serverInterceptor *interceptor.FrontendInterceptor // serverConn -> clientConn + // errChan is a buffered channel that contains the first forwarder error. // This channel may receive nil errors. errChan chan error @@ -42,51 +65,52 @@ type forwarder struct { // forward returns a new instance of forwarder, and starts forwarding messages // from clientConn to serverConn. When this is called, it is expected that the // caller passes ownership of serverConn to the forwarder, which implies that -// the forwarder will clean up serverConn. -// -// All goroutines spun up must check on f.ctx to prevent leaks, if possible. If -// there was an error within the goroutines, the forwarder will be closed, and -// the first error can be found in f.errChan. +// the forwarder will clean up serverConn. clientConn and serverConn must not +// be nil in all cases except for testing. // -// clientConn and serverConn must not be nil in all cases except testing. -// -// Note that callers MUST call Close in all cases, and should not rely on the -// fact that ctx was passed into forward(). There could be a possibility where -// the top-level context was cancelled, but the forwarder has not cleaned up. +// Note that callers MUST call Close in all cases, even if ctx was cancelled. func forward(ctx context.Context, clientConn, serverConn net.Conn) *forwarder { ctx, cancelFn := context.WithCancel(ctx) f := &forwarder{ - ctx: ctx, - ctxCancel: cancelFn, - clientConn: clientConn, - serverConn: serverConn, - errChan: make(chan error, 1), + ctx: ctx, + ctxCancel: cancelFn, + errChan: make(chan error, 1), } + // The net.Conn object for the client is switched to a net.Conn that + // unblocks Read every second on idle to check for exit conditions. + // This is mainly used to unblock the request processor whenever the + // forwarder has stopped, or a transfer has been requested. + clientConn = pgwire.NewReadTimeoutConn(clientConn, func() error { + // Context was cancelled. + if f.ctx.Err() != nil { + return f.ctx.Err() + } + // TODO(jaylim-crl): Check for transfer state here. + return nil + }) + + f.setClientConn(clientConn) + f.setServerConn(serverConn) + + // Start request (client to server) and response (server to client) + // processors. We will copy all pgwire messages/ from client to server + // (and vice-versa) until we encounter an error or a shutdown signal + // (i.e. context cancellation). go func() { - // Block until context is done. - <-f.ctx.Done() - - // Close the forwarder to clean up. This goroutine is temporarily here - // because the only way to unblock io.Copy is to close one of the ends, - // which will be done through closing the forwarder. Once we replace - // io.Copy with the interceptors, we could use f.ctx directly, and no - // longer need this goroutine. - // - // Note that if f.Close was called externally, this will result - // in two f.Close calls in total, i.e. one externally, and one here - // once the context gets cancelled. This is fine for now since we'll - // be removing this soon anyway. - f.Close() - }() + defer f.Close() - // Copy all pgwire messages from frontend to backend connection until we - // encounter an error or shutdown signal. + err := wrapClientToServerError(f.handleClientToServer()) + select { + case f.errChan <- err: /* error reported */ + default: /* the channel already contains an error */ + } + }() go func() { defer f.Close() - err := ConnectionCopy(f.serverConn, f.clientConn) + err := wrapServerToClientError(f.handleServerToClient()) select { case f.errChan <- err: /* error reported */ default: /* the channel already contains an error */ @@ -105,3 +129,78 @@ func (f *forwarder) Close() { // has already been closed. f.serverConn.Close() } + +// handleClientToServer handles the communication from the client to the server. +// This returns a context cancellation error whenever the forwarder's context +// is cancelled, or whenever forwarding fails. When ForwardMsg gets blocked on +// Read, we will unblock that through our custom readTimeoutConn wrapper, which +// gets triggered when context is cancelled. +func (f *forwarder) handleClientToServer() error { + for f.ctx.Err() == nil { + if _, err := f.clientInterceptor.ForwardMsg(f.serverConn); err != nil { + return err + } + } + return f.ctx.Err() +} + +// handleServerToClient handles the communication from the server to the client. +// This returns a context cancellation error whenever the forwarder's context +// is cancelled, or whenever forwarding fails. When ForwardMsg gets blocked on +// Read, we will unblock that by closing serverConn through f.Close(). +func (f *forwarder) handleServerToClient() error { + for f.ctx.Err() == nil { + if _, err := f.serverInterceptor.ForwardMsg(f.clientConn); err != nil { + return err + } + } + return f.ctx.Err() +} + +// setClientConn is a convenient helper to update clientConn, and will also +// create a matching interceptor for the given connection. It is the caller's +// responsibility to close the old connection before calling this, or there +// may be a leak. +func (f *forwarder) setClientConn(clientConn net.Conn) { + f.clientConn = clientConn + f.clientInterceptor = interceptor.NewBackendInterceptor(f.clientConn) +} + +// setServerConn is a convenient helper to update serverConn, and will also +// create a matching interceptor for the given connection. It is the caller's +// responsibility to close the old connection before calling this, or there +// may be a leak. +func (f *forwarder) setServerConn(serverConn net.Conn) { + f.serverConn = serverConn + f.serverInterceptor = interceptor.NewFrontendInterceptor(f.serverConn) +} + +// wrapClientToServerError overrides client to server errors for external +// consumption. +// +// TODO(jaylim-crl): We don't send any of these to the client today, +// unfortunately. At the moment, this is only used for metrics. See TODO in +// proxy_handler about sending safely to avoid corrupted packets. Handle these +// errors in a friendly manner. +func wrapClientToServerError(err error) error { + if err == nil || + errors.IsAny(err, context.Canceled, context.DeadlineExceeded) { + return nil + } + return newErrorf(codeClientDisconnected, "copying from client to target server: %v", err) +} + +// wrapServerToClientError overrides server to client errors for external +// consumption. +// +// TODO(jaylim-crl): We don't send any of these to the client today, +// unfortunately. At the moment, this is only used for metrics. See TODO in +// proxy_handler about sending safely to avoid corrupted packets. Handle these +// errors in a friendly manner. +func wrapServerToClientError(err error) error { + if err == nil || + errors.IsAny(err, context.Canceled, context.DeadlineExceeded) { + return nil + } + return newErrorf(codeBackendDisconnected, "copying from target server to client: %s", err) +} diff --git a/pkg/ccl/sqlproxyccl/forwarder_test.go b/pkg/ccl/sqlproxyccl/forwarder_test.go index 762402873c14..0acdf12e5c00 100644 --- a/pkg/ccl/sqlproxyccl/forwarder_test.go +++ b/pkg/ccl/sqlproxyccl/forwarder_test.go @@ -9,6 +9,7 @@ package sqlproxyccl import ( + "bytes" "context" "net" "testing" @@ -186,3 +187,115 @@ func TestForwarder_Close(t *testing.T) { f.Close() require.EqualError(t, f.ctx.Err(), context.Canceled.Error()) } + +func TestForwarder_setClientConn(t *testing.T) { + defer leaktest.AfterTest(t)() + f := &forwarder{serverConn: nil, serverInterceptor: nil} + + w, r := net.Pipe() + defer w.Close() + defer r.Close() + + f.setClientConn(r) + require.Equal(t, r, f.clientConn) + + dst := new(bytes.Buffer) + errChan := make(chan error, 1) + go func() { + _, err := f.clientInterceptor.ForwardMsg(dst) + errChan <- err + }() + + _, err := w.Write((&pgproto3.Query{String: "SELECT 1"}).Encode(nil)) + require.NoError(t, err) + + // Block until message has been forwarded. This checks that we are creating + // our interceptor properly. + err = <-errChan + require.NoError(t, err) + require.Equal(t, 14, dst.Len()) +} + +func TestForwarder_setServerConn(t *testing.T) { + defer leaktest.AfterTest(t)() + f := &forwarder{serverConn: nil, serverInterceptor: nil} + + w, r := net.Pipe() + defer w.Close() + defer r.Close() + + f.setServerConn(r) + require.Equal(t, r, f.serverConn) + + dst := new(bytes.Buffer) + errChan := make(chan error, 1) + go func() { + _, err := f.serverInterceptor.ForwardMsg(dst) + errChan <- err + }() + + _, err := w.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)) + require.NoError(t, err) + + // Block until message has been forwarded. This checks that we are creating + // our interceptor properly. + err = <-errChan + require.NoError(t, err) + require.Equal(t, 6, dst.Len()) +} + +func TestWrapClientToServerError(t *testing.T) { + defer leaktest.AfterTest(t)() + + for _, tc := range []struct { + input error + output error + }{ + // Nil errors. + {nil, nil}, + {context.Canceled, nil}, + {context.DeadlineExceeded, nil}, + {errors.Mark(errors.New("foo"), context.Canceled), nil}, + {errors.Wrap(context.DeadlineExceeded, "foo"), nil}, + // Forwarding errors. + {errors.New("foo"), newErrorf( + codeClientDisconnected, + "copying from client to target server: foo", + )}, + } { + err := wrapClientToServerError(tc.input) + if tc.output == nil { + require.NoError(t, err) + } else { + require.EqualError(t, err, tc.output.Error()) + } + } +} + +func TestWrapServerToClientError(t *testing.T) { + defer leaktest.AfterTest(t)() + + for _, tc := range []struct { + input error + output error + }{ + // Nil errors. + {nil, nil}, + {context.Canceled, nil}, + {context.DeadlineExceeded, nil}, + {errors.Mark(errors.New("foo"), context.Canceled), nil}, + {errors.Wrap(context.DeadlineExceeded, "foo"), nil}, + // Forwarding errors. + {errors.New("foo"), newErrorf( + codeBackendDisconnected, + "copying from target server to client: foo", + )}, + } { + err := wrapServerToClientError(tc.input) + if tc.output == nil { + require.NoError(t, err) + } else { + require.EqualError(t, err, tc.output.Error()) + } + } +} diff --git a/pkg/ccl/sqlproxyccl/proxy.go b/pkg/ccl/sqlproxyccl/proxy.go index e9fdf2dfda8c..ad77ca2ba7d9 100644 --- a/pkg/ccl/sqlproxyccl/proxy.go +++ b/pkg/ccl/sqlproxyccl/proxy.go @@ -9,7 +9,6 @@ package sqlproxyccl import ( - "io" "net" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" @@ -80,45 +79,3 @@ var SendErrToClient = func(conn net.Conn, err error) { } _, _ = conn.Write(toPgError(err).Encode(nil)) } - -// ConnectionCopy does a bi-directional copy between the backend and frontend -// connections. It terminates when one of connections terminate. -func ConnectionCopy(crdbConn, conn net.Conn) error { - errOutgoing := make(chan error, 1) - errIncoming := make(chan error, 1) - - go func() { - _, err := io.Copy(crdbConn, conn) - errOutgoing <- err - }() - go func() { - _, err := io.Copy(conn, crdbConn) - errIncoming <- err - }() - - select { - // NB: when using pgx, we see a nil errIncoming first on clean connection - // termination. Using psql I see a nil errOutgoing first. I think the PG - // protocol stipulates sending a message to the server at which point the - // server closes the connection (errIncoming), but presumably the client - // gets to close the connection once it's sent that message, meaning either - // case is possible. - case err := <-errIncoming: - if err == nil { - return nil - } else if codeErr := (*codeError)(nil); errors.As(err, &codeErr) && - codeErr.code == codeExpiredClientConnection { - return codeErr - } else if ne := (net.Error)(nil); errors.As(err, &ne) && ne.Timeout() { - return newErrorf(codeIdleDisconnect, "terminating connection due to idle timeout: %v", err) - } else { - return newErrorf(codeBackendDisconnected, "copying from target server to client: %s", err) - } - case err := <-errOutgoing: - // The incoming connection got closed. - if err != nil { - return newErrorf(codeClientDisconnected, "copying from target server to client: %v", err) - } - return nil - } -} diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 7d6305548ef8..3219931f3612 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -326,8 +326,10 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn } var f *forwarder defer func() { - // Only close crdbConn if the forwarder hasn't been started. When the - // forwarder has been created, crdbConn is owned by the forwarder. + // Only close crdbConn if the forwarder hasn't been started. We want + // this here to ensure that we close the idle monitor wrapped + // connection. If the forwarder has been created, crdbConn is owned by + // the forwarder. if f == nil { _ = crdbConn.Close() }