diff --git a/lib/srv/usermgmt.go b/lib/srv/usermgmt.go index 8a1a7bc697f9d..9ae82f186ca80 100644 --- a/lib/srv/usermgmt.go +++ b/lib/srv/usermgmt.go @@ -23,6 +23,7 @@ import ( "errors" "fmt" "io" + "maps" "os/user" "regexp" "strings" @@ -229,20 +230,20 @@ func (u *HostUserManagement) UpsertUser(name string, ui *services.HostUsersInfo) return nil, trace.BadParameter("Mode is a required argument to CreateUser") } - groups := make([]string, 0, len(ui.Groups)) + groupsToAdd := make([]string, 0, len(ui.Groups)) for _, group := range ui.Groups { if group == name { // this causes an error as useradd expects the group with the same name as the user to be available log.Debugf("Skipping group creation with name the same as login user (%q, %q).", name, group) continue } - groups = append(groups, group) + groupsToAdd = append(groupsToAdd, group) } if ui.Mode == types.CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP { - groups = append(groups, types.TeleportServiceGroup) + groupsToAdd = append(groupsToAdd, types.TeleportServiceGroup) } var errs []error - for _, group := range groups { + for _, group := range groupsToAdd { if err := u.createGroupIfNotExist(group); err != nil { errs = append(errs, err) continue @@ -259,13 +260,12 @@ func (u *HostUserManagement) UpsertUser(name string, ui *services.HostUsersInfo) if tempUser != nil { // Collect actions that need to be done together under a lock on the user. - actionsUnderLock := []func() error{ - func() error { - // If the user exists, set user groups again as they might have changed. - return trace.Wrap(u.backend.SetUserGroups(name, groups)) - }, - } + actionsUnderLock := make([]func() error, 0, 2) doWithUserLock := func() error { + if len(actionsUnderLock) == 0 { + return nil + } + return trace.Wrap(u.doWithUserLock(func(_ types.SemaphoreLease) error { for _, action := range actionsUnderLock { if err := action(); err != nil { @@ -276,6 +276,38 @@ func (u *HostUserManagement) UpsertUser(name string, ui *services.HostUsersInfo) })) } + // Get the user's current groups. + currentGroups := make(map[string]struct{}, len(groupsToAdd)) + groupIds, err := u.backend.UserGIDs(tempUser) + if err != nil { + return nil, trace.Wrap(err) + } + for _, groupId := range groupIds { + group, err := u.backend.LookupGroupByID(groupId) + if err != nil { + return nil, trace.Wrap(err) + } + currentGroups[group.Name] = struct{}{} + } + + // Get the groups that the user should end up with, including the primary group. + finalGroups := make(map[string]struct{}, len(groupsToAdd)+1) + for _, group := range groupsToAdd { + finalGroups[group] = struct{}{} + } + primaryGroup, err := u.backend.LookupGroupByID(tempUser.Gid) + if err != nil { + return nil, trace.Wrap(err) + } + finalGroups[primaryGroup.Name] = struct{}{} + + // Check if the user's groups need to be updated. + if !maps.Equal(currentGroups, finalGroups) { + actionsUnderLock = append(actionsUnderLock, func() error { + return trace.Wrap(u.backend.SetUserGroups(name, groupsToAdd)) + }) + } + systemGroup, err := u.backend.LookupGroup(types.TeleportServiceGroup) if err != nil { if isUnknownGroupError(err, types.TeleportServiceGroup) { @@ -339,7 +371,7 @@ func (u *HostUserManagement) UpsertUser(name string, ui *services.HostUsersInfo) } } - err = u.backend.CreateUser(name, groups, home, ui.UID, ui.GID) + err = u.backend.CreateUser(name, groupsToAdd, home, ui.UID, ui.GID) if err != nil && !trace.IsAlreadyExists(err) { return trace.WrapWithMessage(err, "error while creating user") } diff --git a/lib/srv/usermgmt_test.go b/lib/srv/usermgmt_test.go index 7003237105e2a..4725a9acc9070 100644 --- a/lib/srv/usermgmt_test.go +++ b/lib/srv/usermgmt_test.go @@ -27,6 +27,7 @@ import ( "testing" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -46,6 +47,8 @@ type testHostUserBackend struct { userUID map[string]string // userGID: user -> gid userGID map[string]string + + setUserGroupsCalls int } func newTestUserMgmt() *testHostUserBackend { @@ -68,28 +71,40 @@ func (tm *testHostUserBackend) GetAllUsers() ([]string, error) { func (tm *testHostUserBackend) Lookup(username string) (*user.User, error) { if _, ok := tm.users[username]; !ok { - return nil, nil + return nil, user.UnknownUserError(username) } return &user.User{ Username: username, + Uid: tm.userUID[username], + Gid: tm.userGID[username], }, nil } func (tm *testHostUserBackend) LookupGroup(groupname string) (*user.Group, error) { + gid, ok := tm.groups[groupname] + if !ok { + return nil, user.UnknownGroupError(groupname) + } return &user.Group{ - Gid: tm.groups[groupname], + Gid: gid, Name: groupname, }, nil } func (tm *testHostUserBackend) LookupGroupByID(gid string) (*user.Group, error) { - return &user.Group{ - Gid: tm.groups[gid], - Name: gid, - }, nil + for groupName, groupGid := range tm.groups { + if groupGid == gid { + return &user.Group{ + Gid: gid, + Name: groupName, + }, nil + } + } + return nil, user.UnknownGroupIdError(gid) } func (tm *testHostUserBackend) SetUserGroups(name string, groups []string) error { + tm.setUserGroupsCalls++ if _, ok := tm.users[name]; !ok { return trace.NotFound("User %q doesn't exist", name) } @@ -98,10 +113,12 @@ func (tm *testHostUserBackend) SetUserGroups(name string, groups []string) error } func (tm *testHostUserBackend) UserGIDs(u *user.User) ([]string, error) { - ids := make([]string, 0, len(tm.users[u.Username])) + ids := make([]string, 0, len(tm.users[u.Username])+1) for _, id := range tm.users[u.Username] { ids = append(ids, tm.groups[id]) } + // Include primary group. + ids = append(ids, u.Gid) return ids, nil } @@ -110,7 +127,10 @@ func (tm *testHostUserBackend) CreateGroup(group, gid string) error { if ok { return trace.AlreadyExists("Group %q, already exists", group) } - tm.groups[group] = fmt.Sprint(len(tm.groups) + 1) + if gid == "" { + gid = fmt.Sprint(len(tm.groups) + 1) + } + tm.groups[group] = gid return nil } @@ -119,6 +139,14 @@ func (tm *testHostUserBackend) CreateUser(user string, groups []string, home, ui if ok { return trace.AlreadyExists("Group %q, already exists", user) } + if uid == "" { + uid = fmt.Sprint(len(tm.users) + 1) + } + if gid == "" { + gid = fmt.Sprint(len(tm.groups) + 1) + } + // Ensure that the user has a primary group. It's OK if it already exists. + _ = tm.CreateGroup(user, gid) tm.users[user] = groups tm.userUID[user] = uid tm.userGID[user] = gid @@ -358,3 +386,47 @@ func TestIsUnknownGroupError(t *testing.T) { require.Equal(t, tc.isUnknownGroupError, isUnknownGroupError(tc.err, unknownGroupName)) } } + +func TestUpdateUserGroups(t *testing.T) { + t.Parallel() + + backend := newTestUserMgmt() + bk, err := memory.New(memory.Config{}) + require.NoError(t, err) + pres := local.NewPresenceService(bk) + users := HostUserManagement{ + backend: backend, + storage: pres, + } + + allGroups := []string{"foo", "bar", "baz", "quux"} + for _, group := range allGroups { + require.NoError(t, backend.CreateGroup(group, "")) + } + + userinfo := &services.HostUsersInfo{ + Groups: allGroups[:2], + Mode: types.CreateHostUserMode_HOST_USER_MODE_KEEP, + } + // Create a user with some groups. + closer, err := users.UpsertUser("alice", userinfo) + assert.NoError(t, err) + assert.Nil(t, closer) + assert.Zero(t, backend.setUserGroupsCalls) + assert.ElementsMatch(t, userinfo.Groups, backend.users["alice"]) + + // Update user with new groups. + userinfo.Groups = allGroups[2:] + closer, err = users.UpsertUser("alice", userinfo) + assert.NoError(t, err) + assert.Nil(t, closer) + assert.Equal(t, 1, backend.setUserGroupsCalls) + assert.ElementsMatch(t, userinfo.Groups, backend.users["alice"]) + + // Upsert again with same groups should not call SetUserGroups. + closer, err = users.UpsertUser("alice", userinfo) + assert.NoError(t, err) + assert.Nil(t, closer) + assert.Equal(t, 1, backend.setUserGroupsCalls) + assert.ElementsMatch(t, userinfo.Groups, backend.users["alice"]) +}