diff --git a/lib/services/local/statichostuser.go b/lib/services/local/statichostuser.go index a072398e5c858..7f42e677bc03a 100644 --- a/lib/services/local/statichostuser.go +++ b/lib/services/local/statichostuser.go @@ -73,18 +73,27 @@ func (s *StaticHostUserService) GetStaticHostUser(ctx context.Context, name stri // CreateStaticHostUser creates a static host user. func (s *StaticHostUserService) CreateStaticHostUser(ctx context.Context, in *userprovisioningpb.StaticHostUser) (*userprovisioningpb.StaticHostUser, error) { + if err := services.ValidateStaticHostUser(in); err != nil { + return nil, trace.Wrap(err) + } out, err := s.svc.CreateResource(ctx, in) return out, trace.Wrap(err) } // UpdateStaticHostUser updates a static host user. func (s *StaticHostUserService) UpdateStaticHostUser(ctx context.Context, in *userprovisioningpb.StaticHostUser) (*userprovisioningpb.StaticHostUser, error) { - out, err := s.svc.UpdateResource(ctx, in) + if err := services.ValidateStaticHostUser(in); err != nil { + return nil, trace.Wrap(err) + } + out, err := s.svc.ConditionalUpdateResource(ctx, in) return out, trace.Wrap(err) } // UpsertStaticHostUser upserts a static host user. func (s *StaticHostUserService) UpsertStaticHostUser(ctx context.Context, in *userprovisioningpb.StaticHostUser) (*userprovisioningpb.StaticHostUser, error) { + if err := services.ValidateStaticHostUser(in); err != nil { + return nil, trace.Wrap(err) + } out, err := s.svc.UpsertResource(ctx, in) return out, trace.Wrap(err) } diff --git a/lib/services/local/statichostuser_test.go b/lib/services/local/statichostuser_test.go new file mode 100644 index 0000000000000..b7500d4f2fefc --- /dev/null +++ b/lib/services/local/statichostuser_test.go @@ -0,0 +1,275 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package local + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/jonboulle/clockwork" + "github.com/mailgun/holster/v3/clock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/timestamppb" + + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + userprovisioningpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/userprovisioning/v1" + "github.com/gravitational/teleport/api/types/userprovisioning" + "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/services" +) + +func TestCreateStaticHostUser(t *testing.T) { + t.Parallel() + + ctx := context.Background() + service := getStaticHostUserService(t) + + obj := getStaticHostUser(0) + + // first attempt should succeed + objOut, err := service.CreateStaticHostUser(ctx, obj) + require.NoError(t, err) + require.Equal(t, obj, objOut) + + // second attempt should fail, object already exists + _, err = service.CreateStaticHostUser(ctx, obj) + require.Error(t, err) +} + +func TestUpsertStaticHostUser(t *testing.T) { + t.Parallel() + + ctx := context.Background() + service := getStaticHostUserService(t) + + obj := getStaticHostUser(0) + + // first attempt should succeed + objOut, err := service.UpsertStaticHostUser(ctx, obj) + require.NoError(t, err) + require.Equal(t, obj, objOut) + + // second attempt should also succeed + objOut, err = service.UpsertStaticHostUser(ctx, obj) + require.NoError(t, err) + require.Equal(t, obj, objOut) +} + +func TestGetStaticHostUser(t *testing.T) { + t.Parallel() + + ctx := context.Background() + service := getStaticHostUserService(t) + prepopulateStaticHostUsers(t, service, 1) + + tests := []struct { + name string + key string + assertErr assert.ErrorAssertionFunc + wantObj *userprovisioningpb.StaticHostUser + }{ + { + name: "object does not exist", + key: "dummy", + assertErr: assert.Error, + }, + { + name: "success", + key: getStaticHostUser(0).GetMetadata().GetName(), + assertErr: assert.NoError, + wantObj: getStaticHostUser(0), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + obj, err := service.GetStaticHostUser(ctx, tc.key) + tc.assertErr(t, err) + if tc.wantObj == nil { + assert.Nil(t, obj) + } else { + cmpOpts := []cmp.Option{ + protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"), + protocmp.Transform(), + } + require.Equal(t, "", cmp.Diff(tc.wantObj, obj, cmpOpts...)) + } + }) + } +} + +func TestUpdateStaticHostUser(t *testing.T) { + t.Parallel() + + ctx := context.Background() + service := getStaticHostUserService(t) + prepopulateStaticHostUsers(t, service, 1) + + expiry := timestamppb.New(clock.Now().Add(30 * time.Minute)) + + // Fetch the object from the backend so the revision is populated. + key := getStaticHostUser(0).GetMetadata().GetName() + obj, err := service.GetStaticHostUser(ctx, key) + require.NoError(t, err) + obj.Metadata.Expires = expiry + + objUpdated, err := service.UpdateStaticHostUser(ctx, obj) + require.NoError(t, err) + require.Equal(t, expiry, objUpdated.Metadata.Expires) + + objFresh, err := service.GetStaticHostUser(ctx, key) + require.NoError(t, err) + require.Equal(t, expiry, objFresh.Metadata.Expires) +} + +func TestUpdateStaticHostUserMissingRevision(t *testing.T) { + t.Parallel() + + ctx := context.Background() + service := getStaticHostUserService(t) + prepopulateStaticHostUsers(t, service, 1) + + expiry := timestamppb.New(clock.Now().Add(30 * time.Minute)) + + obj := getStaticHostUser(0) + obj.Metadata.Expires = expiry + + // Update should be rejected as the revision is missing. + _, err := service.UpdateStaticHostUser(ctx, obj) + require.Error(t, err) +} + +func TestDeleteStaticHostUser(t *testing.T) { + t.Parallel() + + ctx := context.Background() + service := getStaticHostUserService(t) + prepopulateStaticHostUsers(t, service, 1) + + tests := []struct { + name string + key string + assertErr require.ErrorAssertionFunc + }{ + { + name: "object does not exist", + key: "dummy", + assertErr: require.Error, + }, + { + name: "success", + key: getStaticHostUser(0).GetMetadata().GetName(), + assertErr: require.NoError, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := service.DeleteStaticHostUser(ctx, tc.key) + tc.assertErr(t, err) + }) + } +} + +func TestListStaticHostUsers(t *testing.T) { + t.Parallel() + ctx := context.Background() + counts := []int{0, 1, 5, 10} + + for _, count := range counts { + t.Run(fmt.Sprintf("count=%v", count), func(t *testing.T) { + service := getStaticHostUserService(t) + prepopulateStaticHostUsers(t, service, count) + + t.Run("one page", func(t *testing.T) { + // Fetch all objects. + elements, nextToken, err := service.ListStaticHostUsers(ctx, 200, "") + require.NoError(t, err) + require.Empty(t, nextToken) + require.Len(t, elements, count) + + for i := 0; i < count; i++ { + cmpOpts := []cmp.Option{ + protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"), + protocmp.Transform(), + } + require.Equal(t, "", cmp.Diff(getStaticHostUser(i), elements[i], cmpOpts...)) + } + }) + + t.Run("paginated", func(t *testing.T) { + // Fetch a paginated list of objects + elements := make([]*userprovisioningpb.StaticHostUser, 0) + nextToken := "" + for { + out, token, err := service.ListStaticHostUsers(ctx, 2, nextToken) + require.NoError(t, err) + nextToken = token + + elements = append(elements, out...) + if nextToken == "" { + break + } + } + + for i := 0; i < count; i++ { + cmpOpts := []cmp.Option{ + protocmp.IgnoreFields(&headerv1.Metadata{}, "revision"), + protocmp.Transform(), + } + require.Equal(t, "", cmp.Diff(getStaticHostUser(i), elements[i], cmpOpts...)) + } + }) + }) + } +} + +func getStaticHostUserService(t *testing.T) services.StaticHostUser { + backend, err := memory.New(memory.Config{ + Context: context.Background(), + Clock: clockwork.NewFakeClock(), + }) + require.NoError(t, err) + + service, err := NewStaticHostUserService(backend) + require.NoError(t, err) + return service +} + +func getStaticHostUser(index int) *userprovisioningpb.StaticHostUser { + name := fmt.Sprintf("obj%v", index) + return userprovisioning.NewStaticHostUser(name, &userprovisioningpb.StaticHostUserSpec{ + Login: "alice", + Groups: []string{"foo", "bar"}, + Uid: "1234", + Gid: "1234", + }) +} + +func prepopulateStaticHostUsers(t *testing.T, service services.StaticHostUser, count int) { + for i := 0; i < count; i++ { + _, err := service.CreateStaticHostUser(context.Background(), getStaticHostUser(i)) + require.NoError(t, err) + } +} diff --git a/lib/services/statichostuser.go b/lib/services/statichostuser.go index 1fe2a67660a62..4214fc48665db 100644 --- a/lib/services/statichostuser.go +++ b/lib/services/statichostuser.go @@ -47,22 +47,13 @@ type StaticHostUser interface { // MarshalStaticHostUser marshals a StaticHostUser resource to JSON. func MarshalStaticHostUser(in *userprovisioningpb.StaticHostUser, opts ...MarshalOption) ([]byte, error) { - if err := ValidateStaticHostUser(in); err != nil { - return nil, trace.Wrap(err) - } return MarshalProtoResource(in, opts...) } // UnmarshalStaticHostUser unmarshals a StaticHostUser resource from JSON. func UnmarshalStaticHostUser(data []byte, opts ...MarshalOption) (*userprovisioningpb.StaticHostUser, error) { out, err := UnmarshalProtoResource[*userprovisioningpb.StaticHostUser](data, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - if err := ValidateStaticHostUser(out); err != nil { - return nil, trace.Wrap(err) - } - return out, nil + return out, trace.Wrap(err) } func isValidUidOrGid(s string) bool {