Skip to content

Commit

Permalink
Merge pull request from GHSA-qwrj-9hmp-gpxh
Browse files Browse the repository at this point in the history
* Fix claims verification for access tokens in external IdP setup

Signed-off-by: Haytham Abuelfutuh <haytham@afutuh.com>

* Add another test case for no signature

Signed-off-by: Haytham Abuelfutuh <haytham@afutuh.com>
  • Loading branch information
EngHabu authored Jul 13, 2022
1 parent 639386d commit 0021ea2
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 23 deletions.
3 changes: 0 additions & 3 deletions flyteadmin/auth/authzserver/authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ func TestAuthEndpoint(t *testing.T) {
})
}

// #nosec
const sampleIDToken = `eyJraWQiOiJaNmRtWl9UWGhkdXctalVCWjZ1RUV6dm5oLWpoTk8wWWhlbUI3cWFfTE9jIiwiYWxnIjoiUlMyNTYifQ.eyJzdWIiOiIwMHVra2k0OHBzSDhMaWtZVjVkNiIsIm5hbWUiOiJIYXl0aGFtIEFidWVsZnV0dWgiLCJ2ZXIiOjEsImlzcyI6Imh0dHBzOi8vZGV2LTE0MTg2NDIyLm9rdGEuY29tL29hdXRoMi9hdXNrbmdubjd1QlZpUXE2YjVkNiIsImF1ZCI6IjBvYWtraGV0ZU5qQ01FUnN0NWQ2IiwiaWF0IjoxNjE4NDUzNjc5LCJleHAiOjE2MTg0NTcyNzksImp0aSI6IklELmE0YXpLdUphVFM2YzNTeHdpWWdTMHhPbTM2bVFnVlVVN0I4V2dEdk80dFkiLCJhbXIiOlsicHdkIl0sImlkcCI6IjBvYWtrbTFjaTFVZVBwTlUwNWQ2IiwicHJlZmVycmVkX3VzZXJuYW1lIjoiaGF5dGhhbUB1bmlvbi5haSIsImF1dGhfdGltZSI6MTYxODQ0NjI0NywiYXRfaGFzaCI6Ikg5Q0FweWlrQkpGYXJ4d1FUbnB6ZFEifQ.SJ3BTD_MFcrYvTnql181Ddeb_mOm81z_S7ZKQ6P8mMgWqn94LZ2nG8k8-_odaaNAAT-M1nAFKWqZAQGvliwS1_TsD8_j0cen5zYnGcz2Uu5fFlvoHwuPgy5JYYNOXkXYgPnIb3kNkgXKbkdjS9hdbMfvnPd9rr8v0yzqf0AQBnUe-cPrzY-ZJjvh80IWDZgSjoP244tTYppPkx8UtedJLJZ4tzB7aXlEyoRV-DpmOLfJkAmblRm4OsO1qjwmx3HSIy_T-0PANn-g4AS07rpoMYHRcqncdgcAsVfGxjyWiOg3kbymLqpGlkIZgzmev-TmpoDp0QkUVPOntuiB57GZ6g`

//func TestAuthCallbackEndpoint(t *testing.T) {
// originalURL := "http://localhost:8088/oauth2/authorize?client_id=my-client&redirect_uri=http%3A%2F%2Flocalhost%3A3846%2Fcallback&response_type=code&scope=photos+openid+offline&state=some-random-state-foobar&nonce=some-random-nonce&code_challenge=p0v_UR0KrXl4--BpxM2BQa7qIW5k3k4WauBhjmkVQw8&code_challenge_method=S256"
// req := httptest.NewRequest(http.MethodGet, originalURL, nil)
Expand Down
15 changes: 10 additions & 5 deletions flyteadmin/auth/authzserver/resource_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
jwtgo "github.com/golang-jwt/jwt/v4"
"io/ioutil"
"mime"
"net/http"
Expand All @@ -28,17 +29,21 @@ type ResourceServer struct {
}

