forked from flyteorg/flyte
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added type assertion check on scp claims which are returned as string…
… in Azure AD (flyteorg#471) * Added type assertion check on scp claims which are returned as string in Azure AD Signed-off-by: Prafulla Mahindrakar <prafulla.mahindrakar@gmail.com> * Moved claim_verfier and added a unit test for string scope Signed-off-by: Prafulla Mahindrakar <prafulla.mahindrakar@gmail.com> * added another check for string type for scp claims and fail otherwise Signed-off-by: Prafulla Mahindrakar <prafulla.mahindrakar@gmail.com> Signed-off-by: Prafulla Mahindrakar <prafulla.mahindrakar@gmail.com>
- Loading branch information
1 parent
098939d
commit 4534fa0
Showing
4 changed files
with
168 additions
and
121 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
package authzserver | ||
|
||
import ( | ||
"encoding/json" | ||
"fmt" | ||
|
||
"github.com/ory/x/jwtx" | ||
"k8s.io/apimachinery/pkg/util/sets" | ||
|
||
"github.com/flyteorg/flyteadmin/auth" | ||
"github.com/flyteorg/flyteadmin/auth/interfaces" | ||
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/service" | ||
) | ||
|
||
func verifyClaims(expectedAudience sets.String, claimsRaw map[string]interface{}) (interfaces.IdentityContext, error) { | ||
claims := jwtx.ParseMapStringInterfaceClaims(claimsRaw) | ||
|
||
foundAudIndex := -1 | ||
for audIndex, aud := range claims.Audience { | ||
if expectedAudience.Has(aud) { | ||
foundAudIndex = audIndex | ||
break | ||
} | ||
} | ||
|
||
if foundAudIndex < 0 { | ||
return nil, fmt.Errorf("invalid audience [%v]", claims) | ||
} | ||
|
||
userInfo := &service.UserInfoResponse{} | ||
if userInfoClaim, found := claimsRaw[UserIDClaim]; found && userInfoClaim != nil { | ||
userInfoRaw := userInfoClaim.(map[string]interface{}) | ||
raw, err := json.Marshal(userInfoRaw) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if err = json.Unmarshal(raw, userInfo); err != nil { | ||
return nil, fmt.Errorf("failed to unmarshal user info claim into UserInfo type. Error: %w", err) | ||
} | ||
} | ||
|
||
clientID := "" | ||
if clientIDClaim, found := claimsRaw[ClientIDClaim]; found { | ||
clientID = clientIDClaim.(string) | ||
} | ||
|
||
scopes := sets.NewString() | ||
if scopesClaim, found := claimsRaw[ScopeClaim]; found { | ||
|
||
switch sct := scopesClaim.(type) { | ||
case []interface{}: | ||
scopes = sets.NewString(interfaceSliceToStringSlice(sct)...) | ||
case string: | ||
sets.NewString(fmt.Sprintf("%v", scopesClaim)) | ||
default: | ||
return nil, fmt.Errorf("failed getting scope claims due to unknown type %T with value %v", sct, sct) | ||
} | ||
} | ||
|
||
// If this is a user-only access token with no scopes defined then add `all` scope by default because it's equivalent | ||
// to having a user's login cookie or an ID Token as means of accessing the service. | ||
if len(clientID) == 0 && scopes.Len() == 0 { | ||
scopes.Insert(auth.ScopeAll) | ||
} | ||
|
||
return auth.NewIdentityContext(claims.Audience[foundAudIndex], claims.Subject, clientID, claims.IssuedAt, scopes, userInfo, claimsRaw), nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
package authzserver | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"k8s.io/apimachinery/pkg/util/sets" | ||
) | ||
|
||
func Test_verifyClaims(t *testing.T) { | ||
t.Run("Empty claims, fail", func(t *testing.T) { | ||
_, err := verifyClaims(sets.NewString("https://myserver"), map[string]interface{}{}) | ||
assert.Error(t, err) | ||
}) | ||
|
||
t.Run("All filled", func(t *testing.T) { | ||
identityCtx, err := verifyClaims(sets.NewString("https://myserver"), map[string]interface{}{ | ||
"aud": []string{"https://myserver"}, | ||
"user_info": map[string]interface{}{ | ||
"preferred_name": "John Doe", | ||
}, | ||
"sub": "123", | ||
"client_id": "my-client", | ||
"scp": []interface{}{"all", "offline"}, | ||
}) | ||
|
||
assert.NoError(t, err) | ||
assert.Equal(t, sets.NewString("all", "offline"), identityCtx.Scopes()) | ||
assert.Equal(t, "my-client", identityCtx.AppID()) | ||
assert.Equal(t, "123", identityCtx.UserID()) | ||
assert.Equal(t, "https://myserver", identityCtx.Audience()) | ||
}) | ||
|
||
t.Run("Multiple audience", func(t *testing.T) { | ||
identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"), | ||
map[string]interface{}{ | ||
"aud": []string{"https://myserver"}, | ||
"user_info": map[string]interface{}{ | ||
"preferred_name": "John Doe", | ||
}, | ||
"sub": "123", | ||
"client_id": "my-client", | ||
"scp": []interface{}{"all", "offline"}, | ||
}) | ||
|
||
assert.NoError(t, err) | ||
assert.Equal(t, "https://myserver", identityCtx.Audience()) | ||
}) | ||
|
||
t.Run("No matching audience", func(t *testing.T) { | ||
_, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"), | ||
map[string]interface{}{ | ||
"aud": []string{"https://myserver3"}, | ||
}) | ||
|
||
assert.Error(t, err) | ||
assert.Contains(t, err.Error(), "invalid audience") | ||
}) | ||
|
||
t.Run("Use first matching audience", func(t *testing.T) { | ||
identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2", "https://myserver3"), | ||
map[string]interface{}{ | ||
"aud": []string{"https://myserver", "https://myserver2"}, | ||
"user_info": map[string]interface{}{ | ||
"preferred_name": "John Doe", | ||
}, | ||
"sub": "123", | ||
"client_id": "my-client", | ||
"scp": []interface{}{"all", "offline"}, | ||
}) | ||
|
||
assert.NoError(t, err) | ||
assert.Equal(t, "https://myserver", identityCtx.Audience()) | ||
}) | ||
|
||
t.Run("String scope", func(t *testing.T) { | ||
identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"), | ||
map[string]interface{}{ | ||
"aud": []string{"https://myserver"}, | ||
"scp": "all", | ||
}) | ||
|
||
assert.NoError(t, err) | ||
assert.Equal(t, "https://myserver", identityCtx.Audience()) | ||
assert.Equal(t, sets.NewString("all"), identityCtx.Scopes()) | ||
}) | ||
t.Run("unknown scope", func(t *testing.T) { | ||
identityCtx, err := verifyClaims(sets.NewString("https://myserver", "https://myserver2"), | ||
map[string]interface{}{ | ||
"aud": []string{"https://myserver"}, | ||
"scp": 1, | ||
}) | ||
|
||
assert.Error(t, err) | ||
assert.Nil(t, identityCtx) | ||
assert.Equal(t, "failed getting scope claims due to unknown type int with value 1", err.Error()) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters