diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index c2f2f81bae29..8eec3dd6ba84 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -21,10 +21,12 @@ go_library( deps = [ "//pkg/ccl/sqlproxyccl/denylist", "//pkg/ccl/sqlproxyccl/idle", + "//pkg/ccl/sqlproxyccl/interceptor", "//pkg/ccl/sqlproxyccl/tenant", "//pkg/ccl/sqlproxyccl/throttler", "//pkg/roachpb", "//pkg/security/certmgr", + "//pkg/sql/pgwire", "//pkg/sql/pgwire/pgcode", "//pkg/util/contextutil", "//pkg/util/grpcutil", diff --git a/pkg/ccl/sqlproxyccl/forwarder.go b/pkg/ccl/sqlproxyccl/forwarder.go index 7b28a529d9e6..9810ef695a7f 100644 --- a/pkg/ccl/sqlproxyccl/forwarder.go +++ b/pkg/ccl/sqlproxyccl/forwarder.go @@ -11,6 +11,10 @@ package sqlproxyccl import ( "context" "net" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire" + "github.com/cockroachdb/errors" ) // forwarder is used to forward pgwire messages from the client to the server, @@ -31,9 +35,28 @@ type forwarder struct { // connection. In the context of a connection migration, serverConn is only // replaced once the session has successfully been deserialized, and the // old connection will be closed. + // + // All reads from these connections must go through the interceptors. It is + // not safe to read from these directly as the interceptors may have + // buffered data. clientConn net.Conn // client <-> proxy serverConn net.Conn // proxy <-> server + // clientInterceptor and serverInterceptor provides a convenient way to + // read and forward Postgres messages, while minimizing IO reads and memory + // allocations. + // + // These interceptors have to match clientConn and serverConn. See comment + // above on when those fields will be updated. + // + // TODO(jaylim-crl): Add updater functions that sets both conn and + // interceptor fields at the same time. At the moment, there's no use case + // besides the forward function. When connection migration happens, we + // will need to create a new serverInterceptor. We should remember to close + // old serverConn as well. + clientInterceptor *interceptor.BackendInterceptor // clientConn -> serverConn + serverInterceptor *interceptor.FrontendInterceptor // serverConn -> clientConn + // errChan is a buffered channel that contains the first forwarder error. // This channel may receive nil errors. errChan chan error @@ -42,51 +65,52 @@ type forwarder struct { // forward returns a new instance of forwarder, and starts forwarding messages // from clientConn to serverConn. When this is called, it is expected that the // caller passes ownership of serverConn to the forwarder, which implies that -// the forwarder will clean up serverConn. -// -// All goroutines spun up must check on f.ctx to prevent leaks, if possible. If -// there was an error within the goroutines, the forwarder will be closed, and -// the first error can be found in f.errChan. +// the forwarder will clean up serverConn. clientConn and serverConn must not +// be nil in all cases except for testing. // -// clientConn and serverConn must not be nil in all cases except testing. -// -// Note that callers MUST call Close in all cases, and should not rely on the -// fact that ctx was passed into forward(). There could be a possibility where -// the top-level context was cancelled, but the forwarder has not cleaned up. +// Note that callers MUST call Close in all cases, even if ctx was cancelled. func forward(ctx context.Context, clientConn, serverConn net.Conn) *forwarder { ctx, cancelFn := context.WithCancel(ctx) f := &forwarder{ - ctx: ctx, - ctxCancel: cancelFn, - clientConn: clientConn, - serverConn: serverConn, - errChan: make(chan error, 1), + ctx: ctx, + ctxCancel: cancelFn, + errChan: make(chan error, 1), } + // The net.Conn object for the client is switched to a net.Conn that + // unblocks Read every second on idle to check for exit conditions. + // This is mainly used to unblock the request processor whenever the + // forwarder has stopped, or a transfer has been requested. + clientConn = pgwire.NewReadTimeoutConn(clientConn, func() error { + // Context was cancelled. + if f.ctx.Err() != nil { + return f.ctx.Err() + } + // TODO(jaylim-crl): Check for transfer state here. + return nil + }) + + f.setClientConn(clientConn) + f.setServerConn(serverConn) + + // Start request (client to server) and response (server to client) + // processors. We will copy all pgwire messages/ from client to server + // (and vice-versa) until we encounter an error or a shutdown signal + // (i.e. context cancellation). go func() { - // Block until context is done. - <-f.ctx.Done() - - // Close the forwarder to clean up. This goroutine is temporarily here - // because the only way to unblock io.Copy is to close one of the ends, - // which will be done through closing the forwarder. Once we replace - // io.Copy with the interceptors, we could use f.ctx directly, and no - // longer need this goroutine. - // - // Note that if f.Close was called externally, this will result - // in two f.Close calls in total, i.e. one externally, and one here - // once the context gets cancelled. This is fine for now since we'll - // be removing this soon anyway. - f.Close() - }() + defer f.Close() - // Copy all pgwire messages from frontend to backend connection until we - // encounter an error or shutdown signal. + err := wrapClientToServerError(f.handleClientToServer()) + select { + case f.errChan <- err: /* error reported */ + default: /* the channel already contains an error */ + } + }() go func() { defer f.Close() - err := ConnectionCopy(f.serverConn, f.clientConn) + err := wrapServerToClientError(f.handleServerToClient()) select { case f.errChan <- err: /* error reported */ default: /* the channel already contains an error */ @@ -105,3 +129,78 @@ func (f *forwarder) Close() { // has already been closed. f.serverConn.Close() } + +// handleClientToServer handles the communication from the client to the server. +// This returns a context cancellation error whenever the forwarder's context +// is cancelled, or whenever forwarding fails. When ForwardMsg gets blocked on +// Read, we will unblock that through our custom readTimeoutConn wrapper, which +// gets triggered when context is cancelled. +func (f *forwarder) handleClientToServer() error { + for f.ctx.Err() == nil { + if _, err := f.clientInterceptor.ForwardMsg(f.serverConn); err != nil { + return err + } + } + return f.ctx.Err() +} + +// handleServerToClient handles the communication from the server to the client. +// This returns a context cancellation error whenever the forwarder's context +// is cancelled, or whenever forwarding fails. When ForwardMsg gets blocked on +// Read, we will unblock that by closing serverConn through f.Close(). +func (f *forwarder) handleServerToClient() error { + for f.ctx.Err() == nil { + if _, err := f.serverInterceptor.ForwardMsg(f.clientConn); err != nil { + return err + } + } + return f.ctx.Err() +} + +// setClientConn is a convenient helper to update clientConn, and will also +// create a matching interceptor for the given connection. It is the caller's +// responsibility to close the old connection before calling this, or there +// may be a leak. +func (f *forwarder) setClientConn(clientConn net.Conn) { + f.clientConn = clientConn + f.clientInterceptor = interceptor.NewBackendInterceptor(f.clientConn) +} + +// setServerConn is a convenient helper to update serverConn, and will also +// create a matching interceptor for the given connection. It is the caller's +// responsibility to close the old connection before calling this, or there +// may be a leak. +func (f *forwarder) setServerConn(serverConn net.Conn) { + f.serverConn = serverConn + f.serverInterceptor = interceptor.NewFrontendInterceptor(f.serverConn) +} + +// wrapClientToServerError overrides client to server errors for external +// consumption. +// +// TODO(jaylim-crl): We don't send any of these to the client today, +// unfortunately. At the moment, this is only used for metrics. See TODO in +// proxy_handler about sending safely to avoid corrupted packets. Handle these +// errors in a friendly manner. +func wrapClientToServerError(err error) error { + if err == nil || + errors.IsAny(err, context.Canceled, context.DeadlineExceeded) { + return nil + } + return newErrorf(codeClientDisconnected, "copying from client to target server: %v", err) +} + +// wrapServerToClientError overrides server to client errors for external +// consumption. +// +// TODO(jaylim-crl): We don't send any of these to the client today, +// unfortunately. At the moment, this is only used for metrics. See TODO in +// proxy_handler about sending safely to avoid corrupted packets. Handle these +// errors in a friendly manner. +func wrapServerToClientError(err error) error { + if err == nil || + errors.IsAny(err, context.Canceled, context.DeadlineExceeded) { + return nil + } + return newErrorf(codeBackendDisconnected, "copying from target server to client: %s", err) +} diff --git a/pkg/ccl/sqlproxyccl/forwarder_test.go b/pkg/ccl/sqlproxyccl/forwarder_test.go index 762402873c14..0acdf12e5c00 100644 --- a/pkg/ccl/sqlproxyccl/forwarder_test.go +++ b/pkg/ccl/sqlproxyccl/forwarder_test.go @@ -9,6 +9,7 @@ package sqlproxyccl import ( + "bytes" "context" "net" "testing" @@ -186,3 +187,115 @@ func TestForwarder_Close(t *testing.T) { f.Close() require.EqualError(t, f.ctx.Err(), context.Canceled.Error()) } + +func TestForwarder_setClientConn(t *testing.T) { + defer leaktest.AfterTest(t)() + f := &forwarder{serverConn: nil, serverInterceptor: nil} + + w, r := net.Pipe() + defer w.Close() + defer r.Close() + + f.setClientConn(r) + require.Equal(t, r, f.clientConn) + + dst := new(bytes.Buffer) + errChan := make(chan error, 1) + go func() { + _, err := f.clientInterceptor.ForwardMsg(dst) + errChan <- err + }() + + _, err := w.Write((&pgproto3.Query{String: "SELECT 1"}).Encode(nil)) + require.NoError(t, err) + + // Block until message has been forwarded. This checks that we are creating + // our interceptor properly. + err = <-errChan + require.NoError(t, err) + require.Equal(t, 14, dst.Len()) +} + +func TestForwarder_setServerConn(t *testing.T) { + defer leaktest.AfterTest(t)() + f := &forwarder{serverConn: nil, serverInterceptor: nil} + + w, r := net.Pipe() + defer w.Close() + defer r.Close() + + f.setServerConn(r) + require.Equal(t, r, f.serverConn) + + dst := new(bytes.Buffer) + errChan := make(chan error, 1) + go func() { + _, err := f.serverInterceptor.ForwardMsg(dst) + errChan <- err + }() + + _, err := w.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)) + require.NoError(t, err) + + // Block until message has been forwarded. This checks that we are creating + // our interceptor properly. + err = <-errChan + require.NoError(t, err) + require.Equal(t, 6, dst.Len()) +} + +func TestWrapClientToServerError(t *testing.T) { + defer leaktest.AfterTest(t)() + + for _, tc := range []struct { + input error + output error + }{ + // Nil errors. + {nil, nil}, + {context.Canceled, nil}, + {context.DeadlineExceeded, nil}, + {errors.Mark(errors.New("foo"), context.Canceled), nil}, + {errors.Wrap(context.DeadlineExceeded, "foo"), nil}, + // Forwarding errors. + {errors.New("foo"), newErrorf( + codeClientDisconnected, + "copying from client to target server: foo", + )}, + } { + err := wrapClientToServerError(tc.input) + if tc.output == nil { + require.NoError(t, err) + } else { + require.EqualError(t, err, tc.output.Error()) + } + } +} + +func TestWrapServerToClientError(t *testing.T) { + defer leaktest.AfterTest(t)() + + for _, tc := range []struct { + input error + output error + }{ + // Nil errors. + {nil, nil}, + {context.Canceled, nil}, + {context.DeadlineExceeded, nil}, + {errors.Mark(errors.New("foo"), context.Canceled), nil}, + {errors.Wrap(context.DeadlineExceeded, "foo"), nil}, + // Forwarding errors. + {errors.New("foo"), newErrorf( + codeBackendDisconnected, + "copying from target server to client: foo", + )}, + } { + err := wrapServerToClientError(tc.input) + if tc.output == nil { + require.NoError(t, err) + } else { + require.EqualError(t, err, tc.output.Error()) + } + } +} diff --git a/pkg/ccl/sqlproxyccl/proxy.go b/pkg/ccl/sqlproxyccl/proxy.go index e9fdf2dfda8c..ad77ca2ba7d9 100644 --- a/pkg/ccl/sqlproxyccl/proxy.go +++ b/pkg/ccl/sqlproxyccl/proxy.go @@ -9,7 +9,6 @@ package sqlproxyccl import ( - "io" "net" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" @@ -80,45 +79,3 @@ var SendErrToClient = func(conn net.Conn, err error) { } _, _ = conn.Write(toPgError(err).Encode(nil)) } - -// ConnectionCopy does a bi-directional copy between the backend and frontend -// connections. It terminates when one of connections terminate. -func ConnectionCopy(crdbConn, conn net.Conn) error { - errOutgoing := make(chan error, 1) - errIncoming := make(chan error, 1) - - go func() { - _, err := io.Copy(crdbConn, conn) - errOutgoing <- err - }() - go func() { - _, err := io.Copy(conn, crdbConn) - errIncoming <- err - }() - - select { - // NB: when using pgx, we see a nil errIncoming first on clean connection - // termination. Using psql I see a nil errOutgoing first. I think the PG - // protocol stipulates sending a message to the server at which point the - // server closes the connection (errIncoming), but presumably the client - // gets to close the connection once it's sent that message, meaning either - // case is possible. - case err := <-errIncoming: - if err == nil { - return nil - } else if codeErr := (*codeError)(nil); errors.As(err, &codeErr) && - codeErr.code == codeExpiredClientConnection { - return codeErr - } else if ne := (net.Error)(nil); errors.As(err, &ne) && ne.Timeout() { - return newErrorf(codeIdleDisconnect, "terminating connection due to idle timeout: %v", err) - } else { - return newErrorf(codeBackendDisconnected, "copying from target server to client: %s", err) - } - case err := <-errOutgoing: - // The incoming connection got closed. - if err != nil { - return newErrorf(codeClientDisconnected, "copying from target server to client: %v", err) - } - return nil - } -} diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 7d6305548ef8..3219931f3612 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -326,8 +326,10 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn } var f *forwarder defer func() { - // Only close crdbConn if the forwarder hasn't been started. When the - // forwarder has been created, crdbConn is owned by the forwarder. + // Only close crdbConn if the forwarder hasn't been started. We want + // this here to ensure that we close the idle monitor wrapped + // connection. If the forwarder has been created, crdbConn is owned by + // the forwarder. if f == nil { _ = crdbConn.Close() }