diff --git a/flyteadmin/auth/authzserver/claims_verifier.go b/flyteadmin/auth/authzserver/claims_verifier.go new file mode 100644 index 0000000000..7887ee084d --- /dev/null +++ b/flyteadmin/auth/authzserver/claims_verifier.go @@ -0,0 +1,68 @@ +package authzserver + +import ( + "encoding/json" + "fmt" + + "github.com/ory/x/jwtx" + "k8s.io/apimachinery/pkg/util/sets" + + "github.com/flyteorg/flyteadmin/auth" + "github.com/flyteorg/flyteadmin/auth/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" +) + +func verifyClaims(expectedAudience sets.String, claimsRaw map[string]interface{}) (interfaces.IdentityContext, error) { + claims := jwtx.ParseMapStringInterfaceClaims(claimsRaw) + + foundAudIndex := -1 + for audIndex, aud := range claims.Audience { + if expectedAudience.Has(aud) { + foundAudIndex = audIndex + break + } + } + + if foundAudIndex < 0 { + return nil, fmt.Errorf("invalid audience [%v]", claims) + } + + userInfo := &service.UserInfoResponse{} + if userInfoClaim, found := claimsRaw[UserIDClaim]; found && userInfoClaim != nil { + userInfoRaw := userInfoClaim.(map[string]interface{}) + raw, err := json.Marshal(userInfoRaw) + if err != nil { + return nil, err + } + + if err = json.Unmarshal(raw, userInfo); err != nil { + return nil, fmt.Errorf("failed to unmarshal user info claim into UserInfo type. Error: %w", err) + } + } + + clientID := "" + if clientIDClaim, found := claimsRaw[ClientIDClaim]; found { + clientID = clientIDClaim.(string) + } + + scopes := sets.NewString() + if scopesClaim, found := claimsRaw[ScopeClaim]; found { + + switch sct := scopesClaim.(type) { + case []interface{}: + scopes = sets.NewString(interfaceSliceToStringSlice(sct)...) + case string: + sets.NewString(fmt.Sprintf("%v", scopesClaim)) + default: + return nil, fmt.Errorf("failed getting scope claims due to unknown type %T with value %v", sct, sct) + } + } + + // If this is a user-only access token with no scopes defined then add `all` scope by default because it's equivalent + // to having a user's login cookie or an ID Token as means of accessing the service. + if len(clientID) == 0 && scopes.Len() == 0 { + scopes.Insert(auth.ScopeAll) + } + + return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo, claimsRaw), nil +} diff --git a/flyteadmin/auth/authzserver/claims_verifier_test.go b/flyteadmin/auth/authzserver/claims_verifier_test.go new file mode 100644 index 0000000000..db5b5aee7d --- /dev/null +++ b/flyteadmin/auth/authzserver/claims_verifier_test.go @@ -0,0 +1,98 @@ +package authzserver + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/util/sets" +) + +func Test_verifyClaims(t *testing.T) { + t.Run("Empty claims, fail", func(t *testing.T) { + _, err := verifyClaims(sets.NewString("https://myserver"), map[string]interface{}{}) + assert.Error(t, err) + }) + + t.Run("All filled", func(t *testing.T) { + identityCtx, err := verifyClaims(sets.NewString("https://myserver"), map[string]interface{}{ + "aud": []string{"https://myserver"}, + "user_info": map[string]interface{}{ + "preferred_name": "John Doe", + }, + "sub": "123", + "client_id": "my-client", + "scp": []interface{}{"all", "offline"}, + }) + + assert.NoError(t, err) + assert.Equal(t, sets.NewString("all", "offline"), identityCtx.Scopes()) + assert.Equal(t, "my-client", identityCtx.AppID()) + assert.Equal(t, "123", identityCtx.UserID()) + assert.Equal(t, "https://myserver", identityCtx.Audience()) + }) + + t.Run("Multiple audience", func(t *testing.T) { + identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"), + map[string]interface{}{ + "aud": []string{"https://myserver"}, + "user_info": map[string]interface{}{ + "preferred_name": "John Doe", + }, + "sub": "123", + "client_id": "my-client", + "scp": []interface{}{"all", "offline"}, + }) + + assert.NoError(t, err) + assert.Equal(t, "https://myserver", identityCtx.Audience()) + }) + + t.Run("No matching audience", func(t *testing.T) { + _, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"), + map[string]interface{}{ + "aud": []string{"https://myserver3"}, + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid audience") + }) + + t.Run("Use first matching audience", func(t *testing.T) { + identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2", "https://myserver3"), + map[string]interface{}{ + "aud": []string{"https://myserver", "https://myserver2"}, + "user_info": map[string]interface{}{ + "preferred_name": "John Doe", + }, + "sub": "123", + "client_id": "my-client", + "scp": []interface{}{"all", "offline"}, + }) + + assert.NoError(t, err) + assert.Equal(t, "https://myserver", identityCtx.Audience()) + }) + + t.Run("String scope", func(t *testing.T) { + identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"), + map[string]interface{}{ + "aud": []string{"https://myserver"}, + "scp": "all", + }) + + assert.NoError(t, err) + assert.Equal(t, "https://myserver", identityCtx.Audience()) + assert.Equal(t, sets.NewString("all"), identityCtx.Scopes()) + }) + t.Run("unknown scope", func(t *testing.T) { + identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"), + map[string]interface{}{ + "aud": []string{"https://myserver"}, + "scp": 1, + }) + + assert.Error(t, err) + assert.Nil(t, identityCtx) + assert.Equal(t, "failed getting scope claims due to unknown type int with value 1", err.Error()) + }) +} diff --git a/flyteadmin/auth/authzserver/provider.go b/flyteadmin/auth/authzserver/provider.go index 2695897bf1..f20ebcbe06 100644 --- a/flyteadmin/auth/authzserver/provider.go +++ b/flyteadmin/auth/authzserver/provider.go @@ -5,7 +5,6 @@ import ( "crypto/rsa" "crypto/x509" "encoding/base64" - "encoding/json" "encoding/pem" "fmt" "time" @@ -17,8 +16,6 @@ import ( "github.com/lestrrat-go/jwx/jwk" - "github.com/ory/x/jwtx" - "github.com/flyteorg/flyteadmin/auth/interfaces" "github.com/flyteorg/flyteadmin/auth" @@ -131,53 +128,6 @@ func (p Provider) ValidateAccessToken(ctx context.Context, expectedAudience, tok return verifyClaims(sets.NewString(expectedAudience), claimsRaw) } -func verifyClaims(expectedAudience sets.String, claimsRaw map[string]interface{}) (interfaces.IdentityContext, error) { - claims := jwtx.ParseMapStringInterfaceClaims(claimsRaw) - - foundAudIndex := -1 - for audIndex, aud := range claims.Audience { - if expectedAudience.Has(aud) { - foundAudIndex = audIndex - break - } - } - - if foundAudIndex < 0 { - return nil, fmt.Errorf("invalid audience [%v]", claims) - } - - userInfo := &service.UserInfoResponse{} - if userInfoClaim, found := claimsRaw[UserIDClaim]; found && userInfoClaim != nil { - userInfoRaw := userInfoClaim.(map[string]interface{}) - raw, err := json.Marshal(userInfoRaw) - if err != nil { - return nil, err - } - - if err = json.Unmarshal(raw, userInfo); err != nil { - return nil, fmt.Errorf("failed to unmarshal user info claim into UserInfo type. Error: %w", err) - } - } - - clientID := "" - if clientIDClaim, found := claimsRaw[ClientIDClaim]; found { - clientID = clientIDClaim.(string) - } - - scopes := sets.NewString() - if scopesClaim, found := claimsRaw[ScopeClaim]; found { - scopes = sets.NewString(interfaceSliceToStringSlice(scopesClaim.([]interface{}))...) - } - - // If this is a user-only access token with no scopes defined then add `all` scope by default because it's equivalent - // to having a user's login cookie or an ID Token as means of accessing the service. - if len(clientID) == 0 && scopes.Len() == 0 { - scopes.Insert(auth.ScopeAll) - } - - return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo, claimsRaw), nil -} - // NewProvider creates a new OAuth2 Provider that is able to do OAuth 2-legged and 3-legged flows. It'll lookup // config.SecretNameClaimSymmetricKey and config.SecretNameTokenSigningRSAKey secrets from the secret manager to use to // sign and generate hashes for tokens. The RSA Private key is expected to be in PEM format with the public key embedded. diff --git a/flyteadmin/auth/authzserver/provider_test.go b/flyteadmin/auth/authzserver/provider_test.go index 9ebfd32a83..98806b98a2 100644 --- a/flyteadmin/auth/authzserver/provider_test.go +++ b/flyteadmin/auth/authzserver/provider_test.go @@ -11,16 +11,15 @@ import ( "testing" "time" - "k8s.io/apimachinery/pkg/util/sets" - jwtgo "github.com/golang-jwt/jwt/v4" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" + "github.com/stretchr/testify/assert" + "github.com/flyteorg/flyteadmin/auth" "github.com/flyteorg/flyteadmin/auth/config" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks" - "github.com/stretchr/testify/assert" ) func newMockProvider(t testing.TB) (Provider, auth.SecretsSet) { @@ -213,71 +212,3 @@ func TestProvider_ValidateAccessToken(t *testing.T) { assert.False(t, identity.IsEmpty()) }) } - -func Test_verifyClaims(t *testing.T) { - t.Run("Empty claims, fail", func(t *testing.T) { - _, err := verifyClaims(sets.NewString("https://myserver"), map[string]interface{}{}) - assert.Error(t, err) - }) - - t.Run("All filled", func(t *testing.T) { - identityCtx, err := verifyClaims(sets.NewString("https://myserver"), map[string]interface{}{ - "aud": []string{"https://myserver"}, - "user_info": map[string]interface{}{ - "preferred_name": "John Doe", - }, - "sub": "123", - "client_id": "my-client", - "scp": []interface{}{"all", "offline"}, - }) - - assert.NoError(t, err) - assert.Equal(t, sets.NewString("all", "offline"), identityCtx.Scopes()) - assert.Equal(t, "my-client", identityCtx.AppID()) - assert.Equal(t, "123", identityCtx.UserID()) - assert.Equal(t, "https://myserver", identityCtx.Audience()) - }) - - t.Run("Multiple audience", func(t *testing.T) { - identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"), - map[string]interface{}{ - "aud": []string{"https://myserver"}, - "user_info": map[string]interface{}{ - "preferred_name": "John Doe", - }, - "sub": "123", - "client_id": "my-client", - "scp": []interface{}{"all", "offline"}, - }) - - assert.NoError(t, err) - assert.Equal(t, "https://myserver", identityCtx.Audience()) - }) - - t.Run("No matching audience", func(t *testing.T) { - _, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"), - map[string]interface{}{ - "aud": []string{"https://myserver3"}, - }) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid audience") - }) - - t.Run("Use first matching audience", func(t *testing.T) { - identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2", "https://myserver3"), - map[string]interface{}{ - "aud": []string{"https://myserver", "https://myserver2"}, - "user_info": map[string]interface{}{ - "preferred_name": "John Doe", - }, - "sub": "123", - "client_id": "my-client", - "scp": []interface{}{"all", "offline"}, - }) - - assert.NoError(t, err) - assert.Equal(t, "https://myserver", identityCtx.Audience()) - }) - -}