diff --git a/internal/broker/broker.go b/internal/broker/broker.go index faca648b..8aff04e6 100644 --- a/internal/broker/broker.go +++ b/internal/broker/broker.go @@ -20,6 +20,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/google/uuid" + "github.com/ubuntu/authd-oidc-brokers/internal/consts" "github.com/ubuntu/authd-oidc-brokers/internal/providers" "github.com/ubuntu/authd-oidc-brokers/internal/providers/info" "github.com/ubuntu/decorate" @@ -199,7 +200,7 @@ func (b *Broker) connectToProvider(ctx context.Context) (authCfg authConfig, err oauthCfg := oauth2.Config{ ClientID: b.oidcCfg.ClientID, Endpoint: provider.Endpoint(), - Scopes: append([]string{oidc.ScopeOpenID, "profile", "email"}, b.providerInfo.AdditionalScopes()...), + Scopes: append(consts.DefaultScopes, b.providerInfo.AdditionalScopes()...), } return authConfig{provider: provider, oauth: oauthCfg}, nil @@ -453,6 +454,10 @@ func (b *Broker) handleIsAuthenticated(ctx context.Context, session *sessionInfo return AuthRetry, errorMessage{Message: "could not authenticate user remotely"} } + if err = b.providerInfo.CheckTokenScopes(t); err != nil { + slog.Warn(err.Error()) + } + rawIDToken, ok := t.Extra("id_token").(string) if !ok { slog.Error("could not get ID token") diff --git a/internal/broker/broker_test.go b/internal/broker/broker_test.go index 012c3da1..51c85be5 100644 --- a/internal/broker/broker_test.go +++ b/internal/broker/broker_test.go @@ -360,6 +360,7 @@ func TestIsAuthenticated(t *testing.T) { badFirstKey bool customHandlers map[string]testutils.ProviderHandler + address string wantSecondCall bool secondChallenge string @@ -388,6 +389,14 @@ func TestIsAuthenticated(t *testing.T) { "/.well-known/openid-configuration": testutils.UnavailableHandler(), }, }, + "Authenticating still allowed if token is missing scopes": { + firstChallenge: "-", + wantSecondCall: true, + customHandlers: map[string]testutils.ProviderHandler{ + "/token": testutils.DefaultTokenHandler("http://127.0.0.1:31313", []string{}), + }, + address: "127.0.0.1:31313", + }, "Error when authentication data is invalid": {invalidAuthData: true}, "Error when challenge can not be decrypted": {firstMode: "password", badFirstKey: true}, @@ -466,7 +475,7 @@ func TestIsAuthenticated(t *testing.T) { for path, handler := range tc.customHandlers { opts = append(opts, testutils.WithHandler(path, handler)) } - p, cleanup := testutils.StartMockProvider("", opts...) + p, cleanup := testutils.StartMockProvider(tc.address, opts...) t.Cleanup(cleanup) provider = p } diff --git a/internal/broker/testdata/TestIsAuthenticated/golden/authenticating_still_allowed_if_token_is_missing_scopes/cache/provider_url/test-user@email.com.cache b/internal/broker/testdata/TestIsAuthenticated/golden/authenticating_still_allowed_if_token_is_missing_scopes/cache/provider_url/test-user@email.com.cache new file mode 100644 index 00000000..80ab7838 --- /dev/null +++ b/internal/broker/testdata/TestIsAuthenticated/golden/authenticating_still_allowed_if_token_is_missing_scopes/cache/provider_url/test-user@email.com.cache @@ -0,0 +1 @@ +Definitely an encrypted token \ No newline at end of file diff --git a/internal/broker/testdata/TestIsAuthenticated/golden/authenticating_still_allowed_if_token_is_missing_scopes/first_call b/internal/broker/testdata/TestIsAuthenticated/golden/authenticating_still_allowed_if_token_is_missing_scopes/first_call new file mode 100644 index 00000000..d0887a13 --- /dev/null +++ b/internal/broker/testdata/TestIsAuthenticated/golden/authenticating_still_allowed_if_token_is_missing_scopes/first_call @@ -0,0 +1,3 @@ +access: next +data: '{}' +err: diff --git a/internal/broker/testdata/TestIsAuthenticated/golden/authenticating_still_allowed_if_token_is_missing_scopes/second_call b/internal/broker/testdata/TestIsAuthenticated/golden/authenticating_still_allowed_if_token_is_missing_scopes/second_call new file mode 100644 index 00000000..3c4c96bc --- /dev/null +++ b/internal/broker/testdata/TestIsAuthenticated/golden/authenticating_still_allowed_if_token_is_missing_scopes/second_call @@ -0,0 +1,3 @@ +access: granted +data: '{"userinfo":{"name":"test-user@email.com","uuid":"test-user-id","dir":"/home/test-user@email.com","shell":"/usr/bin/bash","gecos":"test-user@email.com","groups":[{"name":"remote-group","ugid":"12345"},{"name":"linux-local-group","ugid":""}]}}' +err: diff --git a/internal/consts/oidc.go b/internal/consts/oidc.go new file mode 100644 index 00000000..c9cf3d0f --- /dev/null +++ b/internal/consts/oidc.go @@ -0,0 +1,9 @@ +package consts + +import "github.com/coreos/go-oidc/v3/oidc" + +var ( + // DefaultScopes contains the OIDC scopes that we require for all providers. + // Provider implementations can append additional scopes. + DefaultScopes = []string{oidc.ScopeOpenID, "profile", "email"} +) diff --git a/internal/providers/default.go b/internal/providers/default.go index 21514ecf..f3b406dc 100644 --- a/internal/providers/default.go +++ b/internal/providers/default.go @@ -8,5 +8,5 @@ import ( // CurrentProviderInfo returns a generic oidc provider implementation. func CurrentProviderInfo() ProviderInfoer { - return noprovider.NoProvider{} + return noprovider.New() } diff --git a/internal/providers/msentraid/export_test.go b/internal/providers/msentraid/export_test.go new file mode 100644 index 00000000..8b28309c --- /dev/null +++ b/internal/providers/msentraid/export_test.go @@ -0,0 +1,8 @@ +package msentraid + +import "strings" + +// AllExpectedScopes returns all the default expected scopes for a new provider. +func AllExpectedScopes() string { + return strings.Join(New().expectedScopes, " ") +} diff --git a/internal/providers/msentraid/msentraid.go b/internal/providers/msentraid/msentraid.go index 9ac5a568..d8a7aa8f 100644 --- a/internal/providers/msentraid/msentraid.go +++ b/internal/providers/msentraid/msentraid.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log/slog" + "slices" "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore" @@ -16,6 +17,7 @@ import ( msgraphauth "github.com/microsoftgraph/msgraph-sdk-go-core/authentication" msgraphgroups "github.com/microsoftgraph/msgraph-sdk-go/groups" msgraphmodels "github.com/microsoftgraph/msgraph-sdk-go/models" + "github.com/ubuntu/authd-oidc-brokers/internal/consts" "github.com/ubuntu/authd-oidc-brokers/internal/providers/info" "golang.org/x/oauth2" ) @@ -27,7 +29,16 @@ func init() { const localGroupPrefix = "linux-" // Provider is the Microsoft Entra ID provider implementation. -type Provider struct{} +type Provider struct { + expectedScopes []string +} + +// New returns a new MSEntraID provider. +func New() Provider { + return Provider{ + expectedScopes: append(consts.DefaultScopes, "GroupMember.Read.All", "User.Read"), + } +} // AdditionalScopes returns the generic scopes required by the EntraID provider. func (p Provider) AdditionalScopes() []string { @@ -39,6 +50,26 @@ func (p Provider) AuthOptions() []oauth2.AuthCodeOption { return []oauth2.AuthCodeOption{} } +// CheckTokenScopes checks if the token has the required scopes. +func (p Provider) CheckTokenScopes(token *oauth2.Token) error { + scopesStr, ok := token.Extra("scope").(string) + if !ok { + return fmt.Errorf("failed to cast token scopes to string: %v", token.Extra("scope")) + } + + scopes := strings.Split(scopesStr, " ") + var missingScopes []string + for _, s := range p.expectedScopes { + if !slices.Contains(scopes, s) { + missingScopes = append(missingScopes, s) + } + } + if len(missingScopes) > 0 { + return fmt.Errorf("missing required scopes: %s", strings.Join(missingScopes, ", ")) + } + return nil +} + // GetUserInfo is a no-op when no specific provider is in use. func (p Provider) GetUserInfo(ctx context.Context, accessToken *oauth2.Token, idToken *oidc.IDToken) (info.User, error) { userClaims, err := p.userClaims(idToken) diff --git a/internal/providers/msentraid/msentraid_test.go b/internal/providers/msentraid/msentraid_test.go new file mode 100644 index 00000000..8d5e80d0 --- /dev/null +++ b/internal/providers/msentraid/msentraid_test.go @@ -0,0 +1,53 @@ +package msentraid_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/ubuntu/authd-oidc-brokers/internal/providers/msentraid" + "golang.org/x/oauth2" +) + +func TestNew(t *testing.T) { + p := msentraid.New() + + require.NotEmpty(t, p, "New should return a non-empty provider") +} + +func TestCheckTokenScopes(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + scopes string + noExtraScopeField bool + + wantErr bool + }{ + "success when checking all scopes are present": {scopes: msentraid.AllExpectedScopes()}, + "success even if getting more scopes than requested": {scopes: msentraid.AllExpectedScopes() + " extra-scope"}, + + "error with missing scopes": {scopes: "profile email", wantErr: true}, + "error without extra scope field": {noExtraScopeField: true, wantErr: true}, + "error with empty scopes": {scopes: "", wantErr: true}, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + p := msentraid.New() + + token := &oauth2.Token{} + if !tc.noExtraScopeField { + token = token.WithExtra(map[string]interface{}{"scope": any(tc.scopes)}) + } + + err := p.CheckTokenScopes(token) + if tc.wantErr { + require.Error(t, err, "CheckTokenScopes should return an error") + return + } + + require.NoError(t, err, "CheckTokenScopes should not return an error") + }) + } +} diff --git a/internal/providers/noprovider/noprovider.go b/internal/providers/noprovider/noprovider.go index 6f2ac9c5..1c64079b 100644 --- a/internal/providers/noprovider/noprovider.go +++ b/internal/providers/noprovider/noprovider.go @@ -14,6 +14,17 @@ import ( // NoProvider is a generic OIDC provider. type NoProvider struct{} +// New returns a new NoProvider. +func New() NoProvider { + return NoProvider{} +} + +// CheckTokenScopes should check the token scopes, but we're not sure +// if there is a generic way to do this, so for now it's a no-op. +func (p NoProvider) CheckTokenScopes(token *oauth2.Token) error { + return nil +} + // AdditionalScopes returns the generic scopes required by the provider. func (p NoProvider) AdditionalScopes() []string { return []string{oidc.ScopeOfflineAccess} diff --git a/internal/providers/providers.go b/internal/providers/providers.go index 94dc58c5..e2b29fd7 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -13,6 +13,7 @@ import ( type ProviderInfoer interface { AdditionalScopes() []string AuthOptions() []oauth2.AuthCodeOption + CheckTokenScopes(token *oauth2.Token) error CurrentAuthenticationModesOffered( sessionMode string, supportedAuthModes map[string]string, diff --git a/internal/providers/withmsentraid.go b/internal/providers/withmsentraid.go index a45dcc8b..3d29d7e9 100644 --- a/internal/providers/withmsentraid.go +++ b/internal/providers/withmsentraid.go @@ -8,5 +8,5 @@ import ( // CurrentProviderInfo returns a Microsoft Entra ID provider implementation. func CurrentProviderInfo() ProviderInfoer { - return msentraid.Provider{} + return msentraid.New() } diff --git a/internal/testutils/provider.go b/internal/testutils/provider.go index 221a49aa..be9d370b 100644 --- a/internal/testutils/provider.go +++ b/internal/testutils/provider.go @@ -8,11 +8,13 @@ import ( "net" "net/http" "net/http/httptest" + "slices" "strconv" "strings" "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/ubuntu/authd-oidc-brokers/internal/consts" "github.com/ubuntu/authd-oidc-brokers/internal/providers/info" "golang.org/x/oauth2" ) @@ -52,7 +54,7 @@ func StartMockProvider(address string, args ...OptionProvider) (*httptest.Server handlers: map[string]ProviderHandler{ "/.well-known/openid-configuration": DefaultOpenIDHandler(server.URL), "/device_auth": DefaultDeviceAuthHandler(), - "/token": DefaultTokenHandler(server.URL), + "/token": DefaultTokenHandler(server.URL, consts.DefaultScopes), }, } for _, arg := range args { @@ -128,7 +130,7 @@ func DefaultDeviceAuthHandler() ProviderHandler { } // DefaultTokenHandler returns a handler that returns a default token response. -func DefaultTokenHandler(serverURL string) ProviderHandler { +func DefaultTokenHandler(serverURL string, scopes []string) ProviderHandler { return func(w http.ResponseWriter, r *http.Request) { // Mimics user going through auth process time.Sleep(2 * time.Second) @@ -151,10 +153,10 @@ func DefaultTokenHandler(serverURL string) ProviderHandler { "access_token": "accesstoken", "refresh_token": "refreshtoken", "token_type": "Bearer", - "scope": "offline_access openid profile", + "scope": "%s", "expires_in": 3600, "id_token": "%s" - }`, rawToken) + }`, strings.Join(scopes, " "), rawToken) w.Header().Add("Content-Type", "application/json") _, err := w.Write([]byte(response)) @@ -225,6 +227,26 @@ type MockProviderInfoer struct { GroupsErr bool } +// CheckTokenScopes checks if the token has the required scopes. +func (p *MockProviderInfoer) CheckTokenScopes(token *oauth2.Token) error { + scopesStr, ok := token.Extra("scope").(string) + if !ok { + return fmt.Errorf("failed to cast token scopes to string: %v", token.Extra("scope")) + } + + scopes := strings.Split(scopesStr, " ") + var missingScopes []string + for _, s := range consts.DefaultScopes { + if !slices.Contains(scopes, s) { + missingScopes = append(missingScopes, s) + } + } + if len(missingScopes) > 0 { + return fmt.Errorf("missing required scopes: %s", strings.Join(missingScopes, ", ")) + } + return nil +} + // AdditionalScopes returns the additional scopes required by the provider. func (p *MockProviderInfoer) AdditionalScopes() []string { if p.Scopes != nil {