From fddfd308ec3106e1be705dbd26c99e38fb6efa26 Mon Sep 17 00:00:00 2001 From: Jay Date: Thu, 10 Mar 2022 11:03:55 -0500 Subject: [PATCH] ccl/sqlproxyccl: complete connection migration support in the forwarder Informs cockroachdb#76000. This commit completes the connection migration feature in the the forwarder within sqlproxy. The idea is as described in the RFC. A couple of new sqlproxy metrics have been added as well: - proxy.conn_migration.success - proxy.conn_migration.error_fatal - proxy.conn_migration.error_recoverable - proxy.conn_migration.attempted For more details, see metrics.go in the sqlproxyccl package. Release justification: This completes the first half of the connection migration feature. This is low risk as part of the code is guarded behind the connection migration feature, which is currently not being used in production. To add on, CockroachCloud is the only user of sqlproxy. Release note: None --- pkg/ccl/sqlproxyccl/BUILD.bazel | 4 + pkg/ccl/sqlproxyccl/conn_migration.go | 241 +++++++++++ pkg/ccl/sqlproxyccl/conn_migration_test.go | 5 + pkg/ccl/sqlproxyccl/forwarder.go | 109 +++-- pkg/ccl/sqlproxyccl/forwarder_test.go | 10 +- pkg/ccl/sqlproxyccl/interceptor/pg_conn.go | 14 + .../sqlproxyccl/interceptor/pg_conn_test.go | 10 + pkg/ccl/sqlproxyccl/metrics.go | 58 ++- pkg/ccl/sqlproxyccl/proxy_handler.go | 14 + pkg/ccl/sqlproxyccl/proxy_handler_test.go | 373 ++++++++++++++++++ 10 files changed, 787 insertions(+), 51 deletions(-) diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 905d0e160461..a59a6e0df844 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -39,6 +39,7 @@ go_library( "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/timeutil", + "//pkg/util/uuid", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_logtags//:logtags", "@com_github_jackc_pgproto3_v2//:pgproto3", @@ -78,6 +79,8 @@ go_test( "//pkg/sql", "//pkg/sql/pgwire", "//pkg/sql/pgwire/pgerror", + "//pkg/sql/pgwire/pgwirebase", + "//pkg/sql/tests", "//pkg/testutils", "//pkg/testutils/serverutils", "//pkg/testutils/skip", @@ -90,6 +93,7 @@ go_test( "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/timeutil", + "@com_github_cockroachdb_cockroach_go_v2//crdb", "@com_github_cockroachdb_errors//:errors", "@com_github_jackc_pgconn//:pgconn", "@com_github_jackc_pgproto3_v2//:pgproto3", diff --git a/pkg/ccl/sqlproxyccl/conn_migration.go b/pkg/ccl/sqlproxyccl/conn_migration.go index 1858690741cc..dcc3ae776e61 100644 --- a/pkg/ccl/sqlproxyccl/conn_migration.go +++ b/pkg/ccl/sqlproxyccl/conn_migration.go @@ -13,13 +13,254 @@ import ( "encoding/json" "fmt" "io" + "time" "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" pgproto3 "github.com/jackc/pgproto3/v2" ) +// defaultTransferTimeout corresponds to the timeout period for the connection +// migration process. If the timeout gets triggered, and we're in a non +// recoverable state, the connection will be closed. +// +// This is a variable instead of a constant to support testing hooks. +var defaultTransferTimeout = 15 * time.Second + +var errTransferCannotStart = errors.New("transfer cannot be started") + +// tryBeginTransfer returns true if the transfer can be started, and false +// otherwise. If the transfer can be started, it updates the state of the +// forwarder to indicate that a transfer is in progress. +func (f *forwarder) tryBeginTransfer() bool { + f.mu.Lock() + defer f.mu.Unlock() + + // Transfer is already in progress. No concurrent transfers are allowed. + if f.mu.isTransferring { + return false + } + + if !isSafeTransferPoint(f.mu.request, f.mu.response) { + return false + } + + f.mu.isTransferring = true + return true +} + +type transferContext struct { + context.Context + mu struct { + syncutil.Mutex + recoverableConn bool + } +} + +func newTransferContext(backgroundCtx context.Context) (*transferContext, context.CancelFunc) { + transferCtx, cancel := context.WithTimeout(backgroundCtx, defaultTransferTimeout) // nolint:context + ctx := &transferContext{ + Context: transferCtx, + } + ctx.mu.recoverableConn = true + return ctx, cancel +} + +func (t *transferContext) markRecoverable(r bool) { + t.mu.Lock() + defer t.mu.Unlock() + t.mu.recoverableConn = r +} + +func (t *transferContext) isRecoverable() bool { + t.mu.Lock() + defer t.mu.Unlock() + return t.mu.recoverableConn +} + +func (f *forwarder) runTransfer() (retErr error) { + // A previous non-recoverable transfer would have closed the forwarder, so + // return right away. + if f.ctx.Err() != nil { + return f.ctx.Err() + } + + if !f.tryBeginTransfer() { + return errTransferCannotStart + } + defer func() { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.isTransferring = false + }() + + f.metrics.ConnMigrationAttemptedCount.Inc(1) + + // Create a transfer context, and timeout handler which gets triggered + // whenever the context expires. We have to close the forwarder because + // the transfer may be blocked on I/O, and the only way for now is to close + // the connections. This then allow runTransfer to return and cleanup. + ctx, cancel := newTransferContext(f.ctx) + defer cancel() + + go func() { + <-ctx.Done() + if !ctx.isRecoverable() { + f.Close() + } + }() + + // Use a separate context for logging because f.ctx will be closed whenever + // the connection is non-recoverable. + logCtx := logtags.WithTags(context.Background(), logtags.FromContext(f.ctx)) + defer func() { + if !ctx.isRecoverable() { + log.Infof(logCtx, "transfer failed: connection closed, err=%v", retErr) + f.metrics.ConnMigrationErrorFatalCount.Inc(1) + } else { + // Transfer was successful. + if retErr == nil { + log.Infof(logCtx, "transfer successful") + f.metrics.ConnMigrationSuccessCount.Inc(1) + } else { + log.Infof(logCtx, "transfer failed: connection recovered, err=%v", retErr) + f.metrics.ConnMigrationErrorRecoverableCount.Inc(1) + } + f.resumeProcessors() + } + }() + + f.mu.Lock() + defer f.mu.Unlock() + + // Suspend both processors before starting the transfer. + f.mu.request.suspend(ctx) + f.mu.response.suspend(ctx) + + // Transfer the connection. + newServerConn, err := transferConnection(ctx, f.connector, f.mu.clientConn, f.mu.serverConn) + if err != nil { + return errors.Wrap(err, "transferring connection") + } + + // Transfer was successful. + clockFn := makeLogicalClockFn() + f.mu.serverConn.Close() + f.mu.serverConn = newServerConn + f.mu.request = newProcessor(clockFn, f.mu.clientConn, f.mu.serverConn) + f.mu.response = newProcessor(clockFn, f.mu.serverConn, f.mu.clientConn) + return nil +} + +// transferConnection performs the transfer operation for the current server +// connection, and returns the a new connection to the server that the +// connection got transferred to. +func transferConnection( + ctx *transferContext, connector *connector, clientConn, serverConn *interceptor.PGConn, +) (_ *interceptor.PGConn, retErr error) { + ctx.markRecoverable(true) + + // Context was cancelled. + if ctx.Err() != nil { + return nil, ctx.Err() + } + + transferKey := uuid.MakeV4().String() + + // Send the SHOW TRANSFER STATE statement. At this point, connection is + // non-recoverable because the message has already been sent to the server. + ctx.markRecoverable(false) + if err := runShowTransferState(serverConn, transferKey); err != nil { + return nil, errors.Wrap(err, "sending transfer request") + } + + transferErr, state, revivalToken, err := waitForShowTransferState( + ctx, serverConn.ToFrontendConn(), clientConn, transferKey) + if err != nil { + return nil, errors.Wrap(err, "waiting for transfer state") + } + + // Failures after this point are recoverable, and connections should not be + // terminated. + ctx.markRecoverable(true) + + // If we consumed until ReadyForQuery without errors, but the transfer state + // response returns an error, we could still resume the connection, but the + // transfer process will need to be aborted. + // + // This case may happen pretty frequently (e.g. open transactions, temporary + // tables, etc.). + if transferErr != "" { + return nil, errors.Newf("%s", transferErr) + } + + // Connect to a new SQL pod. + // + // TODO(jaylim-crl): There is a possibility where the same pod will get + // selected. Some ideas to solve this: pass in the remote address of + // serverConn to avoid choosing that pod, or maybe a filter callback? + // We can also consider adding a target pod as an argument to RequestTransfer. + // That way a central component gets to choose where the connections go. + netConn, err := connector.OpenTenantConnWithToken(ctx, revivalToken) + if err != nil { + return nil, errors.Wrap(err, "opening connection") + } + defer func() { + if retErr != nil { + netConn.Close() + } + }() + newServerConn := interceptor.NewPGConn(netConn) + + // Deserialize session state within the new SQL pod. + if err := runAndWaitForDeserializeSession( + ctx, newServerConn.ToFrontendConn(), state, + ); err != nil { + return nil, errors.Wrap(err, "deserializing session") + } + + return newServerConn, nil +} + +// isSafeTransferPoint returns true if we're at a point where we're safe to +// transfer, and false otherwise. +var isSafeTransferPoint = func(request *processor, response *processor) bool { + request.mu.Lock() + response.mu.Lock() + defer request.mu.Unlock() + defer response.mu.Unlock() + + // Three conditions when evaluating a safe transfer point: + // 1. The last message sent to the SQL pod was a Sync(S) or SimpleQuery(Q), + // and a ReadyForQuery(Z) has been received after. + // 2. The last message sent to the SQL pod was a CopyDone(c), and a + // ReadyForQuery(Z) has been received after. + // 3. The last message sent to the SQL pod was a CopyFail(f), and a + // ReadyForQuery(Z) has been received after. + + // The conditions above are not possible if this is true. They cannot be + // equal since the same logical clock is used. + if request.mu.lastMessageTransferredAt > response.mu.lastMessageTransferredAt { + return false + } + + switch pgwirebase.ClientMessageType(request.mu.lastMessageType) { + case pgwirebase.ClientMessageType(0), + pgwirebase.ClientMsgSync, + pgwirebase.ClientMsgSimpleQuery, + pgwirebase.ClientMsgCopyDone, + pgwirebase.ClientMsgCopyFail: + return pgwirebase.ServerMessageType(response.mu.lastMessageType) == pgwirebase.ServerMsgReady + default: + return false + } +} + // runShowTransferState sends a SHOW TRANSFER STATE query with the input // transferKey to the given writer. The transferKey will be used to uniquely // identify the request when parsing the response messages in diff --git a/pkg/ccl/sqlproxyccl/conn_migration_test.go b/pkg/ccl/sqlproxyccl/conn_migration_test.go index d8c1d70da4a7..699244dd243f 100644 --- a/pkg/ccl/sqlproxyccl/conn_migration_test.go +++ b/pkg/ccl/sqlproxyccl/conn_migration_test.go @@ -23,6 +23,11 @@ import ( "github.com/stretchr/testify/require" ) +func TestIsSafeTransferPoint(t *testing.T) { + defer leaktest.AfterTest(t)() + // TODO(jaylim-crl): Tests. +} + func TestRunShowTransferState(t *testing.T) { defer leaktest.AfterTest(t)() diff --git a/pkg/ccl/sqlproxyccl/forwarder.go b/pkg/ccl/sqlproxyccl/forwarder.go index 2e47523b65e7..c69f0663f01e 100644 --- a/pkg/ccl/sqlproxyccl/forwarder.go +++ b/pkg/ccl/sqlproxyccl/forwarder.go @@ -42,34 +42,48 @@ type forwarder struct { // the same as the metrics field in the proxyHandler instance. metrics *metrics - // clientConn and serverConn provide a convenient way to read and forward - // Postgres messages, while minimizing IO reads and memory allocations. - // - // clientConn is set once during initialization, and stays the same - // throughout the lifetime of the forwarder. - // - // serverConn is set during initialization, which happens after the - // authentication phase, and will be replaced if a connection migration - // occurs. During a connection migration, serverConn is only replaced once - // the session has successfully been deserialized, and the old connection - // will be closed. - // - // All reads from these connections must go through the PG interceptors. - // It is not safe to call Read directly as the interceptors may have - // buffered data. - clientConn *interceptor.PGConn // client <-> proxy - serverConn *interceptor.PGConn // proxy <-> server - - // request and response both represent the processors used to handle - // client-to-server and server-to-client messages. - request *processor // client -> server - response *processor // server -> client - // errCh is a buffered channel that contains the first forwarder error. // This channel may receive nil errors. When an error is written to this // channel, it is guaranteed that the forwarder and all connections will // be closed. errCh chan error + + // While not all of these fields may need to be guarded by a mutex, we do + // so for consistency. Fields like clientConn and serverConn need them + // because Close can be invoked anytime from a different goroutine while + // the connection migration is in progress. On the other hand, the processor + // fields will only be updated during connection migration, and we can + // guarantee that processors will be suspended, so we don't need mutexes + // for them. + mu struct { + syncutil.Mutex + + // isTransferring indicates that a connection migration is in progress. + isTransferring bool + + // clientConn and serverConn provide a convenient way to read and forward + // Postgres messages, while minimizing IO reads and memory allocations. + // + // clientConn is set once during initialization, and stays the same + // throughout the lifetime of the forwarder. + // + // serverConn is set during initialization, which happens after the + // authentication phase, and will be replaced if a connection migration + // occurs. During a connection migration, serverConn is only replaced once + // the session has successfully been deserialized, and the old connection + // will be closed. + // + // All reads from these connections must go through the PG interceptors. + // It is not safe to call Read directly as the interceptors may have + // buffered data. + clientConn *interceptor.PGConn // client <-> proxy + serverConn *interceptor.PGConn // proxy <-> server + + // request and response both represent the processors used to handle + // client-to-server and server-to-client messages. + request *processor // client -> server + response *processor // server -> client + } } // forward returns a new instance of forwarder, and starts forwarding messages @@ -89,17 +103,18 @@ func forward( ) (*forwarder, error) { ctx, cancelFn := context.WithCancel(ctx) f := &forwarder{ - ctx: ctx, - ctxCancel: cancelFn, - errCh: make(chan error, 1), - connector: connector, - metrics: metrics, - clientConn: interceptor.NewPGConn(clientConn), - serverConn: interceptor.NewPGConn(serverConn), + ctx: ctx, + ctxCancel: cancelFn, + errCh: make(chan error, 1), + connector: connector, + metrics: metrics, } + f.mu.clientConn = interceptor.NewPGConn(clientConn) + f.mu.serverConn = interceptor.NewPGConn(serverConn) + clockFn := makeLogicalClockFn() - f.request = newProcessor(clockFn, f.clientConn, f.serverConn) // client -> server - f.response = newProcessor(clockFn, f.serverConn, f.clientConn) // server -> client + f.mu.request = newProcessor(clockFn, f.mu.clientConn, f.mu.serverConn) // client -> server + f.mu.response = newProcessor(clockFn, f.mu.serverConn, f.mu.clientConn) // server -> client if err := f.resumeProcessors(); err != nil { return nil, err } @@ -124,8 +139,18 @@ func (f *forwarder) Close() { // Since Close is idempotent, we'll ignore the error from Close calls in // case they have already been closed. - f.clientConn.Close() - f.serverConn.Close() + f.mu.Lock() + defer f.mu.Unlock() + f.mu.clientConn.Close() + f.mu.serverConn.Close() +} + +// RequestTransfer requests that the forwarder performs a best-effort connection +// migration whenever it can. It is best-effort because this will be a no-op if +// the forwarder is not in a state that is eligible for a connection migration. +// If a transfer is already in progress, or has been requested, this is a no-op. +func (f *forwarder) RequestTransfer() { + go f.runTransfer() } // resumeProcessors starts both the request and response processors @@ -133,20 +158,25 @@ func (f *forwarder) Close() { // return an error while resuming. This is idempotent as resume() will return // nil if the processor has already been started. func (f *forwarder) resumeProcessors() error { + f.mu.Lock() + defer f.mu.Unlock() + + requestProc := f.mu.request + responseProc := f.mu.response go func() { - if err := f.request.resume(f.ctx); err != nil { + if err := requestProc.resume(f.ctx); err != nil { f.tryReportError(wrapClientToServerError(err)) } }() go func() { - if err := f.response.resume(f.ctx); err != nil { + if err := responseProc.resume(f.ctx); err != nil { f.tryReportError(wrapServerToClientError(err)) } }() - if err := f.request.waitResumed(f.ctx); err != nil { + if err := requestProc.waitResumed(f.ctx); err != nil { return err } - if err := f.response.waitResumed(f.ctx); err != nil { + if err := responseProc.waitResumed(f.ctx); err != nil { return err } return nil @@ -200,7 +230,8 @@ func wrapServerToClientError(err error) error { // This implementation could overflow in theory, but it doesn't matter for the // forwarder since the worst that could happen is that we are unable to transfer // for an extremely short period of time until all the processors have wrapped -// around. That said, this situation is rare since uint64 is a huge number. +// around. That said, this situation is rare since uint64 is a huge number, and +// we restart the clock on each transfer. func makeLogicalClockFn() func() uint64 { var counter uint64 return func() uint64 { diff --git a/pkg/ccl/sqlproxyccl/forwarder_test.go b/pkg/ccl/sqlproxyccl/forwarder_test.go index c8ab9ec444cb..eb340a77cb04 100644 --- a/pkg/ccl/sqlproxyccl/forwarder_test.go +++ b/pkg/ccl/sqlproxyccl/forwarder_test.go @@ -66,7 +66,10 @@ func TestForward(t *testing.T) { require.NoError(t, err) require.Nil(t, f.ctx.Err()) - requestProc := f.request + f.mu.Lock() + requestProc := f.mu.request + f.mu.Unlock() + initialClock := requestProc.logicalClockFn() barrier := make(chan struct{}) requestProc.testingKnobs.beforeForwardMsg = func() { @@ -162,7 +165,10 @@ func TestForward(t *testing.T) { require.NoError(t, err) require.Nil(t, f.ctx.Err()) - responseProc := f.response + f.mu.Lock() + responseProc := f.mu.response + f.mu.Unlock() + initialClock := responseProc.logicalClockFn() barrier := make(chan struct{}) responseProc.testingKnobs.beforeForwardMsg = func() { diff --git a/pkg/ccl/sqlproxyccl/interceptor/pg_conn.go b/pkg/ccl/sqlproxyccl/interceptor/pg_conn.go index 1155e93dcda6..dad8b542a068 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/pg_conn.go +++ b/pkg/ccl/sqlproxyccl/interceptor/pg_conn.go @@ -24,3 +24,17 @@ func NewPGConn(conn net.Conn) *PGConn { pgInterceptor: newPgInterceptor(conn, defaultBufferSize), } } + +// ToFrontendConn converts a PGConn to a FrontendConn. Callers should be aware +// of the underlying type of net.Conn before calling this, or else there will be +// an error during parsing. +func (c *PGConn) ToFrontendConn() *FrontendConn { + return &FrontendConn{Conn: c.Conn, interceptor: c.pgInterceptor} +} + +// ToBackendConn converts a PGConn to a BackendConn. Callers should be aware +// of the underlying type of net.Conn before calling this, or else there will be +// an error during parsing. +func (c *PGConn) ToBackendConn() *BackendConn { + return &BackendConn{Conn: c.Conn, interceptor: c.pgInterceptor} +} diff --git a/pkg/ccl/sqlproxyccl/interceptor/pg_conn_test.go b/pkg/ccl/sqlproxyccl/interceptor/pg_conn_test.go index 8e0105733ddf..ab1754b94bca 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/pg_conn_test.go +++ b/pkg/ccl/sqlproxyccl/interceptor/pg_conn_test.go @@ -70,3 +70,13 @@ func TestPGConn(t *testing.T) { require.Nil(t, err) }) } + +func TestPGConn_ToFrontendConn(t *testing.T) { + defer leaktest.AfterTest(t)() + // TODO(jaylim-crl): Tests. +} + +func TestPGConn_ToBackendConn(t *testing.T) { + defer leaktest.AfterTest(t)() + // TODO(jaylim-crl): Tests. +} diff --git a/pkg/ccl/sqlproxyccl/metrics.go b/pkg/ccl/sqlproxyccl/metrics.go index 7b5afdf2c012..dd12e852001a 100644 --- a/pkg/ccl/sqlproxyccl/metrics.go +++ b/pkg/ccl/sqlproxyccl/metrics.go @@ -15,16 +15,20 @@ import ( // metrics contains pointers to the metrics for monitoring proxy operations. type metrics struct { - BackendDisconnectCount *metric.Counter - IdleDisconnectCount *metric.Counter - BackendDownCount *metric.Counter - ClientDisconnectCount *metric.Counter - CurConnCount *metric.Gauge - RoutingErrCount *metric.Counter - RefusedConnCount *metric.Counter - SuccessfulConnCount *metric.Counter - AuthFailedCount *metric.Counter - ExpiredClientConnCount *metric.Counter + BackendDisconnectCount *metric.Counter + IdleDisconnectCount *metric.Counter + BackendDownCount *metric.Counter + ClientDisconnectCount *metric.Counter + CurConnCount *metric.Gauge + RoutingErrCount *metric.Counter + RefusedConnCount *metric.Counter + SuccessfulConnCount *metric.Counter + AuthFailedCount *metric.Counter + ExpiredClientConnCount *metric.Counter + ConnMigrationSuccessCount *metric.Counter + ConnMigrationErrorFatalCount *metric.Counter + ConnMigrationErrorRecoverableCount *metric.Counter + ConnMigrationAttemptedCount *metric.Counter } // MetricStruct implements the metrics.Struct interface. @@ -93,6 +97,35 @@ var ( Measurement: "Expired Client Connections", Unit: metric.Unit_COUNT, } + // Connection migration metrics. + // + // attempted = success + error_fatal + error_recoverable + metaConnMigrationAttemptedCount = metric.Metadata{ + Name: "proxy.conn_migration.attempted", + Help: "Number of attempted connection migrations", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } + metaConnMigrationSuccessCount = metric.Metadata{ + Name: "proxy.conn_migration.success", + Help: "Number of successful connection migrations", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } + metaConnMigrationErrorFatalCount = metric.Metadata{ + // When connection migrations errored out, connections will be closed. + Name: "proxy.conn_migration.error_fatal", + Help: "Number of failed connection migrations which resulted in terminations", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } + metaConnMigrationErrorRecoverableCount = metric.Metadata{ + // Connections are recoverable, so they won't be closed. + Name: "proxy.conn_migration.error_recoverable", + Help: "Number of failed connection migrations that were recoverable", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } ) // makeProxyMetrics instantiates the metrics holder for proxy monitoring. @@ -108,6 +141,11 @@ func makeProxyMetrics() metrics { SuccessfulConnCount: metric.NewCounter(metaSuccessfulConnCount), AuthFailedCount: metric.NewCounter(metaAuthFailedCount), ExpiredClientConnCount: metric.NewCounter(metaExpiredClientConnCount), + // Connection migration metrics. + ConnMigrationSuccessCount: metric.NewCounter(metaConnMigrationSuccessCount), + ConnMigrationErrorFatalCount: metric.NewCounter(metaConnMigrationErrorFatalCount), + ConnMigrationErrorRecoverableCount: metric.NewCounter(metaConnMigrationErrorRecoverableCount), + ConnMigrationAttemptedCount: metric.NewCounter(metaConnMigrationAttemptedCount), } } diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 7ed39617724a..a84012588f4d 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -96,6 +96,11 @@ type ProxyOptions struct { // ThrottleBaseDelay is the initial exponential backoff triggered in // response to the first connection failure. ThrottleBaseDelay time.Duration + + // Used for testing. + testingKnobs struct { + afterForward func(*forwarder) error + } } // proxyHandler is the default implementation of a proxy handler. @@ -341,6 +346,15 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn return err } + if handler.testingKnobs.afterForward != nil { + if err := handler.testingKnobs.afterForward(f); err != nil { + select { + case errConnection <- err: /* error reported */ + default: /* the channel already contains an error */ + } + } + } + // Block until an error is received, or when the stopper starts quiescing, // whichever that happens first. // diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index a439d5321177..eabf4d39c508 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -11,6 +11,7 @@ package sqlproxyccl import ( "context" "crypto/tls" + gosql "database/sql" "fmt" "io/ioutil" "net" @@ -20,6 +21,7 @@ import ( "testing" "time" + "github.com/cockroachdb/cockroach-go/v2/crdb" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/ccl/kvccl/kvtenantccl" "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/denylist" @@ -30,6 +32,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/pgwire" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/sql/tests" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/skip" @@ -755,6 +758,376 @@ func TestDirectoryConnect(t *testing.T) { }) } +func TestConnectionMigration(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + defer log.Scope(t).Close(t) + + params, _ := tests.CreateTestServerParams() + s, mainDB, _ := serverutils.StartServer(t, params) + defer s.Stopper().Stop(ctx) + tenantID := serverutils.TestTenantID() + + // TODO(rafi): use ALTER TENANT ALL when available. + _, err := mainDB.Exec(`INSERT INTO system.tenant_settings (tenant_id, name, value, value_type) VALUES + (0, 'server.user_login.session_revival_token.enabled', 'true', 'b')`) + require.NoError(t, err) + + // Start first SQL pod. + tenant1, tenantDB1 := serverutils.StartTenant(t, s, tests.CreateTestTenantParams(tenantID)) + tenant1.PGServer().(*pgwire.Server).TestingSetTrustClientProvidedRemoteAddr(true) + defer tenant1.Stopper().Stop(ctx) + defer tenantDB1.Close() + + // Start second SQL pod. + params2 := tests.CreateTestTenantParams(tenantID) + params2.Existing = true + tenant2, tenantDB2 := serverutils.StartTenant(t, s, params2) + tenant2.PGServer().(*pgwire.Server).TestingSetTrustClientProvidedRemoteAddr(true) + defer tenant2.Stopper().Stop(ctx) + defer tenantDB2.Close() + + _, err = tenantDB1.Exec("CREATE USER testuser WITH PASSWORD 'hunter2'") + require.NoError(t, err) + _, err = tenantDB1.Exec("GRANT admin TO testuser") + require.NoError(t, err) + + // Create a proxy server without using a directory. The directory is very + // difficult to work with, and there isn't a way to easily stub out fake + // loads. For this test, we will stub out lookupAddr in the connector. We + // will alternate between tenant1 and tenant2, starting with tenant1. + forwarderCh := make(chan *forwarder) + opts := &ProxyOptions{SkipVerify: true, RoutingRule: tenant1.SQLAddr()} + opts.testingKnobs.afterForward = func(f *forwarder) error { + select { + case forwarderCh <- f: + case <-time.After(10 * time.Second): + return errors.New("no receivers for forwarder") + } + return nil + } + _, addr := newSecureProxyServer(ctx, t, s.Stopper(), opts) + + // The tenant ID does not matter here since we stubbed RoutingRule. + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-28", addr) + + type queryer interface { + QueryRowContext(context.Context, string, ...interface{}) *gosql.Row + } + // queryAddr queries the SQL node that `db` is connected to for its address. + queryAddr := func(t *testing.T, ctx context.Context, db queryer) string { + t.Helper() + var host, port string + require.NoError(t, db.QueryRowContext(ctx, ` + SELECT + a.value AS "host", b.value AS "port" + FROM crdb_internal.node_runtime_info a, crdb_internal.node_runtime_info b + WHERE a.component = 'DB' AND a.field = 'Host' + AND b.component = 'DB' AND b.field = 'Port' + `).Scan(&host, &port)) + return fmt.Sprintf("%s:%s", host, port) + } + + // Test that connection transfers are successful. Note that if one sub-test + // fails, the remaining will fail as well since they all use the same + // forwarder instance. + t.Run("successful", func(t *testing.T) { + tCtx, cancel := context.WithCancel(ctx) + defer cancel() + + db, err := gosql.Open("postgres", connectionString) + db.SetMaxOpenConns(1) + defer db.Close() + require.NoError(t, err) + + // Spin up a goroutine to trigger the initial connection. + go func() { + _ = db.PingContext(tCtx) + }() + + var f *forwarder + select { + case f = <-forwarderCh: + case <-time.After(10 * time.Second): + t.Fatal("no connection") + } + + // Set up forwarder hooks. + prevTenant1 := true + var lookupAddrDelayDuration time.Duration + f.connector.testingKnobs.lookupAddr = func(ctx context.Context) (string, error) { + if lookupAddrDelayDuration != 0 { + select { + case <-ctx.Done(): + return "", errors.Wrap(ctx.Err(), "injected delays") + case <-time.After(lookupAddrDelayDuration): + } + } + if prevTenant1 { + prevTenant1 = false + return tenant2.SQLAddr(), nil + } + prevTenant1 = true + return tenant1.SQLAddr(), nil + } + + t.Run("normal_transfer", func(t *testing.T) { + require.Equal(t, tenant1.SQLAddr(), queryAddr(t, tCtx, db)) + + _, err = db.Exec("SET application_name = 'foo'") + require.NoError(t, err) + + // Show that we get alternating SQL pods when we transfer. + f.RequestTransfer() + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationSuccessCount.Count() == 1 + }, 20*time.Second, 25*time.Millisecond) + require.Equal(t, tenant2.SQLAddr(), queryAddr(t, tCtx, db)) + + var name string + require.NoError(t, db.QueryRow("SHOW application_name").Scan(&name)) + require.Equal(t, "foo", name) + + _, err = db.Exec("SET application_name = 'bar'") + require.NoError(t, err) + + f.RequestTransfer() + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationSuccessCount.Count() == 2 + }, 20*time.Second, 25*time.Millisecond) + require.Equal(t, tenant1.SQLAddr(), queryAddr(t, tCtx, db)) + + require.NoError(t, db.QueryRow("SHOW application_name").Scan(&name)) + require.Equal(t, "bar", name) + + // Now attempt a transfer concurrently with requests. + closerCh := make(chan struct{}) + go func() { + for i := 0; i < 10 && tCtx.Err() == nil; i++ { + f.RequestTransfer() + time.Sleep(500 * time.Millisecond) + } + closerCh <- struct{}{} + }() + + // This test runs for 5 seconds. + var tenant1Addr, tenant2Addr int + for i := 0; i < 100; i++ { + addr := queryAddr(t, tCtx, db) + if addr == tenant1.SQLAddr() { + tenant1Addr++ + } else { + require.Equal(t, tenant2.SQLAddr(), addr) + tenant2Addr++ + } + time.Sleep(50 * time.Millisecond) + } + + // In 5s, we should have at least 10 successful transfers. Just do + // an approximation here. + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationSuccessCount.Count() >= 5 + }, 20*time.Second, 25*time.Millisecond) + require.True(t, tenant1Addr > 2) + require.True(t, tenant2Addr > 2) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorRecoverableCount.Count()) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorFatalCount.Count()) + + // Ensure that the goroutine terminates so other subtests are not + // affected. + <-closerCh + + // There's a chance that we still have an in-progress transfer, so + // attempt to wait. + require.Eventually(t, func() bool { + f.mu.Lock() + defer f.mu.Unlock() + return !f.mu.isTransferring + }, 10*time.Second, 25*time.Millisecond) + + require.Equal(t, f.metrics.ConnMigrationAttemptedCount.Count(), + f.metrics.ConnMigrationSuccessCount.Count(), + ) + }) + + // Transfers should fail if there is an open transaction. These failed + // transfers should not close the connection. + t.Run("failed_transfers_with_tx", func(t *testing.T) { + initSuccessCount := f.metrics.ConnMigrationSuccessCount.Count() + initAddr := queryAddr(t, tCtx, db) + + err = crdb.ExecuteTx(tCtx, db, nil /* txopts */, func(tx *gosql.Tx) error { + for i := 0; i < 10; i++ { + f.RequestTransfer() + addr := queryAddr(t, tCtx, tx) + if initAddr != addr { + return errors.Newf( + "address does not match, expected %s, found %s", + initAddr, + addr, + ) + } + time.Sleep(50 * time.Millisecond) + } + return nil + }) + require.NoError(t, err) + + // Make sure there are no pending transfers. + func() { + f.mu.Lock() + defer f.mu.Unlock() + require.False(t, f.mu.isTransferring) + }() + + // Just check that we have half of what we requested since we cannot + // guarantee that the transfer will run within 50ms. + require.True(t, f.metrics.ConnMigrationErrorRecoverableCount.Count() >= 5) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorFatalCount.Count()) + require.Equal(t, initSuccessCount, f.metrics.ConnMigrationSuccessCount.Count()) + prevErrorRecoverableCount := f.metrics.ConnMigrationErrorRecoverableCount.Count() + + // Once the transaction is closed, transfers should work. + f.RequestTransfer() + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationSuccessCount.Count() == initSuccessCount+1 + }, 20*time.Second, 25*time.Millisecond) + require.NotEqual(t, initAddr, queryAddr(t, tCtx, db)) + require.Equal(t, prevErrorRecoverableCount, f.metrics.ConnMigrationErrorRecoverableCount.Count()) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorFatalCount.Count()) + + // We have already asserted metrics above, so transfer must have + // been completed. + f.mu.Lock() + defer f.mu.Unlock() + require.False(t, f.mu.isTransferring) + }) + + // Transfer timeout caused by dial issues should not close the session. + // We will test this by introducing delays when connecting to the SQL + // pod. + t.Run("failed_transfers_with_dial_issues", func(t *testing.T) { + initSuccessCount := f.metrics.ConnMigrationSuccessCount.Count() + initErrorRecoverableCount := f.metrics.ConnMigrationErrorRecoverableCount.Count() + initAddr := queryAddr(t, tCtx, db) + + // Set the delay longer than the timeout. + lookupAddrDelayDuration = 10 * time.Second + defer testutils.TestingHook(&defaultTransferTimeout, 3*time.Second)() + + f.RequestTransfer() + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationErrorRecoverableCount.Count() == initErrorRecoverableCount+1 + }, 20*time.Second, 25*time.Millisecond) + require.Equal(t, initAddr, queryAddr(t, tCtx, db)) + require.Equal(t, initSuccessCount, f.metrics.ConnMigrationSuccessCount.Count()) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorFatalCount.Count()) + + // We have already asserted metrics above, so transfer must have + // been completed. + f.mu.Lock() + defer f.mu.Unlock() + require.False(t, f.mu.isTransferring) + }) + }) + + // Test transfer timeouts caused by waiting for a transfer state response. + // In reality, this can only be caused by pipelined queries. Consider the + // folllowing: + // 1. short-running simple query + // 2. long-running simple query + // 3. SHOW TRANSFER STATE + // When (1) returns a response, the forwarder will see that we're in a + // safe transfer point, and initiate (3). But (2) may block until we hit + // a timeout. + // + // There's no easy way to simulate pipelined queries. pgtest (that allows + // us to send individual pgwire messages) does not support authentication, + // which is what the proxy needs, so we will stub isSafeTransferPoint + // instead. + t.Run("transfer_timeout_in_response", func(t *testing.T) { + tCtx, cancel := context.WithCancel(ctx) + defer cancel() + + db, err := gosql.Open("postgres", connectionString) + db.SetMaxOpenConns(1) + defer db.Close() + require.NoError(t, err) + + // Use a single connection so that we don't reopen when the connection + // is closed. + conn, err := db.Conn(tCtx) + require.NoError(t, err) + + // Spin up a goroutine to trigger the initial connection. + go func() { + _ = conn.PingContext(tCtx) + }() + + var f *forwarder + select { + case f = <-forwarderCh: + case <-time.After(10 * time.Second): + t.Fatal("no connection") + } + + // Set up forwarder hooks. + prevTenant1 := true + f.connector.testingKnobs.lookupAddr = func(ctx context.Context) (string, error) { + if prevTenant1 { + prevTenant1 = false + return tenant2.SQLAddr(), nil + } + prevTenant1 = true + return tenant1.SQLAddr(), nil + } + defer testutils.TestingHook(&isSafeTransferPoint, func(req *processor, res *processor) bool { + return true + })() + // Transfer timeout is 3s, and we'll run pg_sleep for 10s. + defer testutils.TestingHook(&defaultTransferTimeout, 3*time.Second)() + + goCh := make(chan struct{}, 1) + errCh := make(chan error, 1) + go func() { + goCh <- struct{}{} + _, err = conn.ExecContext(tCtx, "SELECT pg_sleep(10)") + errCh <- err + }() + + // Block until goroutine is started. We want to make sure we run the + // transfer request *after* sending the query. This doesn't guarantee, + // but is the best that we can do. We also added a sleep call here. + // + // Alternatively, we could open another connection, and query the server + // to make sure pg_sleep is running, but that seems unnecessary for just + // one test. + <-goCh + time.Sleep(250 * time.Millisecond) + f.RequestTransfer() + + // Connection should be closed because this is a non-recoverable error, + // i.e. timeout after sending the request, but before fully receiving + // its response. + require.Eventually(t, func() bool { + err := conn.PingContext(tCtx) + return err != nil && strings.Contains(err.Error(), "bad connection") + }, 20*time.Second, 25*time.Millisecond) + + select { + case <-time.After(10 * time.Second): + t.Fatalf("require that pg_sleep query terminates") + case err = <-errCh: + require.NotNil(t, err) + require.Regexp(t, "bad connection", err.Error()) + } + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationErrorFatalCount.Count() == 1 + }, 30*time.Second, 25*time.Millisecond) + require.NotNil(t, f.ctx.Err()) + }) +} + func TestClusterNameAndTenantFromParams(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t)