func (r ResourceServer) ValidateAccessToken(ctx context.Context, expectedAudience, tokenStr string) (interfaces.IdentityContext, error) {
raw, err := r.signatureVerifier.VerifySignature(ctx, tokenStr)
_, err := r.signatureVerifier.VerifySignature(ctx, tokenStr)
if err != nil {
return nil, err
}

claimsRaw := map[string]interface{}{}
if err = json.Unmarshal(raw, &claimsRaw); err != nil {
return nil, fmt.Errorf("failed to unmarshal user info claim into UserInfo type. Error: %w", err)
t, _, err := jwtgo.NewParser().ParseUnverified(tokenStr, jwtgo.MapClaims{})
if err != nil {
return nil, fmt.Errorf("failed to parse token: %v", err)
}

if err = t.Claims.Valid(); err != nil {
return nil, fmt.Errorf("failed to validate token: %v", err)
}

return verifyClaims(sets.NewString(append(r.allowedAudience, expectedAudience)...), claimsRaw)
return verifyClaims(sets.NewString(append(r.allowedAudience, expectedAudience)...), t.Claims.(jwtgo.MapClaims))
}

func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
Expand Down
134 changes: 119 additions & 15 deletions flyteadmin/auth/authzserver/resource_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package authzserver

import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"io"
"net/http"
Expand All @@ -10,6 +12,9 @@ import (
"reflect"
"strings"
"testing"
"time"

"github.com/golang-jwt/jwt/v4"

"github.com/stretchr/testify/assert"

Expand All @@ -21,20 +26,20 @@ import (
stdlibConfig "github.com/flyteorg/flytestdlib/config"
)

func newMockResourceServer(t testing.TB) ResourceServer {
func newMockResourceServer(t testing.TB, publicKey rsa.PublicKey) (resourceServer ResourceServer, closer func()) {
ctx := context.Background()
dummy := ""
serverURL := &dummy
hf := func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/.well-known/oauth-authorization-server" {
w.Header().Set("Content-Type", "application/json")
_, err := io.WriteString(w, strings.ReplaceAll(`{
"issuer": "https://dev-14186422.okta.com",
"issuer": "https://whatever.okta.com",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"jwks_uri": "URL/keys",
"jwks_uri": "{URL}/keys",
"id_token_signing_alg_values_supported": ["RS256"]
}`, "URL", *serverURL))
}`, "{URL}", *serverURL))

if !assert.NoError(t, err) {
t.FailNow()
Expand All @@ -43,6 +48,14 @@ func newMockResourceServer(t testing.TB) ResourceServer {
return
} else if r.URL.Path == "/keys" {
keys := jwk.NewSet()
key := jwk.NewRSAPublicKey()
err := key.FromRaw(&publicKey)
if err != nil {
http.Error(w, err.Error(), 400)
return
}

keys.Add(key)
raw, err := json.Marshal(keys)
if err != nil {
http.Error(w, err.Error(), 400)
Expand All @@ -55,36 +68,127 @@ func newMockResourceServer(t testing.TB) ResourceServer {
if !assert.NoError(t, err) {
t.FailNow()
}

return
}

http.NotFound(w, r)
}

s := httptest.NewServer(http.HandlerFunc(hf))
defer s.Close()

*serverURL = s.URL

http.DefaultClient = s.Client()

r, err := NewOAuth2ResourceServer(ctx, authConfig.ExternalAuthorizationServer{
BaseURL: stdlibConfig.URL{URL: *config.MustParseURL(s.URL)},
BaseURL: stdlibConfig.URL{URL: *config.MustParseURL(s.URL)},
AllowedAudience: []string{"https://localhost"},
}, stdlibConfig.URL{})
if !assert.NoError(t, err) {
t.FailNow()
}

return r
}

func TestNewOAuth2ResourceServer(t *testing.T) {
newMockResourceServer(t)
return r, func() {
s.Close()
}
}

func TestResourceServer_ValidateAccessToken(t *testing.T) {
r := newMockResourceServer(t)
_, err := r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken)
assert.Error(t, err)
sampleRSAKey, err := rsa.GenerateKey(rand.Reader, 2048)
if !assert.NoError(t, err) {
t.FailNow()
}

r, closer := newMockResourceServer(t, sampleRSAKey.PublicKey)
defer closer()

t.Run("No signature", func(t *testing.T) {
sampleIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS512, jwt.RegisteredClaims{
Audience: r.allowedAudience,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: "localhost",
Subject: "someone",
}).SignedString(sampleRSAKey)
if !assert.NoError(t, err) {
t.FailNow()
}

parts := strings.Split(sampleIDToken, ".")
sampleIDToken = strings.Join(parts[:len(parts)-1], ".") + "."

_, err = r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken)
if !assert.Error(t, err) {
t.FailNow()
}

assert.Contains(t, err.Error(), "failed to verify id token signature")
})

t.Run("Invalid signature", func(t *testing.T) {
sampleRSAKey, err := rsa.GenerateKey(rand.Reader, 2048)
if !assert.NoError(t, err) {
t.FailNow()
}

sampleIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS512, jwt.RegisteredClaims{
Audience: r.allowedAudience,
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: "localhost",
Subject: "someone",
}).SignedString(sampleRSAKey)
if !assert.NoError(t, err) {
t.FailNow()
}

_, err = r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken)
if !assert.Error(t, err) {
t.FailNow()
}

assert.Contains(t, err.Error(), "failed to verify id token signature")
})

t.Run("Invalid audience", func(t *testing.T) {
sampleIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS512, jwt.RegisteredClaims{
Audience: []string{"https://hello world"},
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: "localhost",
Subject: "someone",
}).SignedString(sampleRSAKey)
if !assert.NoError(t, err) {
t.FailNow()
}

_, err = r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken)
if !assert.Error(t, err) {
t.FailNow()
}

assert.Contains(t, err.Error(), "invalid audience")
})

t.Run("Expired token", func(t *testing.T) {
sampleIDToken, err := jwt.NewWithClaims(jwt.SigningMethodRS512, jwt.StandardClaims{
Audience: r.allowedAudience[0],
ExpiresAt: time.Now().Add(-time.Hour).Unix(),
IssuedAt: time.Now().Add(-2 * time.Hour).Unix(),
Issuer: "localhost",
Subject: "someone",
}).SignedString(sampleRSAKey)
if !assert.NoError(t, err) {
t.FailNow()
}

_, err = r.ValidateAccessToken(context.Background(), "myserver", sampleIDToken)
if !assert.Error(t, err) {
t.FailNow()
}

assert.Contains(t, err.Error(), "failed to validate token: Token is expired")
})
}

func Test_doRequest(t *testing.T) {
Expand Down

0 comments on commit 0021ea2

Please sign in to comment.