Skip to content

Commit

Permalink
Added type assertion check on scp claims which are returned as string…
Browse files Browse the repository at this point in the history
… in Azure AD (flyteorg#471)

* Added type assertion check on scp claims which are returned as string in Azure AD

Signed-off-by: Prafulla Mahindrakar <prafulla.mahindrakar@gmail.com>

* Moved claim_verfier and added a unit test for string scope

Signed-off-by: Prafulla Mahindrakar <prafulla.mahindrakar@gmail.com>

* added another check for string type for scp claims and fail otherwise

Signed-off-by: Prafulla Mahindrakar <prafulla.mahindrakar@gmail.com>

Signed-off-by: Prafulla Mahindrakar <prafulla.mahindrakar@gmail.com>
  • Loading branch information
pmahindrakar-oss authored Sep 15, 2022
1 parent 098939d commit 4534fa0
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 121 deletions.
68 changes: 68 additions & 0 deletions flyteadmin/auth/authzserver/claims_verifier.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package authzserver

import (
"encoding/json"
"fmt"

"github.com/ory/x/jwtx"
"k8s.io/apimachinery/pkg/util/sets"

"github.com/flyteorg/flyteadmin/auth"
"github.com/flyteorg/flyteadmin/auth/interfaces"
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"
)

func verifyClaims(expectedAudience sets.String, claimsRaw map[string]interface{}) (interfaces.IdentityContext, error) {
claims := jwtx.ParseMapStringInterfaceClaims(claimsRaw)

foundAudIndex := -1
for audIndex, aud := range claims.Audience {
if expectedAudience.Has(aud) {
foundAudIndex = audIndex
break
}
}

if foundAudIndex < 0 {
return nil, fmt.Errorf("invalid audience [%v]", claims)
}

userInfo := &service.UserInfoResponse{}
if userInfoClaim, found := claimsRaw[UserIDClaim]; found && userInfoClaim != nil {
userInfoRaw := userInfoClaim.(map[string]interface{})
raw, err := json.Marshal(userInfoRaw)
if err != nil {
return nil, err
}

if err = json.Unmarshal(raw, userInfo); err != nil {
return nil, fmt.Errorf("failed to unmarshal user info claim into UserInfo type. Error: %w", err)
}
}

clientID := ""
if clientIDClaim, found := claimsRaw[ClientIDClaim]; found {
clientID = clientIDClaim.(string)
}

scopes := sets.NewString()
if scopesClaim, found := claimsRaw[ScopeClaim]; found {

switch sct := scopesClaim.(type) {
case []interface{}:
scopes = sets.NewString(interfaceSliceToStringSlice(sct)...)
case string:
sets.NewString(fmt.Sprintf("%v", scopesClaim))
default:
return nil, fmt.Errorf("failed getting scope claims due to unknown type %T with value %v", sct, sct)
}
}

// If this is a user-only access token with no scopes defined then add `all` scope by default because it's equivalent
// to having a user's login cookie or an ID Token as means of accessing the service.
if len(clientID) == 0 && scopes.Len() == 0 {
scopes.Insert(auth.ScopeAll)
}

return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo, claimsRaw), nil
}
98 changes: 98 additions & 0 deletions flyteadmin/auth/authzserver/claims_verifier_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package authzserver

import (
"testing"

"github.com/stretchr/testify/assert"
"k8s.io/apimachinery/pkg/util/sets"
)

func Test_verifyClaims(t *testing.T) {
t.Run("Empty claims, fail", func(t *testing.T) {
_, err := verifyClaims(sets.NewString("https://myserver"), map[string]interface{}{})
assert.Error(t, err)
})

t.Run("All filled", func(t *testing.T) {
identityCtx, err := verifyClaims(sets.NewString("https://myserver"), map[string]interface{}{
"aud": []string{"https://myserver"},
"user_info": map[string]interface{}{
"preferred_name": "John Doe",
},
"sub": "123",
"client_id": "my-client",
"scp": []interface{}{"all", "offline"},
})

assert.NoError(t, err)
assert.Equal(t, sets.NewString("all", "offline"), identityCtx.Scopes())
assert.Equal(t, "my-client", identityCtx.AppID())
assert.Equal(t, "123", identityCtx.UserID())
assert.Equal(t, "https://myserver", identityCtx.Audience())
})

t.Run("Multiple audience", func(t *testing.T) {
identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"),
map[string]interface{}{
"aud": []string{"https://myserver"},
"user_info": map[string]interface{}{
"preferred_name": "John Doe",
},
"sub": "123",
"client_id": "my-client",
"scp": []interface{}{"all", "offline"},
})

assert.NoError(t, err)
assert.Equal(t, "https://myserver", identityCtx.Audience())
})

t.Run("No matching audience", func(t *testing.T) {
_, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"),
map[string]interface{}{
"aud": []string{"https://myserver3"},
})

assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid audience")
})

t.Run("Use first matching audience", func(t *testing.T) {
identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2", "https://myserver3"),
map[string]interface{}{
"aud": []string{"https://myserver", "https://myserver2"},
"user_info": map[string]interface{}{
"preferred_name": "John Doe",
},
"sub": "123",
"client_id": "my-client",
"scp": []interface{}{"all", "offline"},
})

assert.NoError(t, err)
assert.Equal(t, "https://myserver", identityCtx.Audience())
})

