Skip to content

Commit

Permalink
[v16] Host user creation - Only update groups if needed (#45162)
Browse files Browse the repository at this point in the history
* Only update groups when needed

* Update test

* Rename groups
  • Loading branch information
atburke authored Aug 7, 2024
1 parent e780095 commit 6904d9e
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 19 deletions.
54 changes: 43 additions & 11 deletions lib/srv/usermgmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"errors"
"fmt"
"io"
"maps"
"os/user"
"regexp"
"strings"
Expand Down Expand Up @@ -229,20 +230,20 @@ func (u *HostUserManagement) UpsertUser(name string, ui *services.HostUsersInfo)
return nil, trace.BadParameter("Mode is a required argument to CreateUser")
}

groups := make([]string, 0, len(ui.Groups))
groupsToAdd := make([]string, 0, len(ui.Groups))
for _, group := range ui.Groups {
if group == name {
// this causes an error as useradd expects the group with the same name as the user to be available
log.Debugf("Skipping group creation with name the same as login user (%q, %q).", name, group)
continue
}
groups = append(groups, group)
groupsToAdd = append(groupsToAdd, group)
}
if ui.Mode == types.CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP {
groups = append(groups, types.TeleportServiceGroup)
groupsToAdd = append(groupsToAdd, types.TeleportServiceGroup)
}
var errs []error
for _, group := range groups {
for _, group := range groupsToAdd {
if err := u.createGroupIfNotExist(group); err != nil {
errs = append(errs, err)
continue
Expand All @@ -259,13 +260,12 @@ func (u *HostUserManagement) UpsertUser(name string, ui *services.HostUsersInfo)

if tempUser != nil {
// Collect actions that need to be done together under a lock on the user.
actionsUnderLock := []func() error{
func() error {
// If the user exists, set user groups again as they might have changed.
return trace.Wrap(u.backend.SetUserGroups(name, groups))
},
}
actionsUnderLock := make([]func() error, 0, 2)
doWithUserLock := func() error {
if len(actionsUnderLock) == 0 {
return nil
}

return trace.Wrap(u.doWithUserLock(func(_ types.SemaphoreLease) error {
for _, action := range actionsUnderLock {
if err := action(); err != nil {
Expand All @@ -276,6 +276,38 @@ func (u *HostUserManagement) UpsertUser(name string, ui *services.HostUsersInfo)
}))
}

// Get the user's current groups.
currentGroups := make(map[string]struct{}, len(groupsToAdd))
groupIds, err := u.backend.UserGIDs(tempUser)
if err != nil {
return nil, trace.Wrap(err)
}
for _, groupId := range groupIds {
group, err := u.backend.LookupGroupByID(groupId)
if err != nil {
return nil, trace.Wrap(err)
}
currentGroups[group.Name] = struct{}{}
}

// Get the groups that the user should end up with, including the primary group.
finalGroups := make(map[string]struct{}, len(groupsToAdd)+1)
for _, group := range groupsToAdd {
finalGroups[group] = struct{}{}
}
primaryGroup, err := u.backend.LookupGroupByID(tempUser.Gid)
if err != nil {
return nil, trace.Wrap(err)
}
finalGroups[primaryGroup.Name] = struct{}{}

// Check if the user's groups need to be updated.
if !maps.Equal(currentGroups, finalGroups) {
actionsUnderLock = append(actionsUnderLock, func() error {
return trace.Wrap(u.backend.SetUserGroups(name, groupsToAdd))
})
}

systemGroup, err := u.backend.LookupGroup(types.TeleportServiceGroup)
if err != nil {
if isUnknownGroupError(err, types.TeleportServiceGroup) {
Expand Down Expand Up @@ -339,7 +371,7 @@ func (u *HostUserManagement) UpsertUser(name string, ui *services.HostUsersInfo)
}
}

err = u.backend.CreateUser(name, groups, home, ui.UID, ui.GID)
err = u.backend.CreateUser(name, groupsToAdd, home, ui.UID, ui.GID)
if err != nil && !trace.IsAlreadyExists(err) {
return trace.WrapWithMessage(err, "error while creating user")
}
Expand Down
88 changes: 80 additions & 8 deletions lib/srv/usermgmt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
Expand All @@ -46,6 +47,8 @@ type testHostUserBackend struct {
userUID map[string]string
// userGID: user -> gid
userGID map[string]string

setUserGroupsCalls int
}

func newTestUserMgmt() *testHostUserBackend {
Expand All @@ -68,28 +71,40 @@ func (tm *testHostUserBackend) GetAllUsers() ([]string, error) {

func (tm *testHostUserBackend) Lookup(username string) (*user.User, error) {
if _, ok := tm.users[username]; !ok {
return nil, nil
return nil, user.UnknownUserError(username)
}
return &user.User{
Username: username,
Uid: tm.userUID[username],
Gid: tm.userGID[username],
}, nil
}

func (tm *testHostUserBackend) LookupGroup(groupname string) (*user.Group, error) {
gid, ok := tm.groups[groupname]
if !ok {
return nil, user.UnknownGroupError(groupname)
}
return &user.Group{
Gid: tm.groups[groupname],
Gid: gid,
Name: groupname,
}, nil
}

func (tm *testHostUserBackend) LookupGroupByID(gid string) (*user.Group, error) {
return &user.Group{
Gid: tm.groups[gid],
Name: gid,
}, nil
for groupName, groupGid := range tm.groups {
if groupGid == gid {
return &user.Group{
Gid: gid,
Name: groupName,
}, nil
}
}
return nil, user.UnknownGroupIdError(gid)
}

func (tm *testHostUserBackend) SetUserGroups(name string, groups []string) error {
tm.setUserGroupsCalls++
if _, ok := tm.users[name]; !ok {
return trace.NotFound("User %q doesn't exist", name)
}
Expand All @@ -98,10 +113,12 @@ func (tm *testHostUserBackend) SetUserGroups(name string, groups []string) error
}

func (tm *testHostUserBackend) UserGIDs(u *user.User) ([]string, error) {
ids := make([]string, 0, len(tm.users[u.Username]))
ids := make([]string, 0, len(tm.users[u.Username])+1)
for _, id := range tm.users[u.Username] {
ids = append(ids, tm.groups[id])
}
// Include primary group.
ids = append(ids, u.Gid)
return ids, nil
}

Expand All @@ -110,7 +127,10 @@ func (tm *testHostUserBackend) CreateGroup(group, gid string) error {
if ok {
return trace.AlreadyExists("Group %q, already exists", group)
}
tm.groups[group] = fmt.Sprint(len(tm.groups) + 1)
if gid == "" {
gid = fmt.Sprint(len(tm.groups) + 1)
}
tm.groups[group] = gid
return nil
}

Expand All @@ -119,6 +139,14 @@ func (tm *testHostUserBackend) CreateUser(user string, groups []string, home, ui
if ok {
return trace.AlreadyExists("Group %q, already exists", user)
}
if uid == "" {
uid = fmt.Sprint(len(tm.users) + 1)
}
if gid == "" {
gid = fmt.Sprint(len(tm.groups) + 1)
}
// Ensure that the user has a primary group. It's OK if it already exists.
_ = tm.CreateGroup(user, gid)
tm.users[user] = groups
tm.userUID[user] = uid
tm.userGID[user] = gid
Expand Down Expand Up @@ -358,3 +386,47 @@ func TestIsUnknownGroupError(t *testing.T) {
require.Equal(t, tc.isUnknownGroupError, isUnknownGroupError(tc.err, unknownGroupName))
}
}

func TestUpdateUserGroups(t *testing.T) {
t.Parallel()

backend := newTestUserMgmt()
bk, err := memory.New(memory.Config{})
require.NoError(t, err)
pres := local.NewPresenceService(bk)
users := HostUserManagement{
backend: backend,
storage: pres,
}

allGroups := []string{"foo", "bar", "baz", "quux"}
for _, group := range allGroups {
require.NoError(t, backend.CreateGroup(group, ""))
}

userinfo := &services.HostUsersInfo{
Groups: allGroups[:2],
Mode: types.CreateHostUserMode_HOST_USER_MODE_KEEP,
}
// Create a user with some groups.
closer, err := users.UpsertUser("alice", userinfo)
assert.NoError(t, err)
assert.Nil(t, closer)
assert.Zero(t, backend.setUserGroupsCalls)
assert.ElementsMatch(t, userinfo.Groups, backend.users["alice"])

// Update user with new groups.
userinfo.Groups = allGroups[2:]
closer, err = users.UpsertUser("alice", userinfo)
assert.NoError(t, err)
assert.Nil(t, closer)
assert.Equal(t, 1, backend.setUserGroupsCalls)
assert.ElementsMatch(t, userinfo.Groups, backend.users["alice"])

// Upsert again with same groups should not call SetUserGroups.
closer, err = users.UpsertUser("alice", userinfo)
assert.NoError(t, err)
assert.Nil(t, closer)
assert.Equal(t, 1, backend.setUserGroupsCalls)
assert.ElementsMatch(t, userinfo.Groups, backend.users["alice"])
}

0 comments on commit 6904d9e

Please sign in to comment.