diff --git a/executor/coprocessor.go b/executor/coprocessor.go index 6eb438d5aaeb5..3811475fa9212 100644 --- a/executor/coprocessor.go +++ b/executor/coprocessor.go @@ -144,8 +144,8 @@ func (h *CoprocessorDAGHandler) buildDAGExecutor(req *coprocessor.Request) (Exec Username: dagReq.User.UserName, Hostname: dagReq.User.UserHost, } - authName, authHost, success := pm.GetAuthWithoutVerification(dagReq.User.UserName, dagReq.User.UserHost) - if success { + authName, authHost, success := pm.MatchIdentity(dagReq.User.UserName, dagReq.User.UserHost, false) + if success && pm.GetAuthWithoutVerification(authName, authHost) { h.sctx.GetSessionVars().User.AuthUsername = authName h.sctx.GetSessionVars().User.AuthHostname = authHost h.sctx.GetSessionVars().ActiveRoles = pm.GetDefaultRoles(authName, authHost) diff --git a/privilege/privilege.go b/privilege/privilege.go index e0b9d41f41b1d..af5ff9924ffe9 100644 --- a/privilege/privilege.go +++ b/privilege/privilege.go @@ -59,10 +59,15 @@ type Manager interface { RequestDynamicVerificationWithUser(privName string, grantable bool, user *auth.UserIdentity) bool // ConnectionVerification verifies user privilege for connection. - ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) (string, string, bool) + // Requires exact match on user name and host name. + ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) bool // GetAuthWithoutVerification uses to get auth name without verification. - GetAuthWithoutVerification(user, host string) (string, string, bool) + // Requires exact match on user name and host name. + GetAuthWithoutVerification(user, host string) bool + + // MatchIdentity matches an identity + MatchIdentity(user, host string, skipNameResolve bool) (string, string, bool) // DBIsVisible returns true is the database is visible to current user. DBIsVisible(activeRole []*auth.RoleIdentity, db string) bool diff --git a/privilege/privileges/cache.go b/privilege/privileges/cache.go index dc90170a500ef..4d388e073b205 100644 --- a/privilege/privileges/cache.go +++ b/privilege/privileges/cache.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/chunk" @@ -848,6 +849,9 @@ func decodeSetToPrivilege(s types.Set) mysql.PrivilegeType { // See https://dev.mysql.com/doc/refman/5.7/en/account-names.html func (record *baseRecord) hostMatch(s string) bool { if record.hostIPNet == nil { + if record.Host == "localhost" && net.ParseIP(s).IsLoopback() { + return true + } return false } ip := net.ParseIP(s).To4() @@ -890,14 +894,54 @@ func patternMatch(str string, patChars, patTypes []byte) bool { return stringutil.DoMatchBytes(str, patChars, patTypes) } -// connectionVerification verifies the connection have access to TiDB server. -func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord { +// matchIdentity finds an identity to match a user + host +// using the correct rules according to MySQL. +func (p *MySQLPrivilege) matchIdentity(user, host string, skipNameResolve bool) *UserRecord { for i := 0; i < len(p.User); i++ { record := &p.User[i] if record.match(user, host) { return record } } + + // If skip-name resolve is not enabled, and the host is not localhost + // we can fallback and try to resolve with all addrs that match. + // TODO: this is imported from previous code in session.Auth(), and can be improved in future. + if !skipNameResolve && host != variable.DefHostname { + addrs, err := net.LookupAddr(host) + if err != nil { + logutil.BgLogger().Warn( + "net.LookupAddr returned an error during auth check", + zap.String("host", host), + zap.Error(err), + ) + return nil + } + for _, addr := range addrs { + for i := 0; i < len(p.User); i++ { + record := &p.User[i] + if record.match(user, addr) { + return record + } + } + } + } + return nil +} + +// connectionVerification verifies the username + hostname according to exact +// match from the mysql.user privilege table. call matchIdentity() first if you +// do not have an exact match yet. +func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord { + records, exists := p.UserMap[user] + if exists { + for i := 0; i < len(records); i++ { + record := &records[i] + if record.Host == host { // exact match + return record + } + } + } return nil } diff --git a/privilege/privileges/privileges.go b/privilege/privileges/privileges.go index 104c2c3782387..fea22acef641c 100644 --- a/privilege/privileges/privileges.go +++ b/privilege/privileges/privileges.go @@ -256,8 +256,21 @@ func (p *UserPrivileges) GetAuthPlugin(user, host string) (string, error) { return "", errors.New("Failed to get plugin for user") } +// MatchIdentity implements the Manager interface. +func (p *UserPrivileges) MatchIdentity(user, host string, skipNameResolve bool) (u string, h string, success bool) { + if SkipWithGrant { + return user, host, true + } + mysqlPriv := p.Handle.Get() + record := mysqlPriv.matchIdentity(user, host, skipNameResolve) + if record != nil { + return record.User, record.Host, true + } + return "", "", false +} + // GetAuthWithoutVerification implements the Manager interface. -func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string, h string, success bool) { +func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (success bool) { if SkipWithGrant { p.user = user p.host = host @@ -273,16 +286,14 @@ func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string return } - u = record.User - h = record.Host p.user = user - p.host = h + p.host = record.Host success = true return } // ConnectionVerification implements the Manager interface. -func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) { +func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (success bool) { if SkipWithGrant { p.user = user p.host = host @@ -298,9 +309,6 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio return } - u = record.User - h = record.Host - globalPriv := mysqlPriv.matchGlobalPriv(user, host) if globalPriv != nil { if !p.checkSSL(globalPriv, tlsState) { @@ -328,7 +336,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio // empty password if len(pwd) == 0 && len(authentication) == 0 { p.user = user - p.host = h + p.host = record.Host success = true return } @@ -371,7 +379,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio } p.user = user - p.host = h + p.host = record.Host success = true return } diff --git a/server/conn.go b/server/conn.go index 49a42fe54bb94..54e1c3728b672 100644 --- a/server/conn.go +++ b/server/conn.go @@ -211,6 +211,9 @@ func (cc *clientConn) String() string { // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest // https://bugs.mysql.com/bug.php?id=93044 func (cc *clientConn) authSwitchRequest(ctx context.Context, plugin string) ([]byte, error) { + failpoint.Inject("FakeAuthSwitch", func() { + failpoint.Return([]byte(plugin), nil) + }) enclen := 1 + len(plugin) + 1 + len(cc.salt) + 1 data := cc.alloc.AllocWithLen(4, enclen) data = append(data, mysql.AuthSwitchRequest) // switch request @@ -708,9 +711,10 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeResponse41) error { if resp.Capability&mysql.ClientPluginAuth > 0 { - newAuth, err := cc.checkAuthPlugin(ctx, &resp.AuthPlugin) + newAuth, err := cc.checkAuthPlugin(ctx, resp) if err != nil { logutil.Logger(ctx).Warn("failed to check the user authplugin", zap.Error(err)) + return err } if len(newAuth) > 0 { resp.Auth = newAuth @@ -718,30 +722,18 @@ func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeRespo switch resp.AuthPlugin { case mysql.AuthCachingSha2Password: - resp.Auth, err = cc.authSha(ctx) - if err != nil { - return err - } case mysql.AuthNativePassword: case mysql.AuthSocket: default: logutil.Logger(ctx).Warn("Unknown Auth Plugin", zap.String("plugin", resp.AuthPlugin)) } } else { + // MySQL 5.1 and older clients don't support authentication plugins. logutil.Logger(ctx).Warn("Client without Auth Plugin support; Please upgrade client") - if cc.ctx == nil { - err := cc.openSession() - if err != nil { - return err - } - } - userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: cc.peerHost}) + _, err := cc.checkAuthPlugin(ctx, resp) if err != nil { return err } - if userplugin != mysql.AuthNativePassword && userplugin != "" { - return errNotSupportedAuthMode - } resp.AuthPlugin = mysql.AuthNativePassword } return nil @@ -845,7 +837,7 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) e } // Check if the Authentication Plugin of the server, client and user configuration matches -func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ([]byte, error) { +func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeResponse41) ([]byte, error) { // Open a context unless this was done before. if cc.ctx == nil { err := cc.openSession() @@ -854,12 +846,34 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ( } } - userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: cc.peerHost}) + authData := resp.Auth + hasPassword := "YES" + if len(authData) == 0 { + hasPassword = "NO" + } + host, _, err := cc.PeerHost(hasPassword) if err != nil { return nil, err } + // Find the identity of the user based on username and peer host. + identity, err := cc.ctx.MatchIdentity(cc.user, host) + if err != nil { + return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword) + } + // Get the plugin for the identity. + userplugin, err := cc.ctx.AuthPluginForUser(identity) + if err != nil { + logutil.Logger(ctx).Warn("Failed to get authentication method for user", + zap.String("user", cc.user), zap.String("host", host)) + } + failpoint.Inject("FakeUser", func(val failpoint.Value) { + userplugin = val.(string) + }) if userplugin == mysql.AuthSocket { - *authPlugin = mysql.AuthSocket + if !cc.isUnixSocket { + return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword) + } + resp.AuthPlugin = mysql.AuthSocket user, err := user.LookupId(fmt.Sprint(cc.socketCredUID)) if err != nil { return nil, err @@ -867,9 +881,19 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ( return []byte(user.Username), nil } if len(userplugin) == 0 { - logutil.Logger(ctx).Warn("No user plugin set, assuming MySQL Native Password", - zap.String("user", cc.user), zap.String("host", cc.peerHost)) - *authPlugin = mysql.AuthNativePassword + // No user plugin set, assuming MySQL Native Password + // This happens if the account doesn't exist or if the account doesn't have + // a password set. + if resp.AuthPlugin != mysql.AuthNativePassword { + if resp.Capability&mysql.ClientPluginAuth > 0 { + resp.AuthPlugin = mysql.AuthNativePassword + authData, err := cc.authSwitchRequest(ctx, mysql.AuthNativePassword) + if err != nil { + return nil, err + } + return authData, nil + } + } return nil, nil } @@ -878,13 +902,18 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ( // or if the authentication method send by the server doesn't match the authentication // method send by the client (*authPlugin) then we need to switch the authentication // method to match the one configured for that specific user. - if (cc.authPlugin != userplugin) || (cc.authPlugin != *authPlugin) { - authData, err := cc.authSwitchRequest(ctx, userplugin) - if err != nil { - return nil, err + if (cc.authPlugin != userplugin) || (cc.authPlugin != resp.AuthPlugin) { + if resp.Capability&mysql.ClientPluginAuth > 0 { + authData, err := cc.authSwitchRequest(ctx, userplugin) + if err != nil { + return nil, err + } + resp.AuthPlugin = userplugin + return authData, nil + } else if userplugin != mysql.AuthNativePassword { + // MySQL 5.1 and older don't support authentication plugins yet + return nil, errNotSupportedAuthMode } - *authPlugin = userplugin - return authData, nil } return nil, nil diff --git a/server/conn_test.go b/server/conn_test.go index dc50900e41624..95aeaf7af41d9 100644 --- a/server/conn_test.go +++ b/server/conn_test.go @@ -894,8 +894,6 @@ func TestShowErrors(t *testing.T) { } func TestHandleAuthPlugin(t *testing.T) { - t.Parallel() - store, clean := testkit.CreateMockStore(t) defer clean() @@ -905,25 +903,202 @@ func TestHandleAuthPlugin(t *testing.T) { drv := NewTiDBDriver(store) srv, err := NewServer(cfg, drv) require.NoError(t, err) + ctx := context.Background() + + tk := testkit.NewTestKit(t, store) + tk.MustExec("CREATE USER unativepassword") + defer func() { + tk.MustExec("DROP USER unativepassword") + }() + // 5.7 or newer client trying to authenticate with mysql_native_password cc := &clientConn{ connectionID: 1, alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", pkt: &packetIO{ bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), }, - collation: mysql.DefaultCollationID, - server: srv, - user: "root", + server: srv, + user: "unativepassword", } - ctx := context.Background() resp := handshakeResponse41{ Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthNativePassword, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + + // 8.0 or newer client trying to authenticate with caching_sha2_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthCachingSha2Password, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, resp.Auth, []byte(mysql.AuthNativePassword)) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // MySQL 5.1 or older client, without authplugin support + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + + // === Target account has mysql_native_password === + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeUser", "return(\"mysql_native_password\")")) + + // 5.7 or newer client trying to authenticate with mysql_native_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthNativePassword, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // 8.0 or newer client trying to authenticate with caching_sha2_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthCachingSha2Password, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthNativePassword), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // MySQL 5.1 or older client, without authplugin support + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.NoError(t, err) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeUser")) + + // === Target account has caching_sha2_password === + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeUser", "return(\"caching_sha2_password\")")) + + // 5.7 or newer client trying to authenticate with mysql_native_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthNativePassword, } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) - resp.Capability = mysql.ClientProtocol41 + // 8.0 or newer client trying to authenticate with caching_sha2_password + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)")) + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth, + AuthPlugin: mysql.AuthCachingSha2Password, + } err = cc.handleAuthPlugin(ctx, &resp) require.NoError(t, err) + require.Equal(t, []byte(mysql.AuthCachingSha2Password), resp.Auth) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch")) + + // MySQL 5.1 or older client, without authplugin support + cc = &clientConn{ + connectionID: 1, + alloc: arena.NewAllocator(1024), + collation: mysql.DefaultCollationID, + peerHost: "localhost", + pkt: &packetIO{ + bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)), + }, + server: srv, + user: "unativepassword", + } + resp = handshakeResponse41{ + Capability: mysql.ClientProtocol41, + } + err = cc.handleAuthPlugin(ctx, &resp) + require.Error(t, err) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeUser")) } diff --git a/server/http_handler_test.go b/server/http_handler_test.go index d78642b1651a0..ca17cfef8f29c 100644 --- a/server/http_handler_test.go +++ b/server/http_handler_test.go @@ -483,6 +483,7 @@ func (ts *basicHTTPHandlerTestSuite) startServer(c *C) { cfg.Port = 0 cfg.Status.StatusPort = 0 cfg.Status.ReportStatus = true + cfg.Socket = "" server, err := NewServer(cfg, ts.tidbdrv) c.Assert(err, IsNil) diff --git a/server/plan_replayer_test.go b/server/plan_replayer_test.go index 903f771463ee8..a007e1f5d3c9a 100644 --- a/server/plan_replayer_test.go +++ b/server/plan_replayer_test.go @@ -39,6 +39,7 @@ func TestDumpPlanReplayerAPI(t *testing.T) { client := newTestServerClient() cfg := newTestConfig() cfg.Port = client.port + cfg.Socket = "" cfg.Status.StatusPort = client.statusPort cfg.Status.ReportStatus = true diff --git a/server/statistics_handler_serial_test.go b/server/statistics_handler_serial_test.go index 7c56cf2186831..9d81d36dcb083 100644 --- a/server/statistics_handler_serial_test.go +++ b/server/statistics_handler_serial_test.go @@ -38,6 +38,7 @@ func TestDumpStatsAPI(t *testing.T) { client := newTestServerClient() cfg := newTestConfig() cfg.Port = client.port + cfg.Socket = "" cfg.Status.StatusPort = client.statusPort cfg.Status.ReportStatus = true diff --git a/session/session.go b/session/session.go index aa08d554c9f13..913faa3591aa9 100644 --- a/session/session.go +++ b/session/session.go @@ -24,7 +24,6 @@ import ( "crypto/tls" "encoding/json" "fmt" - "net" "runtime/pprof" "runtime/trace" "strconv" @@ -146,6 +145,7 @@ type Session interface { Auth(user *auth.UserIdentity, auth []byte, salt []byte) bool AuthWithoutVerification(user *auth.UserIdentity) bool AuthPluginForUser(user *auth.UserIdentity) (string, error) + MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) ShowProcess() *util.ProcessInfo // Return the information of the txn current running TxnInfo() *txninfo.TxnInfo @@ -2211,91 +2211,61 @@ func (s *session) AuthPluginForUser(user *auth.UserIdentity) (string, error) { return authplugin, nil } +// Auth validates a user using an authentication string and salt. +// If the password fails, it will keep trying other users until exhausted. +// This means it can not be refactored to use MatchIdentity yet. func (s *session) Auth(user *auth.UserIdentity, authentication []byte, salt []byte) bool { pm := privilege.GetPrivilegeManager(s) - - // Check IP or localhost. - var success bool - user.AuthUsername, user.AuthHostname, success = pm.ConnectionVerification(user.Username, user.Hostname, authentication, salt, s.sessionVars.TLSConnectionState) - if success { + authUser, err := s.MatchIdentity(user.Username, user.Hostname) + if err != nil { + return false + } + if pm.ConnectionVerification(authUser.Username, authUser.Hostname, authentication, salt, s.sessionVars.TLSConnectionState) { + user.AuthUsername = authUser.Username + user.AuthHostname = authUser.Hostname s.sessionVars.User = user s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) return true - } else if user.Hostname == variable.DefHostname { - return false } + return false +} - // Check Hostname. - for _, addr := range s.getHostByIP(user.Hostname) { - u, h, success := pm.ConnectionVerification(user.Username, addr, authentication, salt, s.sessionVars.TLSConnectionState) - if success { - s.sessionVars.User = &auth.UserIdentity{ - Username: user.Username, - Hostname: addr, - AuthUsername: u, - AuthHostname: h, - } - s.sessionVars.ActiveRoles = pm.GetDefaultRoles(u, h) - return true - } +// MatchIdentity finds the matching username + password in the MySQL privilege tables +// for a username + hostname, since MySQL can have wildcards. +func (s *session) MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) { + pm := privilege.GetPrivilegeManager(s) + var success bool + var skipNameResolve bool + var user = &auth.UserIdentity{} + varVal, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.SkipNameResolve) + if err == nil && variable.TiDBOptOn(varVal) { + skipNameResolve = true } - return false + user.Username, user.Hostname, success = pm.MatchIdentity(username, remoteHost, skipNameResolve) + if success { + return user, nil + } + // This error will not be returned to the user, access denied will be instead + return nil, fmt.Errorf("could not find matching user in MatchIdentity: %s, %s", username, remoteHost) } // AuthWithoutVerification is required by the ResetConnection RPC func (s *session) AuthWithoutVerification(user *auth.UserIdentity) bool { pm := privilege.GetPrivilegeManager(s) - - // Check IP or localhost. - var success bool - user.AuthUsername, user.AuthHostname, success = pm.GetAuthWithoutVerification(user.Username, user.Hostname) - if success { + authUser, err := s.MatchIdentity(user.Username, user.Hostname) + if err != nil { + return false + } + if pm.GetAuthWithoutVerification(authUser.Username, authUser.Hostname) { + user.AuthUsername = authUser.Username + user.AuthHostname = authUser.Hostname s.sessionVars.User = user s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) return true - } else if user.Hostname == variable.DefHostname { - return false - } - - // Check Hostname. - for _, addr := range s.getHostByIP(user.Hostname) { - u, h, success := pm.GetAuthWithoutVerification(user.Username, addr) - if success { - s.sessionVars.User = &auth.UserIdentity{ - Username: user.Username, - Hostname: addr, - AuthUsername: u, - AuthHostname: h, - } - s.sessionVars.ActiveRoles = pm.GetDefaultRoles(u, h) - return true - } } return false } -func (s *session) getHostByIP(ip string) []string { - if ip == "127.0.0.1" { - return []string{variable.DefHostname} - } - skipNameResolve, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.SkipNameResolve) - if err == nil && variable.TiDBOptOn(skipNameResolve) { - return []string{ip} // user wants to skip name resolution - } - addrs, err := net.LookupAddr(ip) - if err != nil { - // These messages can be noisy. - // See: https://github.com/pingcap/tidb/pull/13989 - logutil.BgLogger().Debug( - "net.LookupAddr returned an error during auth check", - zap.String("ip", ip), - zap.Error(err), - ) - return []string{ip} - } - return addrs -} - // RefreshVars implements the sessionctx.Context interface. func (s *session) RefreshVars(ctx context.Context) error { pruneMode, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.TiDBPartitionPruneMode) diff --git a/session/session_test.go b/session/session_test.go index b2b386fefbd61..b4aadd0306bc1 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -18,6 +18,7 @@ import ( "context" "flag" "fmt" + "net" "os" "path" "runtime" @@ -691,6 +692,50 @@ func (s *testSessionSuite) TestGlobalVarAccessor(c *C) { c.Assert(terror.ErrorEqual(err, variable.ErrUnknownTimeZone), IsTrue) } +func (s *testSessionSuite) TestMatchIdentity(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("CREATE USER `useridentity`@`%`") + tk.MustExec("CREATE USER `useridentity`@`localhost`") + tk.MustExec("CREATE USER `useridentity`@`192.168.1.1`") + tk.MustExec("CREATE USER `useridentity`@`example.com`") + + // The MySQL matching rule is most specific to least specific. + // So if I log in from 192.168.1.1 I should match that entry always. + identity, err := tk.Se.MatchIdentity("useridentity", "192.168.1.1") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "192.168.1.1") + + // If I log in from localhost, I should match localhost + identity, err = tk.Se.MatchIdentity("useridentity", "localhost") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "localhost") + + // If I log in from 192.168.1.2 I should match wildcard. + identity, err = tk.Se.MatchIdentity("useridentity", "192.168.1.2") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "%") + + identity, err = tk.Se.MatchIdentity("useridentity", "127.0.0.1") + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + c.Assert(identity.Hostname, Equals, "localhost") + + // This uses the lookup of example.com to get an IP address. + // We then login with that IP address, but expect it to match the example.com + // entry in the privileges table (by reverse lookup). + ips, err := net.LookupHost("example.com") + c.Assert(err, IsNil) + identity, err = tk.Se.MatchIdentity("useridentity", ips[0]) + c.Assert(err, IsNil) + c.Assert(identity.Username, Equals, "useridentity") + // FIXME: we *should* match example.com instead + // as long as skip-name-resolve is not set (DEFAULT) + c.Assert(identity.Hostname, Equals, "%") +} + func (s *testSessionSuite) TestGetSysVariables(c *C) { tk := testkit.NewTestKitWithInit(c, s.store)