t.Run("String scope", func(t *testing.T) {
identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"),
map[string]interface{}{
"aud": []string{"https://myserver"},
"scp": "all",
})

assert.NoError(t, err)
assert.Equal(t, "https://myserver", identityCtx.Audience())
assert.Equal(t, sets.NewString("all"), identityCtx.Scopes())
})
t.Run("unknown scope", func(t *testing.T) {
identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"),
map[string]interface{}{
"aud": []string{"https://myserver"},
"scp": 1,
})

assert.Error(t, err)
assert.Nil(t, identityCtx)
assert.Equal(t, "failed getting scope claims due to unknown type int with value 1", err.Error())
})
}
50 changes: 0 additions & 50 deletions flyteadmin/auth/authzserver/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"time"
Expand All @@ -17,8 +16,6 @@ import (

"github.com/lestrrat-go/jwx/jwk"

"github.com/ory/x/jwtx"

"github.com/flyteorg/flyteadmin/auth/interfaces"

"github.com/flyteorg/flyteadmin/auth"
Expand Down Expand Up @@ -131,53 +128,6 @@ func (p Provider) ValidateAccessToken(ctx context.Context, expectedAudience, tok
return verifyClaims(sets.NewString(expectedAudience), claimsRaw)
}

func verifyClaims(expectedAudience sets.String, claimsRaw map[string]interface{}) (interfaces.IdentityContext, error) {
claims := jwtx.ParseMapStringInterfaceClaims(claimsRaw)

foundAudIndex := -1
for audIndex, aud := range claims.Audience {
if expectedAudience.Has(aud) {
foundAudIndex = audIndex
break
}
}

if foundAudIndex < 0 {
return nil, fmt.Errorf("invalid audience [%v]", claims)
}

userInfo := &service.UserInfoResponse{}
if userInfoClaim, found := claimsRaw[UserIDClaim]; found && userInfoClaim != nil {
userInfoRaw := userInfoClaim.(map[string]interface{})
raw, err := json.Marshal(userInfoRaw)
if err != nil {
return nil, err
}

if err = json.Unmarshal(raw, userInfo); err != nil {
return nil, fmt.Errorf("failed to unmarshal user info claim into UserInfo type. Error: %w", err)
}
}

clientID := ""
if clientIDClaim, found := claimsRaw[ClientIDClaim]; found {
clientID = clientIDClaim.(string)
}

scopes := sets.NewString()
if scopesClaim, found := claimsRaw[ScopeClaim]; found {
scopes = sets.NewString(interfaceSliceToStringSlice(scopesClaim.([]interface{}))...)
}

// If this is a user-only access token with no scopes defined then add `all` scope by default because it's equivalent
// to having a user's login cookie or an ID Token as means of accessing the service.
if len(clientID) == 0 && scopes.Len() == 0 {
scopes.Insert(auth.ScopeAll)
}

return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo, claimsRaw), nil
}

// NewProvider creates a new OAuth2 Provider that is able to do OAuth 2-legged and 3-legged flows. It'll lookup
// config.SecretNameClaimSymmetricKey and config.SecretNameTokenSigningRSAKey secrets from the secret manager to use to
// sign and generate hashes for tokens. The RSA Private key is expected to be in PEM format with the public key embedded.
Expand Down
73 changes: 2 additions & 71 deletions flyteadmin/auth/authzserver/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@ import (
"testing"
"time"

"k8s.io/apimachinery/pkg/util/sets"

jwtgo "github.com/golang-jwt/jwt/v4"

"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service"

"github.com/stretchr/testify/assert"

"github.com/flyteorg/flyteadmin/auth"
"github.com/flyteorg/flyteadmin/auth/config"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/mocks"
"github.com/stretchr/testify/assert"
)

func newMockProvider(t testing.TB) (Provider, auth.SecretsSet) {
Expand Down Expand Up @@ -213,71 +212,3 @@ func TestProvider_ValidateAccessToken(t *testing.T) {
assert.False(t, identity.IsEmpty())
})
}

func Test_verifyClaims(t *testing.T) {
t.Run("Empty claims, fail", func(t *testing.T) {
_, err := verifyClaims(sets.NewString("https://myserver"), map[string]interface{}{})
assert.Error(t, err)
})

t.Run("All filled", func(t *testing.T) {
identityCtx, err := verifyClaims(sets.NewString("https://myserver"), map[string]interface{}{
"aud": []string{"https://myserver"},
"user_info": map[string]interface{}{
"preferred_name": "John Doe",
},
"sub": "123",
"client_id": "my-client",
"scp": []interface{}{"all", "offline"},
})

assert.NoError(t, err)
assert.Equal(t, sets.NewString("all", "offline"), identityCtx.Scopes())
assert.Equal(t, "my-client", identityCtx.AppID())
assert.Equal(t, "123", identityCtx.UserID())
assert.Equal(t, "https://myserver", identityCtx.Audience())
})

t.Run("Multiple audience", func(t *testing.T) {
identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"),
map[string]interface{}{
"aud": []string{"https://myserver"},
"user_info": map[string]interface{}{
"preferred_name": "John Doe",
},
"sub": "123",
"client_id": "my-client",
"scp": []interface{}{"all", "offline"},
})

assert.NoError(t, err)
assert.Equal(t, "https://myserver", identityCtx.Audience())
})

t.Run("No matching audience", func(t *testing.T) {
_, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"),
map[string]interface{}{
"aud": []string{"https://myserver3"},
})

assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid audience")
})

t.Run("Use first matching audience", func(t *testing.T) {
identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2", "https://myserver3"),
map[string]interface{}{
"aud": []string{"https://myserver", "https://myserver2"},
"user_info": map[string]interface{}{
"preferred_name": "John Doe",
},
"sub": "123",
"client_id": "my-client",
"scp": []interface{}{"all", "offline"},
})

assert.NoError(t, err)
assert.Equal(t, "https://myserver", identityCtx.Audience())
})

}

0 comments on commit 4534fa0

Please sign in to comment.