Skip to content

Commit

Permalink
Add builder options to help setting custom token headers (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
danikarik authored Jan 14, 2021
1 parent 831cdfc commit 14f248f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 10 deletions.
22 changes: 20 additions & 2 deletions build.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ var (
b64EncodedLen = base64.RawURLEncoding.EncodedLen
)

// BuilderOption is used to modify builder properties.
type BuilderOption func(*Builder)

// WithKeyID sets `kid` header for token.
func WithKeyID(kid string) BuilderOption {
return func(b *Builder) { b.header.KeyID = kid }
}

// WithContentType sets `cty` header for token.
func WithContentType(cty string) BuilderOption {
return func(b *Builder) { b.header.ContentType = cty }
}

// Builder is used to create a new token.
type Builder struct {
signer Signer
Expand All @@ -28,14 +41,19 @@ func Build(signer Signer, claims interface{}) (*Token, error) {
}

// NewBuilder returns new instance of Builder.
func NewBuilder(signer Signer) *Builder {
func NewBuilder(signer Signer, opts ...BuilderOption) *Builder {
b := &Builder{
signer: signer,
header: Header{
Algorithm: signer.Algorithm(),
Type: "JWT",
},
}

for _, opt := range opts {
opt(b)
}

b.headerRaw = encodeHeader(b.header)
return b
}
Expand Down Expand Up @@ -105,7 +123,7 @@ func encodeClaims(claims interface{}) ([]byte, error) {
}

func encodeHeader(header Header) []byte {
if header.Type == "JWT" && header.ContentType == "" {
if header.Type == "JWT" && header.ContentType == "" && header.KeyID == "" {
if h := getPredefinedHeader(header); h != "" {
return []byte(h)
}
Expand Down
32 changes: 24 additions & 8 deletions build_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ func TestBuild(t *testing.T) {
}

func TestBuildHeader(t *testing.T) {
f := func(signer Signer, header Header, want string) {
f := func(signer Signer, want string, opts ...BuilderOption) {
t.Helper()

token, err := NewBuilder(signer).Build(&StandardClaims{})
token, err := NewBuilder(signer, opts...).Build(&StandardClaims{})
if err != nil {
t.Error(err)
}
Expand All @@ -55,35 +55,51 @@ func TestBuildHeader(t *testing.T) {
key := []byte("key")
f(
mustSigner(NewSignerHS(HS256, key)),
Header{Algorithm: HS256, Type: "JWT"},
`{"alg":"HS256","typ":"JWT"}`,
)
f(
mustSigner(NewSignerHS(HS384, key)),
Header{Algorithm: HS384, Type: "JWT"},
`{"alg":"HS384","typ":"JWT"}`,
)
f(
mustSigner(NewSignerHS(HS512, key)),
Header{Algorithm: HS512, Type: "JWT"},
`{"alg":"HS512","typ":"JWT"}`,
)

f(
mustSigner(NewSignerRS(RS256, rsaPrivateKey1)),
Header{Algorithm: RS256, Type: "JWT"},
`{"alg":"RS256","typ":"JWT"}`,
)
f(
mustSigner(NewSignerRS(RS384, rsaPrivateKey1)),
Header{Algorithm: RS384, Type: "JWT"},
`{"alg":"RS384","typ":"JWT"}`,
)
f(
mustSigner(NewSignerRS(RS512, rsaPrivateKey1)),
Header{Algorithm: RS512, Type: "JWT"},
`{"alg":"RS512","typ":"JWT"}`,
)

f(
mustSigner(NewSignerHS(HS256, key)),
`{"alg":"HS256","typ":"JWT","kid":"test"}`,
WithKeyID("test"),
)
f(
mustSigner(NewSignerHS(HS512, key)),
`{"alg":"HS512","typ":"JWT","cty":"jwk+json"}`,
WithContentType("jwk+json"),
)

f(
mustSigner(NewSignerRS(RS256, rsaPrivateKey1)),
`{"alg":"RS256","typ":"JWT","kid":"test"}`,
WithKeyID("test"),
)
f(
mustSigner(NewSignerRS(RS512, rsaPrivateKey1)),
`{"alg":"RS512","typ":"JWT","cty":"jwk+json"}`,
WithContentType("jwk+json"),
)
}

func TestBuildMalformed(t *testing.T) {
Expand Down

0 comments on commit 14f248f

Please sign in to comment.