diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 905d0e160461..a59a6e0df844 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -39,6 +39,7 @@ go_library( "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/timeutil", + "//pkg/util/uuid", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_logtags//:logtags", "@com_github_jackc_pgproto3_v2//:pgproto3", @@ -78,6 +79,8 @@ go_test( "//pkg/sql", "//pkg/sql/pgwire", "//pkg/sql/pgwire/pgerror", + "//pkg/sql/pgwire/pgwirebase", + "//pkg/sql/tests", "//pkg/testutils", "//pkg/testutils/serverutils", "//pkg/testutils/skip", @@ -90,6 +93,7 @@ go_test( "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/timeutil", + "@com_github_cockroachdb_cockroach_go_v2//crdb", "@com_github_cockroachdb_errors//:errors", "@com_github_jackc_pgconn//:pgconn", "@com_github_jackc_pgproto3_v2//:pgproto3", diff --git a/pkg/ccl/sqlproxyccl/conn_migration.go b/pkg/ccl/sqlproxyccl/conn_migration.go index 1858690741cc..dcc3ae776e61 100644 --- a/pkg/ccl/sqlproxyccl/conn_migration.go +++ b/pkg/ccl/sqlproxyccl/conn_migration.go @@ -13,13 +13,254 @@ import ( "encoding/json" "fmt" "io" + "time" "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/errors" + "github.com/cockroachdb/logtags" pgproto3 "github.com/jackc/pgproto3/v2" ) +// defaultTransferTimeout corresponds to the timeout period for the connection +// migration process. If the timeout gets triggered, and we're in a non +// recoverable state, the connection will be closed. +// +// This is a variable instead of a constant to support testing hooks. +var defaultTransferTimeout = 15 * time.Second + +var errTransferCannotStart = errors.New("transfer cannot be started") + +// tryBeginTransfer returns true if the transfer can be started, and false +// otherwise. If the transfer can be started, it updates the state of the +// forwarder to indicate that a transfer is in progress. +func (f *forwarder) tryBeginTransfer() bool { + f.mu.Lock() + defer f.mu.Unlock() + + // Transfer is already in progress. No concurrent transfers are allowed. + if f.mu.isTransferring { + return false + } + + if !isSafeTransferPoint(f.mu.request, f.mu.response) { + return false + } + + f.mu.isTransferring = true + return true +} + +type transferContext struct { + context.Context + mu struct { + syncutil.Mutex + recoverableConn bool + } +} + +func newTransferContext(backgroundCtx context.Context) (*transferContext, context.CancelFunc) { + transferCtx, cancel := context.WithTimeout(backgroundCtx, defaultTransferTimeout) // nolint:context + ctx := &transferContext{ + Context: transferCtx, + } + ctx.mu.recoverableConn = true + return ctx, cancel +} + +func (t *transferContext) markRecoverable(r bool) { + t.mu.Lock() + defer t.mu.Unlock() + t.mu.recoverableConn = r +} + +func (t *transferContext) isRecoverable() bool { + t.mu.Lock() + defer t.mu.Unlock() + return t.mu.recoverableConn +} + +func (f *forwarder) runTransfer() (retErr error) { + // A previous non-recoverable transfer would have closed the forwarder, so + // return right away. + if f.ctx.Err() != nil { + return f.ctx.Err() + } + + if !f.tryBeginTransfer() { + return errTransferCannotStart + } + defer func() { + f.mu.Lock() + defer f.mu.Unlock() + f.mu.isTransferring = false + }() + + f.metrics.ConnMigrationAttemptedCount.Inc(1) + + // Create a transfer context, and timeout handler which gets triggered + // whenever the context expires. We have to close the forwarder because + // the transfer may be blocked on I/O, and the only way for now is to close + // the connections. This then allow runTransfer to return and cleanup. + ctx, cancel := newTransferContext(f.ctx) + defer cancel() + + go func() { + <-ctx.Done() + if !ctx.isRecoverable() { + f.Close() + } + }() + + // Use a separate context for logging because f.ctx will be closed whenever + // the connection is non-recoverable. + logCtx := logtags.WithTags(context.Background(), logtags.FromContext(f.ctx)) + defer func() { + if !ctx.isRecoverable() { + log.Infof(logCtx, "transfer failed: connection closed, err=%v", retErr) + f.metrics.ConnMigrationErrorFatalCount.Inc(1) + } else { + // Transfer was successful. + if retErr == nil { + log.Infof(logCtx, "transfer successful") + f.metrics.ConnMigrationSuccessCount.Inc(1) + } else { + log.Infof(logCtx, "transfer failed: connection recovered, err=%v", retErr) + f.metrics.ConnMigrationErrorRecoverableCount.Inc(1) + } + f.resumeProcessors() + } + }() + + f.mu.Lock() + defer f.mu.Unlock() + + // Suspend both processors before starting the transfer. + f.mu.request.suspend(ctx) + f.mu.response.suspend(ctx) + + // Transfer the connection. + newServerConn, err := transferConnection(ctx, f.connector, f.mu.clientConn, f.mu.serverConn) + if err != nil { + return errors.Wrap(err, "transferring connection") + } + + // Transfer was successful. + clockFn := makeLogicalClockFn() + f.mu.serverConn.Close() + f.mu.serverConn = newServerConn + f.mu.request = newProcessor(clockFn, f.mu.clientConn, f.mu.serverConn) + f.mu.response = newProcessor(clockFn, f.mu.serverConn, f.mu.clientConn) + return nil +} + +// transferConnection performs the transfer operation for the current server +// connection, and returns the a new connection to the server that the +// connection got transferred to. +func transferConnection( + ctx *transferContext, connector *connector, clientConn, serverConn *interceptor.PGConn, +) (_ *interceptor.PGConn, retErr error) { + ctx.markRecoverable(true) + + // Context was cancelled. + if ctx.Err() != nil { + return nil, ctx.Err() + } + + transferKey := uuid.MakeV4().String() + + // Send the SHOW TRANSFER STATE statement. At this point, connection is + // non-recoverable because the message has already been sent to the server. + ctx.markRecoverable(false) + if err := runShowTransferState(serverConn, transferKey); err != nil { + return nil, errors.Wrap(err, "sending transfer request") + } + + transferErr, state, revivalToken, err := waitForShowTransferState( + ctx, serverConn.ToFrontendConn(), clientConn, transferKey) + if err != nil { + return nil, errors.Wrap(err, "waiting for transfer state") + } + + // Failures after this point are recoverable, and connections should not be + // terminated. + ctx.markRecoverable(true) + + // If we consumed until ReadyForQuery without errors, but the transfer state + // response returns an error, we could still resume the connection, but the + // transfer process will need to be aborted. + // + // This case may happen pretty frequently (e.g. open transactions, temporary + // tables, etc.). + if transferErr != "" { + return nil, errors.Newf("%s", transferErr) + } + + // Connect to a new SQL pod. + // + // TODO(jaylim-crl): There is a possibility where the same pod will get + // selected. Some ideas to solve this: pass in the remote address of + // serverConn to avoid choosing that pod, or maybe a filter callback? + // We can also consider adding a target pod as an argument to RequestTransfer. + // That way a central component gets to choose where the connections go. + netConn, err := connector.OpenTenantConnWithToken(ctx, revivalToken) + if err != nil { + return nil, errors.Wrap(err, "opening connection") + } + defer func() { + if retErr != nil { + netConn.Close() + } + }() + newServerConn := interceptor.NewPGConn(netConn) + + // Deserialize session state within the new SQL pod. + if err := runAndWaitForDeserializeSession( + ctx, newServerConn.ToFrontendConn(), state, + ); err != nil { + return nil, errors.Wrap(err, "deserializing session") + } + + return newServerConn, nil +} + +// isSafeTransferPoint returns true if we're at a point where we're safe to +// transfer, and false otherwise. +var isSafeTransferPoint = func(request *processor, response *processor) bool { + request.mu.Lock() + response.mu.Lock() + defer request.mu.Unlock() + defer response.mu.Unlock() + + // Three conditions when evaluating a safe transfer point: + // 1. The last message sent to the SQL pod was a Sync(S) or SimpleQuery(Q), + // and a ReadyForQuery(Z) has been received after. + // 2. The last message sent to the SQL pod was a CopyDone(c), and a + // ReadyForQuery(Z) has been received after. + // 3. The last message sent to the SQL pod was a CopyFail(f), and a + // ReadyForQuery(Z) has been received after. + + // The conditions above are not possible if this is true. They cannot be + // equal since the same logical clock is used. + if request.mu.lastMessageTransferredAt > response.mu.lastMessageTransferredAt { + return false + } + + switch pgwirebase.ClientMessageType(request.mu.lastMessageType) { + case pgwirebase.ClientMessageType(0), + pgwirebase.ClientMsgSync, + pgwirebase.ClientMsgSimpleQuery, + pgwirebase.ClientMsgCopyDone, + pgwirebase.ClientMsgCopyFail: + return pgwirebase.ServerMessageType(response.mu.lastMessageType) == pgwirebase.ServerMsgReady + default: + return false + } +} + // runShowTransferState sends a SHOW TRANSFER STATE query with the input // transferKey to the given writer. The transferKey will be used to uniquely // identify the request when parsing the response messages in diff --git a/pkg/ccl/sqlproxyccl/conn_migration_test.go b/pkg/ccl/sqlproxyccl/conn_migration_test.go index d8c1d70da4a7..699244dd243f 100644 --- a/pkg/ccl/sqlproxyccl/conn_migration_test.go +++ b/pkg/ccl/sqlproxyccl/conn_migration_test.go @@ -23,6 +23,11 @@ import ( "github.com/stretchr/testify/require" ) +func TestIsSafeTransferPoint(t *testing.T) { + defer leaktest.AfterTest(t)() + // TODO(jaylim-crl): Tests. +} + func TestRunShowTransferState(t *testing.T) { defer leaktest.AfterTest(t)() diff --git a/pkg/ccl/sqlproxyccl/forwarder.go b/pkg/ccl/sqlproxyccl/forwarder.go index 2e47523b65e7..c69f0663f01e 100644 --- a/pkg/ccl/sqlproxyccl/forwarder.go +++ b/pkg/ccl/sqlproxyccl/forwarder.go @@ -42,34 +42,48 @@ type forwarder struct { // the same as the metrics field in the proxyHandler instance. metrics *metrics - // clientConn and serverConn provide a convenient way to read and forward - // Postgres messages, while minimizing IO reads and memory allocations. - // - // clientConn is set once during initialization, and stays the same - // throughout the lifetime of the forwarder. - // - // serverConn is set during initialization, which happens after the - // authentication phase, and will be replaced if a connection migration - // occurs. During a connection migration, serverConn is only replaced once - // the session has successfully been deserialized, and the old connection - // will be closed. - // - // All reads from these connections must go through the PG interceptors. - // It is not safe to call Read directly as the interceptors may have - // buffered data. - clientConn *interceptor.PGConn // client <-> proxy - serverConn *interceptor.PGConn // proxy <-> server - - // request and response both represent the processors used to handle - // client-to-server and server-to-client messages. - request *processor // client -> server - response *processor // server -> client - // errCh is a buffered channel that contains the first forwarder error. // This channel may receive nil errors. When an error is written to this // channel, it is guaranteed that the forwarder and all connections will // be closed. errCh chan error + + // While not all of these fields may need to be guarded by a mutex, we do + // so for consistency. Fields like clientConn and serverConn need them + // because Close can be invoked anytime from a different goroutine while + // the connection migration is in progress. On the other hand, the processor + // fields will only be updated during connection migration, and we can + // guarantee that processors will be suspended, so we don't need mutexes + // for them. + mu struct { + syncutil.Mutex + + // isTransferring indicates that a connection migration is in progress. + isTransferring bool + + // clientConn and serverConn provide a convenient way to read and forward + // Postgres messages, while minimizing IO reads and memory allocations. + // + // clientConn is set once during initialization, and stays the same + // throughout the lifetime of the forwarder. + // + // serverConn is set during initialization, which happens after the + // authentication phase, and will be replaced if a connection migration + // occurs. During a connection migration, serverConn is only replaced once + // the session has successfully been deserialized, and the old connection + // will be closed. + // + // All reads from these connections must go through the PG interceptors. + // It is not safe to call Read directly as the interceptors may have + // buffered data. + clientConn *interceptor.PGConn // client <-> proxy + serverConn *interceptor.PGConn // proxy <-> server + + // request and response both represent the processors used to handle + // client-to-server and server-to-client messages. + request *processor // client -> server + response *processor // server -> client + } } // forward returns a new instance of forwarder, and starts forwarding messages @@ -89,17 +103,18 @@ func forward( ) (*forwarder, error) { ctx, cancelFn := context.WithCancel(ctx) f := &forwarder{ - ctx: ctx, - ctxCancel: cancelFn, - errCh: make(chan error, 1), - connector: connector, - metrics: metrics, - clientConn: interceptor.NewPGConn(clientConn), - serverConn: interceptor.NewPGConn(serverConn), + ctx: ctx, + ctxCancel: cancelFn, + errCh: make(chan error, 1), + connector: connector, + metrics: metrics, } + f.mu.clientConn = interceptor.NewPGConn(clientConn) + f.mu.serverConn = interceptor.NewPGConn(serverConn) + clockFn := makeLogicalClockFn() - f.request = newProcessor(clockFn, f.clientConn, f.serverConn) // client -> server - f.response = newProcessor(clockFn, f.serverConn, f.clientConn) // server -> client + f.mu.request = newProcessor(clockFn, f.mu.clientConn, f.mu.serverConn) // client -> server + f.mu.response = newProcessor(clockFn, f.mu.serverConn, f.mu.clientConn) // server -> client if err := f.resumeProcessors(); err != nil { return nil, err } @@ -124,8 +139,18 @@ func (f *forwarder) Close() { // Since Close is idempotent, we'll ignore the error from Close calls in // case they have already been closed. - f.clientConn.Close() - f.serverConn.Close() + f.mu.Lock() + defer f.mu.Unlock() + f.mu.clientConn.Close() + f.mu.serverConn.Close() +} + +// RequestTransfer requests that the forwarder performs a best-effort connection +// migration whenever it can. It is best-effort because this will be a no-op if +// the forwarder is not in a state that is eligible for a connection migration. +// If a transfer is already in progress, or has been requested, this is a no-op. +func (f *forwarder) RequestTransfer() { + go f.runTransfer() } // resumeProcessors starts both the request and response processors @@ -133,20 +158,25 @@ func (f *forwarder) Close() { // return an error while resuming. This is idempotent as resume() will return // nil if the processor has already been started. func (f *forwarder) resumeProcessors() error { + f.mu.Lock() + defer f.mu.Unlock() + + requestProc := f.mu.request + responseProc := f.mu.response go func() { - if err := f.request.resume(f.ctx); err != nil { + if err := requestProc.resume(f.ctx); err != nil { f.tryReportError(wrapClientToServerError(err)) } }() go func() { - if err := f.response.resume(f.ctx); err != nil { + if err := responseProc.resume(f.ctx); err != nil { f.tryReportError(wrapServerToClientError(err)) } }() - if err := f.request.waitResumed(f.ctx); err != nil { + if err := requestProc.waitResumed(f.ctx); err != nil { return err } - if err := f.response.waitResumed(f.ctx); err != nil { + if err := responseProc.waitResumed(f.ctx); err != nil { return err } return nil @@ -200,7 +230,8 @@ func wrapServerToClientError(err error) error { // This implementation could overflow in theory, but it doesn't matter for the // forwarder since the worst that could happen is that we are unable to transfer // for an extremely short period of time until all the processors have wrapped -// around. That said, this situation is rare since uint64 is a huge number. +// around. That said, this situation is rare since uint64 is a huge number, and +// we restart the clock on each transfer. func makeLogicalClockFn() func() uint64 { var counter uint64 return func() uint64 { diff --git a/pkg/ccl/sqlproxyccl/forwarder_test.go b/pkg/ccl/sqlproxyccl/forwarder_test.go index c8ab9ec444cb..eb340a77cb04 100644 --- a/pkg/ccl/sqlproxyccl/forwarder_test.go +++ b/pkg/ccl/sqlproxyccl/forwarder_test.go @@ -66,7 +66,10 @@ func TestForward(t *testing.T) { require.NoError(t, err) require.Nil(t, f.ctx.Err()) - requestProc := f.request + f.mu.Lock() + requestProc := f.mu.request + f.mu.Unlock() + initialClock := requestProc.logicalClockFn() barrier := make(chan struct{}) requestProc.testingKnobs.beforeForwardMsg = func() { @@ -162,7 +165,10 @@ func TestForward(t *testing.T) { require.NoError(t, err) require.Nil(t, f.ctx.Err()) - responseProc := f.response + f.mu.Lock() + responseProc := f.mu.response + f.mu.Unlock() + initialClock := responseProc.logicalClockFn() barrier := make(chan struct{}) responseProc.testingKnobs.beforeForwardMsg = func() { diff --git a/pkg/ccl/sqlproxyccl/interceptor/pg_conn.go b/pkg/ccl/sqlproxyccl/interceptor/pg_conn.go index 1155e93dcda6..dad8b542a068 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/pg_conn.go +++ b/pkg/ccl/sqlproxyccl/interceptor/pg_conn.go @@ -24,3 +24,17 @@ func NewPGConn(conn net.Conn) *PGConn { pgInterceptor: newPgInterceptor(conn, defaultBufferSize), } } + +// ToFrontendConn converts a PGConn to a FrontendConn. Callers should be aware +// of the underlying type of net.Conn before calling this, or else there will be +// an error during parsing. +func (c *PGConn) ToFrontendConn() *FrontendConn { + return &FrontendConn{Conn: c.Conn, interceptor: c.pgInterceptor} +} + +// ToBackendConn converts a PGConn to a BackendConn. Callers should be aware +// of the underlying type of net.Conn before calling this, or else there will be +// an error during parsing. +func (c *PGConn) ToBackendConn() *BackendConn { + return &BackendConn{Conn: c.Conn, interceptor: c.pgInterceptor} +} diff --git a/pkg/ccl/sqlproxyccl/interceptor/pg_conn_test.go b/pkg/ccl/sqlproxyccl/interceptor/pg_conn_test.go index 8e0105733ddf..ab1754b94bca 100644 --- a/pkg/ccl/sqlproxyccl/interceptor/pg_conn_test.go +++ b/pkg/ccl/sqlproxyccl/interceptor/pg_conn_test.go @@ -70,3 +70,13 @@ func TestPGConn(t *testing.T) { require.Nil(t, err) }) } + +func TestPGConn_ToFrontendConn(t *testing.T) { + defer leaktest.AfterTest(t)() + // TODO(jaylim-crl): Tests. +} + +func TestPGConn_ToBackendConn(t *testing.T) { + defer leaktest.AfterTest(t)() + // TODO(jaylim-crl): Tests. +} diff --git a/pkg/ccl/sqlproxyccl/metrics.go b/pkg/ccl/sqlproxyccl/metrics.go index 7b5afdf2c012..dd12e852001a 100644 --- a/pkg/ccl/sqlproxyccl/metrics.go +++ b/pkg/ccl/sqlproxyccl/metrics.go @@ -15,16 +15,20 @@ import ( // metrics contains pointers to the metrics for monitoring proxy operations. type metrics struct { - BackendDisconnectCount *metric.Counter - IdleDisconnectCount *metric.Counter - BackendDownCount *metric.Counter - ClientDisconnectCount *metric.Counter - CurConnCount *metric.Gauge - RoutingErrCount *metric.Counter - RefusedConnCount *metric.Counter - SuccessfulConnCount *metric.Counter - AuthFailedCount *metric.Counter - ExpiredClientConnCount *metric.Counter + BackendDisconnectCount *metric.Counter + IdleDisconnectCount *metric.Counter + BackendDownCount *metric.Counter + ClientDisconnectCount *metric.Counter + CurConnCount *metric.Gauge + RoutingErrCount *metric.Counter + RefusedConnCount *metric.Counter + SuccessfulConnCount *metric.Counter + AuthFailedCount *metric.Counter + ExpiredClientConnCount *metric.Counter + ConnMigrationSuccessCount *metric.Counter + ConnMigrationErrorFatalCount *metric.Counter + ConnMigrationErrorRecoverableCount *metric.Counter + ConnMigrationAttemptedCount *metric.Counter } // MetricStruct implements the metrics.Struct interface. @@ -93,6 +97,35 @@ var ( Measurement: "Expired Client Connections", Unit: metric.Unit_COUNT, } + // Connection migration metrics. + // + // attempted = success + error_fatal + error_recoverable + metaConnMigrationAttemptedCount = metric.Metadata{ + Name: "proxy.conn_migration.attempted", + Help: "Number of attempted connection migrations", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } + metaConnMigrationSuccessCount = metric.Metadata{ + Name: "proxy.conn_migration.success", + Help: "Number of successful connection migrations", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } + metaConnMigrationErrorFatalCount = metric.Metadata{ + // When connection migrations errored out, connections will be closed. + Name: "proxy.conn_migration.error_fatal", + Help: "Number of failed connection migrations which resulted in terminations", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } + metaConnMigrationErrorRecoverableCount = metric.Metadata{ + // Connections are recoverable, so they won't be closed. + Name: "proxy.conn_migration.error_recoverable", + Help: "Number of failed connection migrations that were recoverable", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } ) // makeProxyMetrics instantiates the metrics holder for proxy monitoring. @@ -108,6 +141,11 @@ func makeProxyMetrics() metrics { SuccessfulConnCount: metric.NewCounter(metaSuccessfulConnCount), AuthFailedCount: metric.NewCounter(metaAuthFailedCount), ExpiredClientConnCount: metric.NewCounter(metaExpiredClientConnCount), + // Connection migration metrics. + ConnMigrationSuccessCount: metric.NewCounter(metaConnMigrationSuccessCount), + ConnMigrationErrorFatalCount: metric.NewCounter(metaConnMigrationErrorFatalCount), + ConnMigrationErrorRecoverableCount: metric.NewCounter(metaConnMigrationErrorRecoverableCount), + ConnMigrationAttemptedCount: metric.NewCounter(metaConnMigrationAttemptedCount), } } diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 7ed39617724a..a84012588f4d 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -96,6 +96,11 @@ type ProxyOptions struct { // ThrottleBaseDelay is the initial exponential backoff triggered in // response to the first connection failure. ThrottleBaseDelay time.Duration + + // Used for testing. + testingKnobs struct { + afterForward func(*forwarder) error + } } // proxyHandler is the default implementation of a proxy handler. @@ -341,6 +346,15 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn return err } + if handler.testingKnobs.afterForward != nil { + if err := handler.testingKnobs.afterForward(f); err != nil { + select { + case errConnection <- err: /* error reported */ + default: /* the channel already contains an error */ + } + } + } + // Block until an error is received, or when the stopper starts quiescing, // whichever that happens first. // diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index a439d5321177..eabf4d39c508 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -11,6 +11,7 @@ package sqlproxyccl import ( "context" "crypto/tls" + gosql "database/sql" "fmt" "io/ioutil" "net" @@ -20,6 +21,7 @@ import ( "testing" "time" + "github.com/cockroachdb/cockroach-go/v2/crdb" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/ccl/kvccl/kvtenantccl" "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/denylist" @@ -30,6 +32,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/pgwire" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/sql/tests" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/skip" @@ -755,6 +758,376 @@ func TestDirectoryConnect(t *testing.T) { }) } +func TestConnectionMigration(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + defer log.Scope(t).Close(t) + + params, _ := tests.CreateTestServerParams() + s, mainDB, _ := serverutils.StartServer(t, params) + defer s.Stopper().Stop(ctx) + tenantID := serverutils.TestTenantID() + + // TODO(rafi): use ALTER TENANT ALL when available. + _, err := mainDB.Exec(`INSERT INTO system.tenant_settings (tenant_id, name, value, value_type) VALUES + (0, 'server.user_login.session_revival_token.enabled', 'true', 'b')`) + require.NoError(t, err) + + // Start first SQL pod. + tenant1, tenantDB1 := serverutils.StartTenant(t, s, tests.CreateTestTenantParams(tenantID)) + tenant1.PGServer().(*pgwire.Server).TestingSetTrustClientProvidedRemoteAddr(true) + defer tenant1.Stopper().Stop(ctx) + defer tenantDB1.Close() + + // Start second SQL pod. + params2 := tests.CreateTestTenantParams(tenantID) + params2.Existing = true + tenant2, tenantDB2 := serverutils.StartTenant(t, s, params2) + tenant2.PGServer().(*pgwire.Server).TestingSetTrustClientProvidedRemoteAddr(true) + defer tenant2.Stopper().Stop(ctx) + defer tenantDB2.Close() + + _, err = tenantDB1.Exec("CREATE USER testuser WITH PASSWORD 'hunter2'") + require.NoError(t, err) + _, err = tenantDB1.Exec("GRANT admin TO testuser") + require.NoError(t, err) + + // Create a proxy server without using a directory. The directory is very + // difficult to work with, and there isn't a way to easily stub out fake + // loads. For this test, we will stub out lookupAddr in the connector. We + // will alternate between tenant1 and tenant2, starting with tenant1. + forwarderCh := make(chan *forwarder) + opts := &ProxyOptions{SkipVerify: true, RoutingRule: tenant1.SQLAddr()} + opts.testingKnobs.afterForward = func(f *forwarder) error { + select { + case forwarderCh <- f: + case <-time.After(10 * time.Second): + return errors.New("no receivers for forwarder") + } + return nil + } + _, addr := newSecureProxyServer(ctx, t, s.Stopper(), opts) + + // The tenant ID does not matter here since we stubbed RoutingRule. + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-28", addr) + + type queryer interface { + QueryRowContext(context.Context, string, ...interface{}) *gosql.Row + } + // queryAddr queries the SQL node that `db` is connected to for its address. + queryAddr := func(t *testing.T, ctx context.Context, db queryer) string { + t.Helper() + var host, port string + require.NoError(t, db.QueryRowContext(ctx, ` + SELECT + a.value AS "host", b.value AS "port" + FROM crdb_internal.node_runtime_info a, crdb_internal.node_runtime_info b + WHERE a.component = 'DB' AND a.field = 'Host' + AND b.component = 'DB' AND b.field = 'Port' + `).Scan(&host, &port)) + return fmt.Sprintf("%s:%s", host, port) + } + + // Test that connection transfers are successful. Note that if one sub-test + // fails, the remaining will fail as well since they all use the same + // forwarder instance. + t.Run("successful", func(t *testing.T) { + tCtx, cancel := context.WithCancel(ctx) + defer cancel() + + db, err := gosql.Open("postgres", connectionString) + db.SetMaxOpenConns(1) + defer db.Close() + require.NoError(t, err) + + // Spin up a goroutine to trigger the initial connection. + go func() { + _ = db.PingContext(tCtx) + }() + + var f *forwarder + select { + case f = <-forwarderCh: + case <-time.After(10 * time.Second): + t.Fatal("no connection") + } + + // Set up forwarder hooks. + prevTenant1 := true + var lookupAddrDelayDuration time.Duration + f.connector.testingKnobs.lookupAddr = func(ctx context.Context) (string, error) { + if lookupAddrDelayDuration != 0 { + select { + case <-ctx.Done(): + return "", errors.Wrap(ctx.Err(), "injected delays") + case <-time.After(lookupAddrDelayDuration): + } + } + if prevTenant1 { + prevTenant1 = false + return tenant2.SQLAddr(), nil + } + prevTenant1 = true + return tenant1.SQLAddr(), nil + } + + t.Run("normal_transfer", func(t *testing.T) { + require.Equal(t, tenant1.SQLAddr(), queryAddr(t, tCtx, db)) + + _, err = db.Exec("SET application_name = 'foo'") + require.NoError(t, err) + + // Show that we get alternating SQL pods when we transfer. + f.RequestTransfer() + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationSuccessCount.Count() == 1 + }, 20*time.Second, 25*time.Millisecond) + require.Equal(t, tenant2.SQLAddr(), queryAddr(t, tCtx, db)) + + var name string + require.NoError(t, db.QueryRow("SHOW application_name").Scan(&name)) + require.Equal(t, "foo", name) + + _, err = db.Exec("SET application_name = 'bar'") + require.NoError(t, err) + + f.RequestTransfer() + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationSuccessCount.Count() == 2 + }, 20*time.Second, 25*time.Millisecond) + require.Equal(t, tenant1.SQLAddr(), queryAddr(t, tCtx, db)) + + require.NoError(t, db.QueryRow("SHOW application_name").Scan(&name)) + require.Equal(t, "bar", name) + + // Now attempt a transfer concurrently with requests. + closerCh := make(chan struct{}) + go func() { + for i := 0; i < 10 && tCtx.Err() == nil; i++ { + f.RequestTransfer() + time.Sleep(500 * time.Millisecond) + } + closerCh <- struct{}{} + }() + + // This test runs for 5 seconds. + var tenant1Addr, tenant2Addr int + for i := 0; i < 100; i++ { + addr := queryAddr(t, tCtx, db) + if addr == tenant1.SQLAddr() { + tenant1Addr++ + } else { + require.Equal(t, tenant2.SQLAddr(), addr) + tenant2Addr++ + } + time.Sleep(50 * time.Millisecond) + } + + // In 5s, we should have at least 10 successful transfers. Just do + // an approximation here. + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationSuccessCount.Count() >= 5 + }, 20*time.Second, 25*time.Millisecond) + require.True(t, tenant1Addr > 2) + require.True(t, tenant2Addr > 2) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorRecoverableCount.Count()) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorFatalCount.Count()) + + // Ensure that the goroutine terminates so other subtests are not + // affected. + <-closerCh + + // There's a chance that we still have an in-progress transfer, so + // attempt to wait. + require.Eventually(t, func() bool { + f.mu.Lock() + defer f.mu.Unlock() + return !f.mu.isTransferring + }, 10*time.Second, 25*time.Millisecond) + + require.Equal(t, f.metrics.ConnMigrationAttemptedCount.Count(), + f.metrics.ConnMigrationSuccessCount.Count(), + ) + }) + + // Transfers should fail if there is an open transaction. These failed + // transfers should not close the connection. + t.Run("failed_transfers_with_tx", func(t *testing.T) { + initSuccessCount := f.metrics.ConnMigrationSuccessCount.Count() + initAddr := queryAddr(t, tCtx, db) + + err = crdb.ExecuteTx(tCtx, db, nil /* txopts */, func(tx *gosql.Tx) error { + for i := 0; i < 10; i++ { + f.RequestTransfer() + addr := queryAddr(t, tCtx, tx) + if initAddr != addr { + return errors.Newf( + "address does not match, expected %s, found %s", + initAddr, + addr, + ) + } + time.Sleep(50 * time.Millisecond) + } + return nil + }) + require.NoError(t, err) + + // Make sure there are no pending transfers. + func() { + f.mu.Lock() + defer f.mu.Unlock() + require.False(t, f.mu.isTransferring) + }() + + // Just check that we have half of what we requested since we cannot + // guarantee that the transfer will run within 50ms. + require.True(t, f.metrics.ConnMigrationErrorRecoverableCount.Count() >= 5) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorFatalCount.Count()) + require.Equal(t, initSuccessCount, f.metrics.ConnMigrationSuccessCount.Count()) + prevErrorRecoverableCount := f.metrics.ConnMigrationErrorRecoverableCount.Count() + + // Once the transaction is closed, transfers should work. + f.RequestTransfer() + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationSuccessCount.Count() == initSuccessCount+1 + }, 20*time.Second, 25*time.Millisecond) + require.NotEqual(t, initAddr, queryAddr(t, tCtx, db)) + require.Equal(t, prevErrorRecoverableCount, f.metrics.ConnMigrationErrorRecoverableCount.Count()) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorFatalCount.Count()) + + // We have already asserted metrics above, so transfer must have + // been completed. + f.mu.Lock() + defer f.mu.Unlock() + require.False(t, f.mu.isTransferring) + }) + + // Transfer timeout caused by dial issues should not close the session. + // We will test this by introducing delays when connecting to the SQL + // pod. + t.Run("failed_transfers_with_dial_issues", func(t *testing.T) { + initSuccessCount := f.metrics.ConnMigrationSuccessCount.Count() + initErrorRecoverableCount := f.metrics.ConnMigrationErrorRecoverableCount.Count() + initAddr := queryAddr(t, tCtx, db) + + // Set the delay longer than the timeout. + lookupAddrDelayDuration = 10 * time.Second + defer testutils.TestingHook(&defaultTransferTimeout, 3*time.Second)() + + f.RequestTransfer() + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationErrorRecoverableCount.Count() == initErrorRecoverableCount+1 + }, 20*time.Second, 25*time.Millisecond) + require.Equal(t, initAddr, queryAddr(t, tCtx, db)) + require.Equal(t, initSuccessCount, f.metrics.ConnMigrationSuccessCount.Count()) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorFatalCount.Count()) + + // We have already asserted metrics above, so transfer must have + // been completed. + f.mu.Lock() + defer f.mu.Unlock() + require.False(t, f.mu.isTransferring) + }) + }) + + // Test transfer timeouts caused by waiting for a transfer state response. + // In reality, this can only be caused by pipelined queries. Consider the + // folllowing: + // 1. short-running simple query + // 2. long-running simple query + // 3. SHOW TRANSFER STATE + // When (1) returns a response, the forwarder will see that we're in a + // safe transfer point, and initiate (3). But (2) may block until we hit + // a timeout. + // + // There's no easy way to simulate pipelined queries. pgtest (that allows + // us to send individual pgwire messages) does not support authentication, + // which is what the proxy needs, so we will stub isSafeTransferPoint + // instead. + t.Run("transfer_timeout_in_response", func(t *testing.T) { + tCtx, cancel := context.WithCancel(ctx) + defer cancel() + + db, err := gosql.Open("postgres", connectionString) + db.SetMaxOpenConns(1) + defer db.Close() + require.NoError(t, err) + + // Use a single connection so that we don't reopen when the connection + // is closed. + conn, err := db.Conn(tCtx) + require.NoError(t, err) + + // Spin up a goroutine to trigger the initial connection. + go func() { + _ = conn.PingContext(tCtx) + }() + + var f *forwarder + select { + case f = <-forwarderCh: + case <-time.After(10 * time.Second): + t.Fatal("no connection") + } + + // Set up forwarder hooks. + prevTenant1 := true + f.connector.testingKnobs.lookupAddr = func(ctx context.Context) (string, error) { + if prevTenant1 { + prevTenant1 = false + return tenant2.SQLAddr(), nil + } + prevTenant1 = true + return tenant1.SQLAddr(), nil + } + defer testutils.TestingHook(&isSafeTransferPoint, func(req *processor, res *processor) bool { + return true + })() + // Transfer timeout is 3s, and we'll run pg_sleep for 10s. + defer testutils.TestingHook(&defaultTransferTimeout, 3*time.Second)() + + goCh := make(chan struct{}, 1) + errCh := make(chan error, 1) + go func() { + goCh <- struct{}{} + _, err = conn.ExecContext(tCtx, "SELECT pg_sleep(10)") + errCh <- err + }() + + // Block until goroutine is started. We want to make sure we run the + // transfer request *after* sending the query. This doesn't guarantee, + // but is the best that we can do. We also added a sleep call here. + // + // Alternatively, we could open another connection, and query the server + // to make sure pg_sleep is running, but that seems unnecessary for just + // one test. + <-goCh + time.Sleep(250 * time.Millisecond) + f.RequestTransfer() + + // Connection should be closed because this is a non-recoverable error, + // i.e. timeout after sending the request, but before fully receiving + // its response. + require.Eventually(t, func() bool { + err := conn.PingContext(tCtx) + return err != nil && strings.Contains(err.Error(), "bad connection") + }, 20*time.Second, 25*time.Millisecond) + + select { + case <-time.After(10 * time.Second): + t.Fatalf("require that pg_sleep query terminates") + case err = <-errCh: + require.NotNil(t, err) + require.Regexp(t, "bad connection", err.Error()) + } + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationErrorFatalCount.Count() == 1 + }, 30*time.Second, 25*time.Millisecond) + require.NotNil(t, f.ctx.Err()) + }) +} + func TestClusterNameAndTenantFromParams(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t)