Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SCIM list provider users #3405

Merged
merged 6 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions api/scim.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package api

import "github.com/infrahq/infra/internal/validate"

type SCIMUserName struct {
GivenName string `json:"givenName"`
FamilyName string `json:"familyName"`
}

type SCIMUserEmail struct {
Primary bool `json:"primary"`
Value string `json:"value"`
}

func (r SCIMUserEmail) ValidationRules() []validate.ValidationRule {
return []validate.ValidationRule{
validate.Required("value", r.Value),
validate.Email("value", r.Value),
}
}

const UserSchema = "urn:ietf:params:scim:schemas:core:2.0:User"

type SCIMMetadata struct {
ResourceType string `json:"resourceType"`
}

type SCIMUser struct {
Schemas []string `json:"schemas"`
ID string `json:"id"`
UserName string `json:"userName"`
Name SCIMUserName `json:"name"`
Emails []SCIMUserEmail `json:"emails"`
Active bool `json:"active"`
Meta SCIMMetadata `json:"meta"`
}

type SCIMParametersRequest struct {
// these pagination parameters must conform to the SCIM spec, rather than our standard pagination
StartIndex int `form:"startIndex"`
Count int `form:"count"`
}

func (r SCIMParametersRequest) ValidationRules() []validate.ValidationRule {
return []validate.ValidationRule{
validate.IntRule{
Name: "startIndex",
Value: r.StartIndex,
Min: validate.Int(0),
},
validate.IntRule{
Name: "count",
Value: r.Count,
Min: validate.Int(0),
},
}
}

const ListResponseSchema = "urn:ietf:params:scim:api:messages:2.0:ListResponse"

type ListProviderUsersResponse struct {
Schemas []string `json:"schemas"`
TotalResults int `json:"totalResults"`
Resources []SCIMUser `json:"Resources"` // intentionally capitalized
StartIndex int `json:"startIndex"`
ItemsPerPage int `json:"itemsPerPage"`
}
22 changes: 22 additions & 0 deletions internal/access/scim.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package access

import (
"fmt"

"github.com/gin-gonic/gin"

"github.com/infrahq/infra/internal/server/data"
"github.com/infrahq/infra/internal/server/models"
)

func ListProviderUsers(c *gin.Context, p *data.SCIMParameters) ([]models.ProviderUser, error) {
ctx := GetRequestContext(c)
// IssuedFor will match no providers if called with a regular access key. When called with
// a SCIM access key it will be the provider ID. This effectively restricts this endpoint to
// only SCIM access keys.
users, err := data.ListProviderUsers(ctx.DBTxn, ctx.Authenticated.AccessKey.IssuedFor, p)
if err != nil {
return nil, fmt.Errorf("list provider users: %w", err)
}
return users, nil
}
19 changes: 19 additions & 0 deletions internal/server/data/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func migrations() []*migrator.Migration {
removeDotFromDestinationName(),
destinationNameUnique(),
removeDeletedIdentityProviderUsers(),
addProviderUserSCIMFields(),
// next one here
}
}
Expand Down Expand Up @@ -777,3 +778,21 @@ func removeDeletedIdentityProviderUsers() *migrator.Migration {
},
}
}

func addProviderUserSCIMFields() *migrator.Migration {
return &migrator.Migration{
ID: "2022-09-28T13:00",
Migrate: func(tx migrator.DB) error {
stmt := `
ALTER TABLE provider_users
ADD COLUMN IF NOT EXISTS given_name text DEFAULT '',
ADD COLUMN IF NOT EXISTS family_name text DEFAULT '',
ADD COLUMN IF NOT EXISTS active boolean DEFAULT true;

CREATE UNIQUE INDEX IF NOT EXISTS idx_emails_providers ON provider_users (email, provider_id);
`
_, err := tx.Exec(stmt)
return err
},
}
}
6 changes: 6 additions & 0 deletions internal/server/data/migrations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,12 @@ DELETE FROM settings WHERE id=24567;
assert.Equal(t, count, 0)
},
},
{
label: testCaseLine("2022-09-28T13:00"),
expected: func(t *testing.T, db WriteTxn) {
// schema changes are tested with schema comparison
},
},
}

