Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

privilege, session, server: consistently map user login to identity (#30204) #30450

Merged
merged 3 commits into from
Feb 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions executor/coprocessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ func (h *CoprocessorDAGHandler) buildDAGExecutor(req *coprocessor.Request) (Exec
Username: dagReq.User.UserName,
Hostname: dagReq.User.UserHost,
}
authName, authHost, success := pm.GetAuthWithoutVerification(dagReq.User.UserName, dagReq.User.UserHost)
if success {
authName, authHost, success := pm.MatchIdentity(dagReq.User.UserName, dagReq.User.UserHost, false)
if success && pm.GetAuthWithoutVerification(authName, authHost) {
h.sctx.GetSessionVars().User.AuthUsername = authName
h.sctx.GetSessionVars().User.AuthHostname = authHost
h.sctx.GetSessionVars().ActiveRoles = pm.GetDefaultRoles(authName, authHost)
Expand Down
9 changes: 7 additions & 2 deletions privilege/privilege.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,15 @@ type Manager interface {
RequestDynamicVerificationWithUser(privName string, grantable bool, user *auth.UserIdentity) bool

// ConnectionVerification verifies user privilege for connection.
ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) (string, string, bool)
// Requires exact match on user name and host name.
ConnectionVerification(user, host string, auth, salt []byte, tlsState *tls.ConnectionState) bool

// GetAuthWithoutVerification uses to get auth name without verification.
GetAuthWithoutVerification(user, host string) (string, string, bool)
// Requires exact match on user name and host name.
GetAuthWithoutVerification(user, host string) bool

// MatchIdentity matches an identity
MatchIdentity(user, host string, skipNameResolve bool) (string, string, bool)

// DBIsVisible returns true is the database is visible to current user.
DBIsVisible(activeRole []*auth.RoleIdentity, db string) bool
Expand Down
48 changes: 46 additions & 2 deletions privilege/privileges/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -848,6 +849,9 @@ func decodeSetToPrivilege(s types.Set) mysql.PrivilegeType {
// See https://dev.mysql.com/doc/refman/5.7/en/account-names.html
func (record *baseRecord) hostMatch(s string) bool {
if record.hostIPNet == nil {
if record.Host == "localhost" && net.ParseIP(s).IsLoopback() {
return true
}
return false
}
ip := net.ParseIP(s).To4()
Expand Down Expand Up @@ -890,14 +894,54 @@ func patternMatch(str string, patChars, patTypes []byte) bool {
return stringutil.DoMatchBytes(str, patChars, patTypes)
}

// connectionVerification verifies the connection have access to TiDB server.
func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord {
// matchIdentity finds an identity to match a user + host
// using the correct rules according to MySQL.
func (p *MySQLPrivilege) matchIdentity(user, host string, skipNameResolve bool) *UserRecord {
for i := 0; i < len(p.User); i++ {
record := &p.User[i]
if record.match(user, host) {
return record
}
}

// If skip-name resolve is not enabled, and the host is not localhost
// we can fallback and try to resolve with all addrs that match.
// TODO: this is imported from previous code in session.Auth(), and can be improved in future.
if !skipNameResolve && host != variable.DefHostname {
addrs, err := net.LookupAddr(host)
if err != nil {
logutil.BgLogger().Warn(
"net.LookupAddr returned an error during auth check",
zap.String("host", host),
zap.Error(err),
)
return nil
}
for _, addr := range addrs {
for i := 0; i < len(p.User); i++ {
record := &p.User[i]
if record.match(user, addr) {
return record
}
}
}
}
return nil
}

// connectionVerification verifies the username + hostname according to exact
// match from the mysql.user privilege table. call matchIdentity() first if you
// do not have an exact match yet.
func (p *MySQLPrivilege) connectionVerification(user, host string) *UserRecord {
records, exists := p.UserMap[user]
if exists {
for i := 0; i < len(records); i++ {
record := &records[i]
if record.Host == host { // exact match
return record
}
}
}
return nil
}

Expand Down
28 changes: 18 additions & 10 deletions privilege/privileges/privileges.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,21 @@ func (p *UserPrivileges) GetAuthPlugin(user, host string) (string, error) {
return "", errors.New("Failed to get plugin for user")
}

// MatchIdentity implements the Manager interface.
func (p *UserPrivileges) MatchIdentity(user, host string, skipNameResolve bool) (u string, h string, success bool) {
if SkipWithGrant {
return user, host, true
}
mysqlPriv := p.Handle.Get()
record := mysqlPriv.matchIdentity(user, host, skipNameResolve)
if record != nil {
return record.User, record.Host, true
}
return "", "", false
}

// GetAuthWithoutVerification implements the Manager interface.
func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string, h string, success bool) {
func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (success bool) {
if SkipWithGrant {
p.user = user
p.host = host
Expand All @@ -273,16 +286,14 @@ func (p *UserPrivileges) GetAuthWithoutVerification(user, host string) (u string
return
}

u = record.User
h = record.Host
p.user = user
p.host = h
p.host = record.Host
success = true
return
}

// ConnectionVerification implements the Manager interface.
func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (u string, h string, success bool) {
func (p *UserPrivileges) ConnectionVerification(user, host string, authentication, salt []byte, tlsState *tls.ConnectionState) (success bool) {
if SkipWithGrant {
p.user = user
p.host = host
Expand All @@ -298,9 +309,6 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
return
}

u = record.User
h = record.Host

globalPriv := mysqlPriv.matchGlobalPriv(user, host)
if globalPriv != nil {
if !p.checkSSL(globalPriv, tlsState) {
Expand Down Expand Up @@ -328,7 +336,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
// empty password
if len(pwd) == 0 && len(authentication) == 0 {
p.user = user
p.host = h
p.host = record.Host
success = true
return
}
Expand Down Expand Up @@ -371,7 +379,7 @@ func (p *UserPrivileges) ConnectionVerification(user, host string, authenticatio
}

p.user = user
p.host = h
p.host = record.Host
success = true
return
}
Expand Down
83 changes: 56 additions & 27 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ func (cc *clientConn) String() string {
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
// https://bugs.mysql.com/bug.php?id=93044
func (cc *clientConn) authSwitchRequest(ctx context.Context, plugin string) ([]byte, error) {
failpoint.Inject("FakeAuthSwitch", func() {
failpoint.Return([]byte(plugin), nil)
})
enclen := 1 + len(plugin) + 1 + len(cc.salt) + 1
data := cc.alloc.AllocWithLen(4, enclen)
data = append(data, mysql.AuthSwitchRequest) // switch request
Expand Down Expand Up @@ -708,40 +711,29 @@ func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Con

func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshakeResponse41) error {
if resp.Capability&mysql.ClientPluginAuth > 0 {
newAuth, err := cc.checkAuthPlugin(ctx, &resp.AuthPlugin)
newAuth, err := cc.checkAuthPlugin(ctx, resp)
if err != nil {
logutil.Logger(ctx).Warn("failed to check the user authplugin", zap.Error(err))
return err
}
if len(newAuth) > 0 {
resp.Auth = newAuth
}

switch resp.AuthPlugin {
case mysql.AuthCachingSha2Password:
resp.Auth, err = cc.authSha(ctx)
if err != nil {
return err
}
case mysql.AuthNativePassword:
Comment on lines 724 to 725
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional. It was removed from master, I believe because this is handled in cc.checkAuthPlugin instead. The code could be cleaned up slightly, but that's for another PR.

case mysql.AuthSocket:
default:
logutil.Logger(ctx).Warn("Unknown Auth Plugin", zap.String("plugin", resp.AuthPlugin))
}
} else {
// MySQL 5.1 and older clients don't support authentication plugins.
logutil.Logger(ctx).Warn("Client without Auth Plugin support; Please upgrade client")
if cc.ctx == nil {
err := cc.openSession()
if err != nil {
return err
}
}
userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: cc.peerHost})
_, err := cc.checkAuthPlugin(ctx, resp)
if err != nil {
return err
}
if userplugin != mysql.AuthNativePassword && userplugin != "" {
return errNotSupportedAuthMode
}
Comment on lines -732 to -744
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes are also based on master, and checkAuthPlugin does these checks.

resp.AuthPlugin = mysql.AuthNativePassword
}
return nil
Expand Down Expand Up @@ -845,7 +837,7 @@ func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string) e
}

// Check if the Authentication Plugin of the server, client and user configuration matches
func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) ([]byte, error) {
func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshakeResponse41) ([]byte, error) {
// Open a context unless this was done before.
if cc.ctx == nil {
err := cc.openSession()
Expand All @@ -854,22 +846,54 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) (
}
}

userplugin, err := cc.ctx.AuthPluginForUser(&auth.UserIdentity{Username: cc.user, Hostname: cc.peerHost})
authData := resp.Auth
hasPassword := "YES"
if len(authData) == 0 {
hasPassword = "NO"
}
host, _, err := cc.PeerHost(hasPassword)
if err != nil {
return nil, err
}
// Find the identity of the user based on username and peer host.
identity, err := cc.ctx.MatchIdentity(cc.user, host)
if err != nil {
return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
// Get the plugin for the identity.
userplugin, err := cc.ctx.AuthPluginForUser(identity)
if err != nil {
logutil.Logger(ctx).Warn("Failed to get authentication method for user",
zap.String("user", cc.user), zap.String("host", host))
}
failpoint.Inject("FakeUser", func(val failpoint.Value) {
userplugin = val.(string)
})
if userplugin == mysql.AuthSocket {
*authPlugin = mysql.AuthSocket
if !cc.isUnixSocket {
return nil, errAccessDenied.FastGenByArgs(cc.user, host, hasPassword)
}
resp.AuthPlugin = mysql.AuthSocket
user, err := user.LookupId(fmt.Sprint(cc.socketCredUID))
if err != nil {
return nil, err
}
return []byte(user.Username), nil
}
if len(userplugin) == 0 {
logutil.Logger(ctx).Warn("No user plugin set, assuming MySQL Native Password",
zap.String("user", cc.user), zap.String("host", cc.peerHost))
*authPlugin = mysql.AuthNativePassword
// No user plugin set, assuming MySQL Native Password
// This happens if the account doesn't exist or if the account doesn't have
// a password set.
if resp.AuthPlugin != mysql.AuthNativePassword {
if resp.Capability&mysql.ClientPluginAuth > 0 {
resp.AuthPlugin = mysql.AuthNativePassword
authData, err := cc.authSwitchRequest(ctx, mysql.AuthNativePassword)
if err != nil {
return nil, err
}
return authData, nil
}
}
return nil, nil
}

Expand All @@ -878,13 +902,18 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, authPlugin *string) (
// or if the authentication method send by the server doesn't match the authentication
// method send by the client (*authPlugin) then we need to switch the authentication
// method to match the one configured for that specific user.
if (cc.authPlugin != userplugin) || (cc.authPlugin != *authPlugin) {
authData, err := cc.authSwitchRequest(ctx, userplugin)
if err != nil {
return nil, err
if (cc.authPlugin != userplugin) || (cc.authPlugin != resp.AuthPlugin) {
if resp.Capability&mysql.ClientPluginAuth > 0 {
authData, err := cc.authSwitchRequest(ctx, userplugin)
if err != nil {
return nil, err
}
resp.AuthPlugin = userplugin
return authData, nil
} else if userplugin != mysql.AuthNativePassword {
// MySQL 5.1 and older don't support authentication plugins yet
return nil, errNotSupportedAuthMode
}
*authPlugin = userplugin
return authData, nil
}

return nil, nil
Expand Down
Loading