diff --git a/build.go b/build.go index 3aad244..6858d88 100644 --- a/build.go +++ b/build.go @@ -10,6 +10,19 @@ var ( b64EncodedLen = base64.RawURLEncoding.EncodedLen ) +// BuilderOption is used to modify builder properties. +type BuilderOption func(*Builder) + +// WithKeyID sets `kid` header for token. +func WithKeyID(kid string) BuilderOption { + return func(b *Builder) { b.header.KeyID = kid } +} + +// WithContentType sets `cty` header for token. +func WithContentType(cty string) BuilderOption { + return func(b *Builder) { b.header.ContentType = cty } +} + // Builder is used to create a new token. type Builder struct { signer Signer @@ -28,7 +41,7 @@ func Build(signer Signer, claims interface{}) (*Token, error) { } // NewBuilder returns new instance of Builder. -func NewBuilder(signer Signer) *Builder { +func NewBuilder(signer Signer, opts ...BuilderOption) *Builder { b := &Builder{ signer: signer, header: Header{ @@ -36,6 +49,11 @@ func NewBuilder(signer Signer) *Builder { Type: "JWT", }, } + + for _, opt := range opts { + opt(b) + } + b.headerRaw = encodeHeader(b.header) return b } @@ -105,7 +123,7 @@ func encodeClaims(claims interface{}) ([]byte, error) { } func encodeHeader(header Header) []byte { - if header.Type == "JWT" && header.ContentType == "" { + if header.Type == "JWT" && header.ContentType == "" && header.KeyID == "" { if h := getPredefinedHeader(header); h != "" { return []byte(h) } diff --git a/build_test.go b/build_test.go index 11b409d..bac937e 100644 --- a/build_test.go +++ b/build_test.go @@ -37,10 +37,10 @@ func TestBuild(t *testing.T) { } func TestBuildHeader(t *testing.T) { - f := func(signer Signer, header Header, want string) { + f := func(signer Signer, want string, opts ...BuilderOption) { t.Helper() - token, err := NewBuilder(signer).Build(&StandardClaims{}) + token, err := NewBuilder(signer, opts...).Build(&StandardClaims{}) if err != nil { t.Error(err) } @@ -55,35 +55,51 @@ func TestBuildHeader(t *testing.T) { key := []byte("key") f( mustSigner(NewSignerHS(HS256, key)), - Header{Algorithm: HS256, Type: "JWT"}, `{"alg":"HS256","typ":"JWT"}`, ) f( mustSigner(NewSignerHS(HS384, key)), - Header{Algorithm: HS384, Type: "JWT"}, `{"alg":"HS384","typ":"JWT"}`, ) f( mustSigner(NewSignerHS(HS512, key)), - Header{Algorithm: HS512, Type: "JWT"}, `{"alg":"HS512","typ":"JWT"}`, ) f( mustSigner(NewSignerRS(RS256, rsaPrivateKey1)), - Header{Algorithm: RS256, Type: "JWT"}, `{"alg":"RS256","typ":"JWT"}`, ) f( mustSigner(NewSignerRS(RS384, rsaPrivateKey1)), - Header{Algorithm: RS384, Type: "JWT"}, `{"alg":"RS384","typ":"JWT"}`, ) f( mustSigner(NewSignerRS(RS512, rsaPrivateKey1)), - Header{Algorithm: RS512, Type: "JWT"}, `{"alg":"RS512","typ":"JWT"}`, ) + + f( + mustSigner(NewSignerHS(HS256, key)), + `{"alg":"HS256","typ":"JWT","kid":"test"}`, + WithKeyID("test"), + ) + f( + mustSigner(NewSignerHS(HS512, key)), + `{"alg":"HS512","typ":"JWT","cty":"jwk+json"}`, + WithContentType("jwk+json"), + ) + + f( + mustSigner(NewSignerRS(RS256, rsaPrivateKey1)), + `{"alg":"RS256","typ":"JWT","kid":"test"}`, + WithKeyID("test"), + ) + f( + mustSigner(NewSignerRS(RS512, rsaPrivateKey1)), + `{"alg":"RS512","typ":"JWT","cty":"jwk+json"}`, + WithContentType("jwk+json"), + ) } func TestBuildMalformed(t *testing.T) {