Skip to content

Commit

Permalink
fix: export custom claims translation (dexidp#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Palesandro committed Nov 9, 2023
1 parent 579bf92 commit d14abf1
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 23 deletions.
7 changes: 0 additions & 7 deletions cmd/dex/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,6 @@ type Config struct {
// database.
StaticPasswords []password `json:"staticPasswords"`

// If enabled, the server does not initialize connectors when it starts, but
// initializes connectors when it handles requests. This allows the server to
// start even if some connectors cannot be initialized, for example, because
// because it is misconfigured, or because initialization requires a network
// service that is unavailable.
LazyInitConnectors bool `json:"lazyInitConnectors"`

// TokenClaimsHooks is a list of hooks that can be used to mutate the claims of a token
TokenClaimsHooks config.TokenClaimsHooks `json:"tokenClaimsHooks"`

Expand Down
1 change: 0 additions & 1 deletion cmd/dex/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ func runServe(options serveOptions) error {
Now: now,
PrometheusRegistry: prometheusRegistry,
HealthChecker: healthChecker,
LazyInitConnectors: c.LazyInitConnectors,
TokenClaimsHooks: c.TokenClaimsHooks,
ConnectorFilterHooks: c.ConnectorFilterHooks,
}
Expand Down
36 changes: 36 additions & 0 deletions pkg/webhook/claims/idtoken.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package claims

import (
"encoding/json"
"fmt"

"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)

func getProtectedClaims() []string {
return []string{"iss", "sub", "aud", "exp", "iat", "azp", "nonce", "at_hash", "c_hash"}
}

func generateIDClaims(baseIDClaims map[string]interface{}, customClaims map[string]interface{}) map[string]interface{} {
finalClaims := map[string]interface{}{}
maps.Copy(finalClaims, baseIDClaims)
// Adding the immutable claims to the token
protectedClaims := getProtectedClaims()
for claim := range customClaims {
if !slices.Contains(protectedClaims, claim) {
finalClaims[claim] = customClaims[claim]
}
}
return finalClaims
}

func GenerateTokenFromTemplate(baseIDClaims map[string]interface{}, customClaims map[string]interface{}) ([]byte,
error,
) {
payload, err := json.Marshal(generateIDClaims(baseIDClaims, customClaims))
if err != nil {
return []byte{}, fmt.Errorf("could not serialize claims: %v", err)
}
return payload, nil
}
54 changes: 54 additions & 0 deletions pkg/webhook/claims/idtoken_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package claims

import (
"testing"

"github.com/stretchr/testify/assert"
)

func Test_GroupTranslation(t *testing.T) {
baseIDToken := map[string]interface{}{
"groups": []string{"group1", "group2"},
}
receivedIDToken := map[string]interface{}{
"groups": []string{"test2:group1", "test2:group2"},
}
res, err := GenerateTokenFromTemplate(baseIDToken, receivedIDToken)
assert.NoError(t, err)
assert.Equal(t, res, []byte(`{"groups":["test2:group1","test2:group2"]}`))
}

func Test_EmptyInput(t *testing.T) {
res, err := GenerateTokenFromTemplate(map[string]interface{}{}, map[string]interface{}{})
assert.NoError(t, err)
assert.Equal(t, res, []byte(`{}`))
}

func Test_ProtectedClaims(t *testing.T) {
baseIDToken := map[string]interface{}{
"iss": "iss",
"sub": "sub",
"aud": "aud",
"groups": []string{"group1", "group2"},
}
receivedIDToken := map[string]interface{}{
"iss": "iss2",
"groups": []string{"test2:group1", "test2:group2"},
}
res, err := GenerateTokenFromTemplate(baseIDToken, receivedIDToken)
assert.NoError(t, err)
assert.Equal(t, res, []byte(`{"aud":"aud","groups":["test2:group1","test2:group2"],"iss":"iss","sub":"sub"}`))
}

func Test_NotStandardClaims(t *testing.T) {
baseIDToken := map[string]interface{}{
"groups": []string{"group1", "group2"},
}
receivedIDToken := map[string]interface{}{
"groups": []string{"test2:group1", "test2:group2"},
"custom": "custom",
}
res, err := GenerateTokenFromTemplate(baseIDToken, receivedIDToken)
assert.NoError(t, err)
assert.Equal(t, res, []byte(`{"custom":"custom","groups":["test2:group1","test2:group2"]}`))
}
20 changes: 5 additions & 15 deletions server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import (
"strings"
"time"

"golang.org/x/exp/maps"
jose "gopkg.in/square/go-jose.v2"

"github.com/dexidp/dex/connector"
claimsWebhook "github.com/dexidp/dex/pkg/webhook/claims"
"github.com/dexidp/dex/server/internal"
"github.com/dexidp/dex/storage"
)
Expand Down Expand Up @@ -423,7 +423,7 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str
}
}

payload, err := generateTokenFromTemplate(tok, res)
payload, err := claimsWebhook.GenerateTokenFromTemplate(convertToMap(tok), res)
if err != nil {
return "", expiry, fmt.Errorf("could not serialize claims: %v", err)
}
Expand Down Expand Up @@ -721,10 +721,8 @@ func forgeMap(tok idTokenClaims) map[string]interface{} {
}
}

func generateTokenFromTemplate(baseIDClaims idTokenClaims, customClaims map[string]interface{}) ([]byte, error) {
finalClaims := map[string]interface{}{}
// Adding the immutable claims to the token
rawClaims := map[string]interface{}{
func convertToMap(baseIDClaims idTokenClaims) map[string]interface{} {
return map[string]interface{}{
"iss": baseIDClaims.Issuer,
"sub": baseIDClaims.Subject,
"aud": baseIDClaims.Audience,
Expand All @@ -735,18 +733,10 @@ func generateTokenFromTemplate(baseIDClaims idTokenClaims, customClaims map[stri
"at_hash": baseIDClaims.AccessTokenHash,
"c_hash": baseIDClaims.CodeHash,
"email": baseIDClaims.Email,
"email_verified": true,
"email_verified": baseIDClaims.EmailVerified,
"groups": baseIDClaims.Groups,
"name": baseIDClaims.Name,
"preferred_username": baseIDClaims.PreferredUsername,
"federated_claims": baseIDClaims.FederatedIDClaims,
}
maps.Copy(finalClaims, customClaims)
maps.Copy(finalClaims, rawClaims)

payload, err := json.Marshal(finalClaims)
if err != nil {
return []byte{}, fmt.Errorf("could not serialize claims: %v", err)
}
return payload, nil
}

0 comments on commit d14abf1

Please sign in to comment.