ids := make(map[string]struct{}, len(testCases))
Expand Down
2 changes: 1 addition & 1 deletion internal/server/data/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func DeleteProviders(db GormTxn, selectors ...SelectorFunc) error {
for _, p := range toDelete {
ids = append(ids, p.ID)

providerUsers, err := listProviderUsers(db, p.ID)
providerUsers, err := ListProviderUsers(db, p.ID, nil)
if err != nil {
return fmt.Errorf("listing provider users: %w", err)
}
Expand Down
51 changes: 45 additions & 6 deletions internal/server/data/provideruser.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ func (p providerUserTable) Table() string {
}

func (p providerUserTable) Columns() []string {
return []string{"identity_id", "provider_id", "email", "groups", "last_update", "redirect_url", "access_token", "refresh_token", "expires_at"}
return []string{"identity_id", "provider_id", "email", "groups", "last_update", "redirect_url", "access_token", "refresh_token", "expires_at", "given_name", "family_name", "active"}
}

func (p providerUserTable) Values() []any {
return []any{p.IdentityID, p.ProviderID, p.Email, p.Groups, p.LastUpdate, p.RedirectURL, p.AccessToken, p.RefreshToken, p.ExpiresAt}
return []any{p.IdentityID, p.ProviderID, p.Email, p.Groups, p.LastUpdate, p.RedirectURL, p.AccessToken, p.RefreshToken, p.ExpiresAt, p.GivenName, p.FamilyName, p.Active}
}

func (p *providerUserTable) ScanFields() []any {
return []any{&p.IdentityID, &p.ProviderID, &p.Email, &p.Groups, &p.LastUpdate, &p.RedirectURL, &p.AccessToken, &p.RefreshToken, &p.ExpiresAt}
return []any{&p.IdentityID, &p.ProviderID, &p.Email, &p.Groups, &p.LastUpdate, &p.RedirectURL, &p.AccessToken, &p.RefreshToken, &p.ExpiresAt, &p.GivenName, &p.FamilyName, &p.Active}
}

func (p *providerUserTable) OnInsert() error {
Expand Down Expand Up @@ -66,6 +66,7 @@ func CreateProviderUser(db GormTxn, provider *models.Provider, ident *models.Ide
IdentityID: ident.ID,
Email: ident.Name,
LastUpdate: time.Now().UTC(),
Active: true,
}
if err := validateProviderUser(pu); err != nil {
return nil, err
Expand Down Expand Up @@ -93,20 +94,51 @@ func UpdateProviderUser(tx WriteTxn, providerUser *models.ProviderUser) error {
return handleError(err)
}

func listProviderUsers(tx ReadTxn, providerID uid.ID) ([]models.ProviderUser, error) {
func ListProviderUsers(tx ReadTxn, providerID uid.ID, p *SCIMParameters) ([]models.ProviderUser, error) {
table := &providerUserTable{}
query := querybuilder.New("SELECT")
query.B(columnsForSelect(table))
if p != nil {
query.B(", count(*) OVER()")
}
query.B("FROM")
query.B(table.Table())
query.B("INNER JOIN providers ON provider_users.provider_id = providers.id AND providers.organization_id = ?", tx.OrganizationID())
query.B("WHERE provider_id = ?", providerID)

query.B("ORDER BY email ASC")

if p != nil {
// apply scim parameters
if p.Count != 0 {
query.B("LIMIT ?", p.Count)
}
if p.StartIndex > 0 {
offset := p.StartIndex - 1 // start index begins at 1, not 0
query.B("OFFSET ?", offset)
}
}

rows, err := tx.Query(query.String(), query.Args...)
if err != nil {
return nil, err
}
return scanRows(rows, func(pu *models.ProviderUser) []any {
return (*providerUserTable)(pu).ScanFields()
result, err := scanRows(rows, func(pu *models.ProviderUser) []any {
fields := (*providerUserTable)(pu).ScanFields()
if p != nil {
fields = append(fields, &p.TotalCount)
}
return fields
})
if err != nil {
return nil, fmt.Errorf("scan provider users: %w", err)
}

if p != nil && p.Count == 0 {
p.Count = p.TotalCount
}

return result, nil
}

type DeleteProviderUsersOptions struct {
Expand Down Expand Up @@ -179,3 +211,10 @@ func SyncProviderUser(ctx context.Context, tx GormTxn, user *models.Identity, pr

return nil
}

type SCIMParameters struct {
Count int // the number of items to return
StartIndex int // the offset to start counting from
TotalCount int // the total number of items that match the query
// TODO: filter query param
}
119 changes: 119 additions & 0 deletions internal/server/data/provideruser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/infrahq/infra/internal/server/models"
"github.com/infrahq/infra/internal/server/providers"
"github.com/infrahq/infra/uid"
)

// mockOIDC is a mock oidc identity provider
Expand Down Expand Up @@ -106,6 +107,7 @@ func TestSyncProviderUser(t *testing.T) {
AccessToken: "any-access-token",
ExpiresAt: time.Now().Add(time.Hour).UTC(),
LastUpdate: time.Now().UTC(),
Active: true,
}

cmpProviderUser := cmp.Options{
Expand Down Expand Up @@ -166,6 +168,7 @@ func TestSyncProviderUser(t *testing.T) {
AccessToken: "any-access-token",
ExpiresAt: time.Now().Add(5 * time.Minute).UTC(),
LastUpdate: time.Now().UTC(),
Active: true,
}

cmpProviderUser := cmp.Options{
Expand Down Expand Up @@ -254,3 +257,119 @@ func TestDeleteProviderUser(t *testing.T) {
assert.NilError(t, err)
})
}

func TestListProviderUsers(t *testing.T) {
type testCase struct {
name string
setup func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int)
}

testCases := []testCase{
{
name: "list all provider users",
setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int) {
provider := &models.Provider{
Name: "mockta",
Kind: models.ProviderKindOkta,
}

err := CreateProvider(tx, provider)
assert.NilError(t, err)

pu := createTestProviderUser(t, tx, provider, "david@example.com")
return provider.ID, nil, []models.ProviderUser{pu}, 0
},
},
{
name: "list all provider users invalid provider ID",
setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int) {
provider := &models.Provider{
Name: "mockta",
Kind: models.ProviderKindOkta,
}

err := CreateProvider(tx, provider)
assert.NilError(t, err)

_ = createTestProviderUser(t, tx, provider, "david@example.com")
return 1234, nil, nil, 0
},
},
{
name: "limit less than total",
setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int) {
provider := &models.Provider{
Name: "mockta",
Kind: models.ProviderKindOkta,
}

err := CreateProvider(tx, provider)
assert.NilError(t, err)

pu := createTestProviderUser(t, tx, provider, "david@example.com")
_ = createTestProviderUser(t, tx, provider, "lucy@example.com")
return provider.ID, &SCIMParameters{Count: 1}, []models.ProviderUser{pu}, 2
},
},
{
name: "offset from start",
setup: func(t *testing.T, tx *Transaction) (providerID uid.ID, p *SCIMParameters, expected []models.ProviderUser, totalCount int) {
provider := &models.Provider{
Name: "mockta",
Kind: models.ProviderKindOkta,
}

err := CreateProvider(tx, provider)
assert.NilError(t, err)

pu1 := createTestProviderUser(t, tx, provider, "david@example.com")
pu2 := createTestProviderUser(t, tx, provider, "lucy@example.com")
return provider.ID, &SCIMParameters{StartIndex: 1}, []models.ProviderUser{pu1, pu2}, 2
},
},
}

runDBTests(t, func(t *testing.T, db *DB) {
org := &models.Organization{Name: "something", Domain: "example.com"}
assert.NilError(t, CreateOrganization(db, org))

// create some dummy data for another org to test multi-tenancy
stmt := `
INSERT INTO provider_users(identity_id, provider_id, email)
VALUES (?, ?, ?);
`
_, err := db.Exec(stmt, 123, 123, "otherorg@example.com")
assert.NilError(t, err)

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
tx := txnForTestCase(t, db, org.ID)

providerID, p, expected, totalCount := tc.setup(t, tx)

result, err := ListProviderUsers(tx, providerID, p)

assert.NilError(t, err)
assert.DeepEqual(t, result, expected, cmpTimeWithDBPrecision)
if p != nil {
assert.Equal(t, p.TotalCount, totalCount)
}
})
}
})
}

func createTestProviderUser(t *testing.T, tx *Transaction, provider *models.Provider, userName string) models.ProviderUser {
user := &models.Identity{
Name: userName,
}
err := CreateIdentity(tx, user)
assert.NilError(t, err)

pu, err := CreateProviderUser(tx, provider, user)
assert.NilError(t, err)

pu.Groups = models.CommaSeparatedStrings{}

return *pu
}
Loading