From 87416b12b5518a098cea2ff67ba2cf252c039fd0 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Wed, 6 May 2020 09:14:05 +0900 Subject: [PATCH 01/22] excercise non-openid keys in openid.Token --- jwt/internal/types/date.go | 4 +- jwt/openid/openid_test.go | 192 +++++++++++++++++++++++++++++-------- 2 files changed, 156 insertions(+), 40 deletions(-) diff --git a/jwt/internal/types/date.go b/jwt/internal/types/date.go index 4d35c2402..669699cb0 100644 --- a/jwt/internal/types/date.go +++ b/jwt/internal/types/date.go @@ -52,14 +52,14 @@ func (n *NumericDate) Accept(v interface{}) error { case string: i, err := strconv.ParseInt(x[:], 10, 64) if err != nil { - return errors.Errorf(`invalid epoch value`) + return errors.Errorf(`invalid epoch value %#v`, x) } t = time.Unix(i, 0) case json.Number: intval, err := x.Int64() if err != nil { - return errors.Wrap(err, `failed to convert json value to int64`) + return errors.Wrapf(err, `failed to convert json value %#v to int64`, x) } t = time.Unix(intval, 0) case time.Time: diff --git a/jwt/openid/openid_test.go b/jwt/openid/openid_test.go index 9dff4b3a4..a847458c4 100644 --- a/jwt/openid/openid_test.go +++ b/jwt/openid/openid_test.go @@ -2,15 +2,22 @@ package openid_test import ( "encoding/json" + "fmt" "testing" "time" + "github.com/lestrrat-go/jwx/jwt/internal/types" "github.com/lestrrat-go/jwx/jwt/openid" "github.com/stretchr/testify/assert" ) const aLongLongTimeAgo = 233431200 const aLongLongTimeAgoString = "233431200" +const ( + tokenTime = 233431200 +) + +var expectedTokenTime = time.Unix(tokenTime, 0).UTC() func assertStockAddressClaim(t *testing.T, x *openid.AddressClaim) bool { t.Helper() @@ -85,132 +92,211 @@ func TestOpenIDClaims(t *testing.T) { return assert.Equal(t, v, expected) } - var base = map[string]struct { - Value interface{} - Key string - Check func(openid.Token) bool + var base = []struct { + Value interface{} + Expected func(interface{}) interface{} + Key string + Check func(openid.Token) bool }{ - "name": { + { + Key: openid.AudienceKey, + Value: []string{"developers", "secops", "tac"}, + Check: func(token openid.Token) bool { + return assert.Equal(t, token.Audience(), []string{"developers", "secops", "tac"}) + }, + }, + { + Key: openid.ExpirationKey, + Value: tokenTime, + Expected: func(v interface{}) interface{} { + var n types.NumericDate + if err := n.Accept(v); err != nil { + panic(err) + } + return n.Get() + }, + Check: func(token openid.Token) bool { + return assert.Equal(t, token.Expiration(), expectedTokenTime) + }, + }, + { + Key: openid.IssuedAtKey, + Value: tokenTime, + Expected: func(v interface{}) interface{} { + var n types.NumericDate + if err := n.Accept(v); err != nil { + panic(err) + } + return n.Get() + }, + Check: func(token openid.Token) bool { + return assert.Equal(t, token.Expiration(), expectedTokenTime) + }, + }, + { + Key: openid.IssuerKey, + Value: "http://www.example.com", + Check: func(token openid.Token) bool { + return assert.Equal(t, token.Issuer(), "http://www.example.com") + }, + }, + { + Key: openid.JwtIDKey, + Value: "e9bc097a-ce51-4036-9562-d2ade882db0d", + Check: func(token openid.Token) bool { + return assert.Equal(t, token.JwtID(), "e9bc097a-ce51-4036-9562-d2ade882db0d") + }, + }, + { + Key: openid.NotBeforeKey, + Value: tokenTime, + Expected: func(v interface{}) interface{} { + var n types.NumericDate + if err := n.Accept(v); err != nil { + panic(err) + } + return n.Get() + }, + Check: func(token openid.Token) bool { + return assert.Equal(t, token.NotBefore(), expectedTokenTime) + }, + }, + { + Key: openid.SubjectKey, + Value: "unit test", + Check: func(token openid.Token) bool { + return assert.Equal(t, token.Subject(), "unit test") + }, + }, + { Value: "jwx", Key: openid.NameKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.Name(), "jwx") }, }, - "given_name": { + { Value: "jay", Key: openid.GivenNameKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.GivenName(), "jay") }, }, - "middle_name": { + { Value: "weee", Key: openid.MiddleNameKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.MiddleName(), "weee") }, }, - "family_name": { + { Value: "xi", Key: openid.FamilyNameKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.FamilyName(), "xi") }, }, - "nickname": { + { Value: "jayweexi", Key: openid.NicknameKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.Nickname(), "jayweexi") }, }, - "preferred_username": { + { Value: "jwx", Key: openid.PreferredUsernameKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.PreferredUsername(), "jwx") }, }, - "profile": { + { Value: "https://github.com/lestrrat-go/jwx", Key: openid.ProfileKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.Profile(), "https://github.com/lestrrat-go/jwx") }, }, - "picture": { + { Value: "https://avatars1.githubusercontent.com/u/36653903?s=400&v=4", Key: openid.PictureKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.Picture(), "https://avatars1.githubusercontent.com/u/36653903?s=400&v=4") }, }, - "website": { + { Value: "https://github.com/lestrrat-go/jwx", Key: openid.WebsiteKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.Website(), "https://github.com/lestrrat-go/jwx") }, }, - "email": { + { Value: "lestrrat+github@gmail.com", Key: openid.EmailKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.Email(), "lestrrat+github@gmail.com") }, }, - "email_verified": { + { Value: true, Key: openid.EmailVerifiedKey, Check: func(token openid.Token) bool { return assert.True(t, token.EmailVerified()) }, }, - "gender": { + { Value: "n/a", Key: openid.GenderKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.Gender(), "n/a") }, }, - "birthdate": { + { Value: "2015-11-04", + Key: openid.BirthdateKey, + Expected: func(v interface{}) interface{} { + var b openid.BirthdateClaim + if err := b.Accept(v); err != nil { + panic(err) + } + return &b + }, Check: func(token openid.Token) bool { var b openid.BirthdateClaim b.Accept("2015-11-04") return assert.Equal(t, token.Birthdate(), &b) }, }, - "zoneinfo": { + { Value: "Asia/Tokyo", Key: openid.ZoneinfoKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.Zoneinfo(), "Asia/Tokyo") }, }, - "locale": { + { Value: "ja_JP", Key: openid.LocaleKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.Locale(), "ja_JP") }, }, - "phone_number": { + { Value: "819012345678", Key: openid.PhoneNumberKey, Check: func(token openid.Token) bool { return assert.Equal(t, token.PhoneNumber(), "819012345678") }, }, - "phone_number_verified": { + { Value: true, Key: openid.PhoneNumberVerifiedKey, Check: func(token openid.Token) bool { return assert.True(t, token.PhoneNumberVerified()) }, }, - "address": { + { Value: map[string]interface{}{ "formatted": "〒105-0011 東京都港区芝公園4丁目2−8", "street_address": "芝公園4丁目2−8", @@ -219,12 +305,34 @@ func TestOpenIDClaims(t *testing.T) { "country": "日本", "postal_code": "105-0011", }, + Key: openid.AddressKey, + Expected: func(v interface{}) interface{} { + address := openid.NewAddress() + m, ok := v.(map[string]interface{}) + if !ok { + panic(fmt.Sprintf("expected map[string]interface{}, got %T", v)) + } + for name, val := range m { + if !assert.NoError(t, address.Set(name, val), `address.Set should succeed`) { + return nil + } + } + return address + }, Check: func(token openid.Token) bool { return assertStockAddressClaim(t, token.Address()) }, }, - "updated_at": { + { Value: aLongLongTimeAgoString, + Key: openid.UpdatedAtKey, + Expected: func(v interface{}) interface{} { + var n types.NumericDate + if err := n.Accept(v); err != nil { + panic(err) + } + return n.Get() + }, Check: func(token openid.Token) bool { return assert.Equal(t, time.Unix(aLongLongTimeAgo, 0).UTC(), token.UpdatedAt()) }, @@ -232,11 +340,15 @@ func TestOpenIDClaims(t *testing.T) { } var data = map[string]interface{}{} - for name, value := range base { - data[name] = value.Value + for _, value := range base { + data[value.Key] = value.Value } - var tokens []openid.Token + type openidTokTestCase struct { + Name string + Token openid.Token + } + var tokens []openidTokTestCase { // one with Set() token := openid.New() @@ -245,11 +357,11 @@ func TestOpenIDClaims(t *testing.T) { return } } - tokens = append(tokens, token) + tokens = append(tokens, openidTokTestCase{Name: `token constructed by calling Set()`, Token: token}) } { // one with json.Marshal / json.Unmarshal - src, err := json.Marshal(data) + src, err := json.MarshalIndent(data, "", " ") if !assert.NoError(t, err, `failed to marshal base map`) { return } @@ -260,22 +372,26 @@ func TestOpenIDClaims(t *testing.T) { if !assert.NoError(t, json.Unmarshal(src, &token), `json.Unmarshal should succeed`) { return } - tokens = append(tokens, token) + tokens = append(tokens, openidTokTestCase{Name: `token constructed by Marshal+Unmashal`, Token: token}) } for _, token := range tokens { token := token - for name, value := range base { - value := value - t.Run(name, func(t *testing.T) { - value.Check(token) - }) - if value.Key != "" { - t.Run(name+" via Get()", func(t *testing.T) { - getVerify(token, value.Key, value.Value) + t.Run(token.Name, func(t *testing.T) { + for _, value := range base { + value := value + t.Run(value.Key, func(t *testing.T) { + value.Check(token.Token) + }) + t.Run(value.Key+" via Get()", func(t *testing.T) { + expected := value.Value + if expf := value.Expected; expf != nil { + expected = expf(value.Value) + } + getVerify(token.Token, value.Key, expected) }) } - } + }) } } From a753742310467049c76c11f976349f80825f8a15 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Wed, 6 May 2020 09:18:01 +0900 Subject: [PATCH 02/22] appease golangci-lint --- jwt/openid/openid_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/jwt/openid/openid_test.go b/jwt/openid/openid_test.go index a847458c4..f6a2e3b4b 100644 --- a/jwt/openid/openid_test.go +++ b/jwt/openid/openid_test.go @@ -106,7 +106,7 @@ func TestOpenIDClaims(t *testing.T) { }, }, { - Key: openid.ExpirationKey, + Key: openid.ExpirationKey, Value: tokenTime, Expected: func(v interface{}) interface{} { var n types.NumericDate @@ -120,7 +120,7 @@ func TestOpenIDClaims(t *testing.T) { }, }, { - Key: openid.IssuedAtKey, + Key: openid.IssuedAtKey, Value: tokenTime, Expected: func(v interface{}) interface{} { var n types.NumericDate @@ -134,21 +134,21 @@ func TestOpenIDClaims(t *testing.T) { }, }, { - Key: openid.IssuerKey, + Key: openid.IssuerKey, Value: "http://www.example.com", Check: func(token openid.Token) bool { return assert.Equal(t, token.Issuer(), "http://www.example.com") }, }, { - Key: openid.JwtIDKey, + Key: openid.JwtIDKey, Value: "e9bc097a-ce51-4036-9562-d2ade882db0d", Check: func(token openid.Token) bool { return assert.Equal(t, token.JwtID(), "e9bc097a-ce51-4036-9562-d2ade882db0d") }, }, { - Key: openid.NotBeforeKey, + Key: openid.NotBeforeKey, Value: tokenTime, Expected: func(v interface{}) interface{} { var n types.NumericDate @@ -162,7 +162,7 @@ func TestOpenIDClaims(t *testing.T) { }, }, { - Key: openid.SubjectKey, + Key: openid.SubjectKey, Value: "unit test", Check: func(token openid.Token) bool { return assert.Equal(t, token.Subject(), "unit test") From 6d7b4932f09748698e205cd333f3f9759f7aa05d Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Wed, 6 May 2020 09:24:05 +0900 Subject: [PATCH 03/22] excercise MarshalJSON --- jwt/openid/openid_test.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/jwt/openid/openid_test.go b/jwt/openid/openid_test.go index f6a2e3b4b..bd2223393 100644 --- a/jwt/openid/openid_test.go +++ b/jwt/openid/openid_test.go @@ -360,7 +360,7 @@ func TestOpenIDClaims(t *testing.T) { tokens = append(tokens, openidTokTestCase{Name: `token constructed by calling Set()`, Token: token}) } - { // one with json.Marshal / json.Unmarshal + { // two with json.Marshal / json.Unmarshal src, err := json.MarshalIndent(data, "", " ") if !assert.NoError(t, err, `failed to marshal base map`) { return @@ -372,7 +372,20 @@ func TestOpenIDClaims(t *testing.T) { if !assert.NoError(t, json.Unmarshal(src, &token), `json.Unmarshal should succeed`) { return } - tokens = append(tokens, openidTokTestCase{Name: `token constructed by Marshal+Unmashal`, Token: token}) + tokens = append(tokens, openidTokTestCase{Name: `token constructed by Marshal(map)+Unmashal`, Token: token}) + + + // One more... Marshal the token, _and_ re-unmarshal + buf, err := json.Marshal(token) + if !assert.NoError(t, err, `json.Marshal should succeed`) { + return + } + + token2 := openid.New() + if !assert.NoError(t, json.Unmarshal(buf, &token2), `json.Unmarshal should succeed`) { + return + } + tokens = append(tokens, openidTokTestCase{Name: `token constructed by Marshal(openid.Token)+Unmashal`, Token: token2}) } for _, token := range tokens { From 3230d23233bbc4bb00064d0fa451f2cb60c7bd1a Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Wed, 6 May 2020 09:33:11 +0900 Subject: [PATCH 04/22] excercise iterators --- jwt/openid/interface.go | 2 +- jwt/openid/openid_test.go | 50 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/jwt/openid/interface.go b/jwt/openid/interface.go index dba28486b..7ef2645d5 100644 --- a/jwt/openid/interface.go +++ b/jwt/openid/interface.go @@ -8,4 +8,4 @@ import ( type ClaimPair = mapiter.Pair type Iterator = mapiter.Iterator type Visitor = iter.MapVisitor -type VisitorFunc iter.MapVisitorFunc +type VisitorFunc = iter.MapVisitorFunc diff --git a/jwt/openid/openid_test.go b/jwt/openid/openid_test.go index bd2223393..4d7c74982 100644 --- a/jwt/openid/openid_test.go +++ b/jwt/openid/openid_test.go @@ -1,6 +1,7 @@ package openid_test import ( + "context" "encoding/json" "fmt" "testing" @@ -340,8 +341,14 @@ func TestOpenIDClaims(t *testing.T) { } var data = map[string]interface{}{} + var expected = map[string]interface{}{} for _, value := range base { data[value.Key] = value.Value + if expf := value.Expected; expf != nil { + expected[value.Key] = expf(value.Value) + } else { + expected[value.Key] = value.Value + } } type openidTokTestCase struct { @@ -374,7 +381,6 @@ func TestOpenIDClaims(t *testing.T) { } tokens = append(tokens, openidTokTestCase{Name: `token constructed by Marshal(map)+Unmashal`, Token: token}) - // One more... Marshal the token, _and_ re-unmarshal buf, err := json.Marshal(token) if !assert.NoError(t, err, `json.Marshal should succeed`) { @@ -406,6 +412,48 @@ func TestOpenIDClaims(t *testing.T) { } }) } + + t.Run("Iterator", func(t *testing.T) { + v := tokens[0].Token + t.Run("Iterate", func(t *testing.T) { + seen := make(map[string]interface{}) + for iter := v.Iterate(context.TODO()); iter.Next(context.TODO()); { + pair := iter.Pair() + seen[pair.Key.(string)] = pair.Value + + getV, ok := v.Get(pair.Key.(string)) + if !assert.True(t, ok, `v.Get should succeed for key %#v`, pair.Key) { + return + } + if !assert.Equal(t, pair.Value, getV, `pair.Value should match value from v.Get()`) { + return + } + } + if !assert.Equal(t, expected, seen, `values should match`) { + return + } + }) + t.Run("Walk", func(t *testing.T) { + seen := make(map[string]interface{}) + v.Walk(context.TODO(), openid.VisitorFunc(func(key string, value interface{}) error { + seen[key] = value + return nil + })) + if !assert.Equal(t, expected, seen, `values should match`) { + return + } + }) + t.Run("AsMap", func(t *testing.T) { + seen, err := v.AsMap(context.TODO()) + if !assert.NoError(t, err, `v.AsMap should succeed`) { + return + } + if !assert.Equal(t, expected, seen, `values should match`) { + return + } + }) + }) + } func TestBirthdateClaim(t *testing.T) { From b533edc64bcdbe73bbfb9081ae0ed07606385b4f Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Wed, 6 May 2020 09:35:25 +0900 Subject: [PATCH 05/22] appease golangci-lint --- jwt/openid/openid_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/jwt/openid/openid_test.go b/jwt/openid/openid_test.go index 4d7c74982..cdc3288b3 100644 --- a/jwt/openid/openid_test.go +++ b/jwt/openid/openid_test.go @@ -453,7 +453,6 @@ func TestOpenIDClaims(t *testing.T) { } }) }) - } func TestBirthdateClaim(t *testing.T) { From a9805898746cc695052d11dde3bc392583cd1a33 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Wed, 6 May 2020 09:47:30 +0900 Subject: [PATCH 06/22] excercise private params --- jwt/openid/openid_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/jwt/openid/openid_test.go b/jwt/openid/openid_test.go index cdc3288b3..4c2e314ac 100644 --- a/jwt/openid/openid_test.go +++ b/jwt/openid/openid_test.go @@ -338,6 +338,17 @@ func TestOpenIDClaims(t *testing.T) { return assert.Equal(t, time.Unix(aLongLongTimeAgo, 0).UTC(), token.UpdatedAt()) }, }, + { + Value: `dummy`, + Key: `dummy`, + Check: func(token openid.Token) bool { + v, ok := token.Get(`dummy`) + if !assert.True(t, ok, `token.Get should return valid value`) { + return false + } + return assert.Equal(t, `dummy`, v, `values should match`) + }, + }, } var data = map[string]interface{}{} From 06d5b54e13e2bd8a47076605d53090b99b978c03 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Wed, 6 May 2020 09:50:18 +0900 Subject: [PATCH 07/22] appease golangci-lint --- jwt/openid/openid_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jwt/openid/openid_test.go b/jwt/openid/openid_test.go index 4c2e314ac..c78e127cb 100644 --- a/jwt/openid/openid_test.go +++ b/jwt/openid/openid_test.go @@ -340,7 +340,7 @@ func TestOpenIDClaims(t *testing.T) { }, { Value: `dummy`, - Key: `dummy`, + Key: `dummy`, Check: func(token openid.Token) bool { v, ok := token.Get(`dummy`) if !assert.True(t, ok, `token.Get should return valid value`) { From 2c55d0e147c093c36df522fe21fd78bc0151b234 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Wed, 6 May 2020 09:59:28 +0900 Subject: [PATCH 08/22] excercise types.NumericDate --- jwt/internal/types/date_test.go | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/jwt/internal/types/date_test.go b/jwt/internal/types/date_test.go index 3b78f4bc1..1f51f63b4 100644 --- a/jwt/internal/types/date_test.go +++ b/jwt/internal/types/date_test.go @@ -3,14 +3,32 @@ package types_test import ( "encoding/json" "fmt" - "reflect" "testing" "time" "github.com/lestrrat-go/jwx/jwt" + "github.com/lestrrat-go/jwx/jwt/internal/types" + "github.com/stretchr/testify/assert" ) func TestDate(t *testing.T) { + t.Run("Get from a nil NumericDate", func(t *testing.T) { + var n *types.NumericDate + if !assert.Equal(t, time.Time{}, n.Get()) { + return + } + }) + t.Run("MarshalJSON with a zero value", func(t *testing.T) { + var n *types.NumericDate + buf, err := json.Marshal(n) + if !assert.NoError(t, err, `json.Marshal against a zero value should succeed`) { + return + } + + if !assert.Equal(t, []byte(`null`), buf, `result should be null`) { + return + } + }) t.Run("Accept values", func(t *testing.T) { // NumericDate allows assignment from various different Go types, // so that it's easier for the devs, and conversion to/from JSON @@ -21,16 +39,16 @@ func TestDate(t *testing.T) { t.Run(fmt.Sprintf("%T", ut), func(t *testing.T) { t1 := jwt.New() err := t1.Set(jwt.IssuedAtKey, ut) - if err != nil { - t.Fatalf("Failed to set IssuedAt value: %v", ut) + if !assert.NoError(t, err) { + return } v, ok := t1.Get(jwt.IssuedAtKey) - if !ok { - t.Fatal("Failed to retrieve IssuedAt value") + if !assert.True(t, ok) { + return } realized := v.(time.Time) - if !reflect.DeepEqual(now, realized) { - t.Fatalf("Token time mistmatch. Expected:Realized (%v:%v)", now, realized) + if !assert.Equal(t, now, realized) { + return } }) } From 00bd224da1d13249b1f80a2d6c5b33f4e0cb6b95 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 07:57:03 +0900 Subject: [PATCH 09/22] Excercise Thumbprint() --- Changes | 3 ++ jwk/jwk.go | 30 +++++++++++++ jwk/jwk_test.go | 115 +++++++++++++++++++++++++++++++++++++++--------- jwk/option.go | 8 +++- 4 files changed, 134 insertions(+), 22 deletions(-) diff --git a/Changes b/Changes index 7ef045883..581f309ed 100644 --- a/Changes +++ b/Changes @@ -1,6 +1,9 @@ Changes ======= +v1.0.2 + * Add jwk.AssignKeyID to automatically assign a `kid` field to a JWK + v1.0.1 - 04 May 2020 * Normalize all JWK serialization to use padding-less base64 encoding (#185) * Fix edge case unmarshaling openid.AddressClaim within a openid.Token diff --git a/jwk/jwk.go b/jwk/jwk.go index 100a5b323..74351f399 100644 --- a/jwk/jwk.go +++ b/jwk/jwk.go @@ -6,6 +6,7 @@ package jwk import ( "bytes" "context" + "crypto" "crypto/ecdsa" "crypto/rsa" "encoding/json" @@ -18,6 +19,7 @@ import ( "strings" "github.com/lestrrat-go/iter/arrayiter" + "github.com/lestrrat-go/jwx/internal/base64" "github.com/lestrrat-go/jwx/jwa" "github.com/pkg/errors" ) @@ -346,3 +348,31 @@ func assignRawResult(v, t interface{}) error { return nil } + +// AssignKeyID is a convenience function to automatically assign the "kid" +// section of the key, if it already doesn't have one. It uses Key.Thumbprint +// method with crypto.SHA256 as the default hashing algorithm +func AssignKeyID(key Key, options ...Option) error { + if _, ok := key.Get(KeyIDKey); ok { + return nil + } + + hash := crypto.SHA256 + for _, option := range options { + switch option.Name() { + case optkeyThumbprintHash: + hash = option.Value().(crypto.Hash) + } + } + + h, err := key.Thumbprint(hash) + if err != nil { + return errors.Wrap(err, `failed to generate thumbprint`) + } + + if err := key.Set(KeyIDKey, base64.EncodeToString(h)); err != nil { + return errors.Wrap(err, `failed to set "kid"`) + } + + return nil +} diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index 711883f0c..0d9ab0509 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -116,16 +116,69 @@ func TestParse(t *testing.T) { }) } -func TestRoundtrip(t *testing.T) { - generateRSA := func(use string, keyID string) (jwk.Key, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) +func generateRSAPrivateKey() (jwk.Key, error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, errors.Wrap(err, `failed to generate RSA private key`) + } + + k, err := jwk.New(key) + if err != nil { + return nil, errors.Wrap(err, `failed to generate jwk.RSAPrivateKey`) + } + + return k, nil +} + +func generateRSAPublicKey() (jwk.Key, error) { + k, err := generateRSAPrivateKey() + if err != nil { + return nil, err + } + + return k.(jwk.RSAPrivateKey).PublicKey() +} + +func generateECDSAPrivateKey() (jwk.Key, error) { + key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if err != nil { + return nil, errors.Wrap(err, `failed to generate ECDSA private key`) + } + + k, err := jwk.New(key) + if err != nil { + return nil, errors.Wrap(err, `failed to generate jwk.ECDSAPrivateKey`) + } + + return k, nil +} + +func generateECDSAPublicKey() (jwk.Key, error) { + k, err := generateECDSAPrivateKey() + if err != nil { + return nil, err + } + + return k.(jwk.ECDSAPrivateKey).PublicKey() +} + +func generateSymmetricKey() (jwk.Key, error) { + sharedKey := make([]byte, 64) + rand.Read(sharedKey) + + key, err := jwk.New(sharedKey) if err != nil { - return nil, errors.Wrap(err, `failed to generate RSA private key`) + return nil, errors.Wrap(err, `failed to generate jwk.SymmetricKey`) } - k, err := jwk.New(key) + return key, nil +} + +func TestRoundtrip(t *testing.T) { + generateRSA := func(use string, keyID string) (jwk.Key, error) { + k, err := generateRSAPrivateKey() if err != nil { - return nil, errors.Wrap(err, `failed to generate jwk.RSAPrivateKey`) + return nil, err } k.Set(jwk.KeyUsageKey, use) @@ -134,14 +187,9 @@ func TestRoundtrip(t *testing.T) { } generateECDSA := func(use, keyID string) (jwk.Key, error) { - key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + k, err := generateECDSAPrivateKey() if err != nil { - return nil, errors.Wrap(err, `failed to generate ECDSA private key`) - } - - k, err := jwk.New(key) - if err != nil { - return nil, errors.Wrap(err, `failed to generate jwk.ECDSAPrivateKey`) + return nil, err } k.Set(jwk.KeyUsageKey, use) @@ -150,17 +198,14 @@ func TestRoundtrip(t *testing.T) { } generateSymmetric := func(use, keyID string) (jwk.Key, error) { - sharedKey := make([]byte, 64) - rand.Read(sharedKey) - - key, err := jwk.New(sharedKey) + k, err := generateSymmetricKey() if err != nil { - return nil, errors.Wrap(err, `failed to generate jwk.SymmetricKey`) + return nil, err } - key.Set(jwk.KeyUsageKey, use) - key.Set(jwk.KeyIDKey, keyID) - return key, nil + k.Set(jwk.KeyUsageKey, use) + k.Set(jwk.KeyIDKey, keyID) + return k, nil } tests := []struct { @@ -296,3 +341,31 @@ func TestKeyOperation(t *testing.T) { } } } + +func TestAssignKeyID(t *testing.T) { + generators := []func() (jwk.Key, error){ + generateRSAPrivateKey, + generateRSAPublicKey, + generateECDSAPrivateKey, + generateECDSAPublicKey, + generateSymmetricKey, + } + + for _, generator := range generators { + k, err := generator() + if !assert.NoError(t, err, `jwk generation should be successful`) { + return + } + + if !assert.Empty(t, k.KeyID(), `k.KeyID should be non-empty`) { + return + } + if !assert.NoError(t, jwk.AssignKeyID(k), `AssignKeyID shuld be successful`) { + return + } + + if !assert.NotEmpty(t, k.KeyID(), `k.KeyID should be non-empty`) { + return + } + } +} diff --git a/jwk/option.go b/jwk/option.go index 80751670d..c560cb7d4 100644 --- a/jwk/option.go +++ b/jwk/option.go @@ -1,6 +1,7 @@ package jwk import ( + "crypto" "net/http" "github.com/lestrrat-go/jwx/internal/option" @@ -9,9 +10,14 @@ import ( type Option = option.Interface const ( - optkeyHTTPClient = `http-client` + optkeyHTTPClient = `http-client` + optkeyThumbprintHash = `thumbprint-hash` ) func WithHTTPClient(cl *http.Client) Option { return option.New(optkeyHTTPClient, cl) } + +func WithThumbprintHash(h crypto.Hash) Option { + return option.New(optkeyThumbprintHash, h) +} From c690076ecba6084b2e800457d34ff5ead726043b Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 08:03:19 +0900 Subject: [PATCH 10/22] oops --- jwk/jwk_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index 0d9ab0509..4d7892689 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -163,13 +163,13 @@ func generateECDSAPublicKey() (jwk.Key, error) { } func generateSymmetricKey() (jwk.Key, error) { - sharedKey := make([]byte, 64) - rand.Read(sharedKey) + sharedKey := make([]byte, 64) + rand.Read(sharedKey) - key, err := jwk.New(sharedKey) - if err != nil { - return nil, errors.Wrap(err, `failed to generate jwk.SymmetricKey`) - } + key, err := jwk.New(sharedKey) + if err != nil { + return nil, errors.Wrap(err, `failed to generate jwk.SymmetricKey`) + } return key, nil } From 0d29afff9516a6f7ea762fb3030d710f0fe874f6 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 08:11:27 +0900 Subject: [PATCH 11/22] don't use StringStd --- internal/base64/base64.go | 4 ---- jwk/certchain.go | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/internal/base64/base64.go b/internal/base64/base64.go index 1833b3a7a..b1b36b060 100644 --- a/internal/base64/base64.go +++ b/internal/base64/base64.go @@ -6,10 +6,6 @@ import ( "strings" ) -func EncodeToStringStd(src []byte) string { - return base64.StdEncoding.EncodeToString(src) -} - func EncodeToString(src []byte) string { return base64.RawURLEncoding.EncodeToString(src) } diff --git a/jwk/certchain.go b/jwk/certchain.go index 5bd002e97..b4cd1b8a3 100644 --- a/jwk/certchain.go +++ b/jwk/certchain.go @@ -12,7 +12,7 @@ func (c CertificateChain) MarshalJSON() ([]byte, error) { certs := c.Get() encoded := make([]string, len(certs)) for i := 0; i < len(certs); i++ { - encoded[i] = base64.EncodeToStringStd(certs[i].Raw) + encoded[i] = base64.EncodeToString(certs[i].Raw) } return json.Marshal(encoded) } From e8da78e94e0cba9bbbc88d5133be0ef72b5fd2ae Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 17:05:28 +0900 Subject: [PATCH 12/22] Fix content compression --- Changes | 1 + jwe/compress.go | 35 ++++++++++++++++++++++++++ jwe/encrypt.go | 16 +++++++++++- jwe/interface.go | 1 + jwe/jwe.go | 1 + jwe/jwe_test.go | 55 +++++++++++++++++++++-------------------- jwe/message.go | 64 ++++++++++++++++++++++++------------------------ 7 files changed, 114 insertions(+), 59 deletions(-) create mode 100644 jwe/compress.go diff --git a/Changes b/Changes index 581f309ed..9fb052a8e 100644 --- a/Changes +++ b/Changes @@ -3,6 +3,7 @@ Changes v1.0.2 * Add jwk.AssignKeyID to automatically assign a `kid` field to a JWK + * Fix jwe.Encrypt / jwe.Decrypt to properly look at the `zip` field v1.0.1 - 04 May 2020 * Normalize all JWK serialization to use padding-less base64 encoding (#185) diff --git a/jwe/compress.go b/jwe/compress.go new file mode 100644 index 000000000..2fa97a707 --- /dev/null +++ b/jwe/compress.go @@ -0,0 +1,35 @@ +package jwe + +import ( + "bytes" + "compress/flate" + "io/ioutil" + + "github.com/lestrrat-go/jwx/jwa" + "github.com/pkg/errors" +) + +func uncompress(plaintext []byte) ([]byte, error) { + return ioutil.ReadAll(flate.NewReader(bytes.NewReader(plaintext))) +} + +func compress(plaintext []byte, alg jwa.CompressionAlgorithm) ([]byte, error) { + if alg == jwa.NoCompress { + return plaintext, nil + } + + var output bytes.Buffer + w, _ := flate.NewWriter(&output, 1) + in := plaintext + for len(in) > 0 { + n, err := w.Write(in) + if err != nil { + return nil, errors.Wrap(err, `failed to write to compression writer`) + } + in = in[n:] + } + if err := w.Close(); err != nil { + return nil, errors.Wrap(err, "failed to close compression writer") + } + return output.Bytes(), nil +} diff --git a/jwe/encrypt.go b/jwe/encrypt.go index fc2c3a716..f52c4e2db 100644 --- a/jwe/encrypt.go +++ b/jwe/encrypt.go @@ -4,6 +4,7 @@ import ( "context" "sync" + "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/pdebug" "github.com/pkg/errors" ) @@ -22,6 +23,7 @@ func releaseEncryptCtx(ctx *encryptCtx) { ctx.contentEncrypter = nil ctx.generator = nil ctx.keyEncrypters = nil + ctx.compress = jwa.NoCompress encryptCtxPool.Put(ctx) } @@ -42,7 +44,14 @@ func (e encryptCtx) Encrypt(plaintext []byte) (*Message, error) { protected := NewHeaders() if err := protected.Set(ContentEncryptionKey, e.contentEncrypter.Algorithm()); err != nil { - return nil, errors.Wrap(err, "failed to set enc in protected header") + return nil, errors.Wrap(err, `failed to set "enc" in protected header`) + } + + compression := e.compress + if compression != jwa.NoCompress { + if err := protected.Set(CompressionKey, compression); err != nil { + return nil, errors.Wrap(err, `failed to set "zip" in protected header`) + } } // In JWE, multiple recipients may exist -- they receive an @@ -95,6 +104,11 @@ func (e encryptCtx) Encrypt(plaintext []byte) (*Message, error) { return nil, errors.Wrap(err, "failed to base64 encode protected headers") } + plaintext, err = compress(plaintext, compression) + if err != nil { + return nil, errors.Wrap(err, `failed to compress payload before encryption`) + } + // ...on the other hand, there's only one content cipher. iv, ciphertext, tag, err := e.contentEncrypter.Encrypt(cek, plaintext, aad) if err != nil { diff --git a/jwe/interface.go b/jwe/interface.go index b84558290..15616f9e3 100644 --- a/jwe/interface.go +++ b/jwe/interface.go @@ -49,6 +49,7 @@ type encryptCtx struct { contentEncrypter contentEncrypter generator keygen.Generator keyEncrypters []keyenc.Encrypter + compress jwa.CompressionAlgorithm } // populater is an interface for things that may modify the diff --git a/jwe/jwe.go b/jwe/jwe.go index b58b81c9a..728d617a0 100644 --- a/jwe/jwe.go +++ b/jwe/jwe.go @@ -108,6 +108,7 @@ func Encrypt(payload []byte, keyalg jwa.KeyEncryptionAlgorithm, key interface{}, encctx.contentEncrypter = contentcrypt encctx.generator = keygen.NewRandom(keysize) encctx.keyEncrypters = []keyenc.Encrypter{enc} + encctx.compress = compressalg msg, err := encctx.Encrypt(payload) if err != nil { if pdebug.Enabled { diff --git a/jwe/jwe_test.go b/jwe/jwe_test.go index 36b61ec64..af73838c1 100644 --- a/jwe/jwe_test.go +++ b/jwe/jwe_test.go @@ -223,32 +223,35 @@ func TestParse_RSAES_OAEP_AES_GCM(t *testing.T) { for _, serializer := range serializers { serializer := serializer - t.Run(serializer.Name, func(t *testing.T) { - jsonbuf, err := serializer.Func(msg) - if !assert.NoError(t, err, "serialize succeeded") { - return - } - - if !assert.Equal(t, serializer.Expected, string(jsonbuf), "serialize result matches") { - jsonbuf, _ = jwe.JSON(msg, jwe.WithPrettyJSONFormat(true)) - t.Logf("%s", jsonbuf) - return - } - - encrypted, err := jwe.Encrypt(plaintext, jwa.RSA_OAEP, rawkey.PublicKey, jwa.A256GCM, jwa.NoCompress) - if !assert.NoError(t, err, "jwe.Encrypt should succeed") { - return - } - - plaintext, err = jwe.Decrypt(encrypted, jwa.RSA_OAEP, rawkey) - if !assert.NoError(t, err, "jwe.Decrypt should succeed") { - return - } - - if !assert.Equal(t, payload, string(plaintext), "jwe.Decrypt should produce the same plaintext") { - return - } - }) + for _, compression := range []jwa.CompressionAlgorithm{jwa.NoCompress, jwa.Deflate} { + compression := compression + t.Run(serializer.Name+" (compression="+compression.String()+")", func(t *testing.T) { + jsonbuf, err := serializer.Func(msg) + if !assert.NoError(t, err, "serialize succeeded") { + return + } + + if !assert.Equal(t, serializer.Expected, string(jsonbuf), "serialize result matches") { + jsonbuf, _ = jwe.JSON(msg, jwe.WithPrettyJSONFormat(true)) + t.Logf("%s", jsonbuf) + return + } + + encrypted, err := jwe.Encrypt(plaintext, jwa.RSA_OAEP, rawkey.PublicKey, jwa.A256GCM, compression) + if !assert.NoError(t, err, "jwe.Encrypt should succeed") { + return + } + + plaintext, err = jwe.Decrypt(encrypted, jwa.RSA_OAEP, rawkey) + if !assert.NoError(t, err, "jwe.Decrypt should succeed") { + return + } + + if !assert.Equal(t, payload, string(plaintext), "jwe.Decrypt should produce the same plaintext") { + return + } + }) + } } // Test direct marshaling and unmarshaling diff --git a/jwe/message.go b/jwe/message.go index 2e371a2c5..93fcb2e3f 100644 --- a/jwe/message.go +++ b/jwe/message.go @@ -2,7 +2,6 @@ package jwe import ( "bytes" - "compress/flate" "context" "encoding/json" "fmt" @@ -295,58 +294,76 @@ func (m *Message) Decrypt(alg jwa.KeyEncryptionAlgorithm, key interface{}) ([]by var plaintext []byte var lastError error for _, recipient := range m.recipients { + // strategy: try each recipient. If we fail in one of the steps, + // keep looping because there might be another key with the same algo + if pdebug.Enabled { pdebug.Printf("Attempting to check if we can decode for recipient (alg = %s)", recipient.Headers().Algorithm()) } + if recipient.Headers().Algorithm() != alg { + // algorithms don't match continue } h2, err := mergeHeaders(context.TODO(), nil, h) if err != nil { + lastError = errors.Wrap(err, `failed to copy headers (1)`) if pdebug.Enabled { - pdebug.Printf("failed to copy header: %s", err) + pdebug.Printf(`%s`, lastError) } - lastError = errors.Wrap(err, `failed to copy headers (1)`) continue } h2, err = mergeHeaders(context.TODO(), h2, recipient.Headers()) if err != nil { + lastError = errors.Wrap(err, `failed to copy headers (2)`) if pdebug.Enabled { - pdebug.Printf("Failed to merge! %s", err) + pdebug.Printf(`%s`, lastError) } - lastError = errors.Wrap(err, `failed to copy headers (2)`) continue } k, err := buildKeyDecrypter(h2.Algorithm(), h2, key, keysize) if err != nil { + lastError = errors.Wrap(err, `failed to build key decrypter`) if pdebug.Enabled { - pdebug.Printf("failed to create key decrypter: %s", err) + pdebug.Printf(`%s`, lastError) } - lastError = errors.Wrap(err, `failed to build key decrypter`) continue } cek, err := k.Decrypt(recipient.EncryptedKey().Bytes()) if err != nil { + lastError = errors.Wrap(err, `failed to decrypt key`) if pdebug.Enabled { - pdebug.Printf("failed to decrypt key: %s", err) + pdebug.Printf(`%s`, lastError) } - lastError = errors.Wrap(err, `failed to decrypt key`) continue } plaintext, err = cipher.Decrypt(cek, iv, ciphertext, tag, aad) - if err == nil { - break + if err != nil { + lastError = errors.Wrap(err, `failed to decrypt payload`) + if pdebug.Enabled { + pdebug.Printf(`%s`, lastError) + } + continue } - if pdebug.Enabled { - pdebug.Printf("DecryptMessage: failed to decrypt using %s: %s", h2.Algorithm(), err) + + if h2.Compression() == jwa.Deflate { + buf, err := uncompress(plaintext) + if err != nil { + lastError = errors.Wrap(err, `failed to uncompress payload`) + if pdebug.Enabled { + pdebug.Printf(`%s`, lastError) + } + continue + } + plaintext = buf } - lastError = errors.Wrap(err, `failed to decrypt ciphertext`) - // Keep looping because there might be another key with the same algo + + break } if plaintext == nil { @@ -356,23 +373,6 @@ func (m *Message) Decrypt(alg jwa.KeyEncryptionAlgorithm, key interface{}) ([]by return nil, errors.New("failed to find matching recipient to decrypt key") } - if h.Compression() == jwa.Deflate { - var output bytes.Buffer - w, _ := flate.NewWriter(&output, 1) - in := plaintext - for len(in) > 0 { - n, err := w.Write(in) - if err != nil { - return nil, errors.Wrap(err, `failed to write to compression writer`) - } - in = in[n:] - } - if err := w.Close(); err != nil { - return nil, errors.Wrap(err, "failed to close compression writer") - } - plaintext = output.Bytes() - } - return plaintext, nil } From 5b953867f335738df96c0ecf7341e55913be1d67 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 17:26:19 +0900 Subject: [PATCH 13/22] Excercise CertificateChain Marshal/Unmarshal --- internal/base64/base64.go | 4 ++++ jwk/certchain.go | 2 +- jwk/x5c_test.go | 47 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/internal/base64/base64.go b/internal/base64/base64.go index b1b36b060..47e3ce99b 100644 --- a/internal/base64/base64.go +++ b/internal/base64/base64.go @@ -6,6 +6,10 @@ import ( "strings" ) +func EncodeToStringStd(src []byte) string { + return base64.RawStdEncoding.EncodeToString(src) +} + func EncodeToString(src []byte) string { return base64.RawURLEncoding.EncodeToString(src) } diff --git a/jwk/certchain.go b/jwk/certchain.go index b4cd1b8a3..5bd002e97 100644 --- a/jwk/certchain.go +++ b/jwk/certchain.go @@ -12,7 +12,7 @@ func (c CertificateChain) MarshalJSON() ([]byte, error) { certs := c.Get() encoded := make([]string, len(certs)) for i := 0; i < len(certs); i++ { - encoded[i] = base64.EncodeToString(certs[i].Raw) + encoded[i] = base64.EncodeToStringStd(certs[i].Raw) } return json.Marshal(encoded) } diff --git a/jwk/x5c_test.go b/jwk/x5c_test.go index 03e3b0b32..c150dd46d 100644 --- a/jwk/x5c_test.go +++ b/jwk/x5c_test.go @@ -1,6 +1,7 @@ package jwk_test import ( + "encoding/json" "testing" "github.com/lestrrat-go/jwx/jwk" @@ -14,6 +15,52 @@ func Test_X5CHeader(t *testing.T) { "MIIC5zCCAlACAQEwDQYJKoZIhvcNAQEFBQAwgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0IFZhbGlkYXRpb24gTmV0d29yazEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAzBgNVBAsTLFZhbGlDZXJ0IENsYXNzIDIgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9yaXR5MSEwHwYDVQQDExhodHRwOi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG9w0BCQEWEWluZm9AdmFsaWNlcnQuY29tMB4XDTk5MDYyNjAwMTk1NFoXDTE5MDYyNjAwMTk1NFowgbsxJDAiBgNVBAcTG1ZhbGlDZXJ0IFZhbGlkYXRpb24gTmV0d29yazEXMBUGA1UEChMOVmFsaUNlcnQsIEluYy4xNTAzBgNVBAsTLFZhbGlDZXJ0IENsYXNzIDIgUG9saWN5IFZhbGlkYXRpb24gQXV0aG9yaXR5MSEwHwYDVQQDExhodHRwOi8vd3d3LnZhbGljZXJ0LmNvbS8xIDAeBgkqhkiG9w0BCQEWEWluZm9AdmFsaWNlcnQuY29tMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDOOnHK5avIWZJV16vYdA757tn2VUdZZUcOBVXc65g2PFxTXdMwzzjsvUGJ7SVCCSRrCl6zfN1SLUzm1NZ9WlmpZdRJEy0kTRxQb7XBhVQ7/nHk01xC+YDgkRoKWzk2Z/M/VXwbP7RfZHM047QSv4dk+NoS/zcnwbNDu+97bi5p9wIDAQABMA0GCSqGSIb3DQEBBQUAA4GBADt/UG9vUJSZSWI4OB9L+KXIPqeCgfYrx+jFzug6EILLGACOTb2oWH+heQC1u+mNr0HZDzTuIYEZoDJJKPTEjlbVUjP9UNV+mWwD5MlM/Mtsq2azSiGM5bUMMj4QssxsodyamEwCW/POuZ6lcg5Ktz885hZo+L7tdEy8W9ViH0Pd", } + t.Run("Marshal/Unmarshal", func(t *testing.T) { + // The input contains padding. We can accept either as input, but only emit + // strings encoded without padding + certsNopad := make([]string, len(certs)) + for i, cert := range certs { + for len(cert) > 0 && cert[len(cert)-1] == '=' { + cert = cert[:len(cert)-1] + } + certsNopad[i] = cert + } + + expected, err := json.Marshal(certsNopad) + if !assert.NoError(t, err, `json.Marshal should succeed`) { + return + } + + inputs := map[string][]string{ + "with padding": certs, + "without padding": certsNopad, + } + for k, input := range inputs { + input := input + t.Run(k, func(t *testing.T) { + // Take the input, and create a json + jsonbuf, err := json.Marshal(input) + if !assert.NoError(t, err, `json.Marshal should succeed (for input)`) { + return + } + + var c jwk.CertificateChain + if !assert.NoError(t, json.Unmarshal(jsonbuf, &c), `json.Unmarshal should succeed`) { + return + } + + if !assert.Len(t, c.Get(), 3, `should have three certs`) { + return + } + + buf, err := json.Marshal(c) + if !assert.Equal(t, expected, buf, `json output should match`) { + return + } + }) + } + }) + for _, key := range []jwk.Key{ jwk.NewRSAPrivateKey(), jwk.NewRSAPublicKey(), From 4845f4ea10078d57cce584492f60d92619c0ea58 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 17:28:42 +0900 Subject: [PATCH 14/22] appease golangci-lint --- jwk/x5c_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/jwk/x5c_test.go b/jwk/x5c_test.go index c150dd46d..11ba22b48 100644 --- a/jwk/x5c_test.go +++ b/jwk/x5c_test.go @@ -54,6 +54,10 @@ func Test_X5CHeader(t *testing.T) { } buf, err := json.Marshal(c) + if !assert.NoError(t, err, `json.Marshal should succeed`) { + return + } + if !assert.Equal(t, expected, buf, `json output should match`) { return } From 858cb4a5c71d5a23357db32db1f82805390363a7 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 17:53:21 +0900 Subject: [PATCH 15/22] Excercise PublicKeyOf --- jwk/jwk.go | 6 ++++ jwk/jwk_test.go | 92 ++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/jwk/jwk.go b/jwk/jwk.go index 74351f399..6643bf0c0 100644 --- a/jwk/jwk.go +++ b/jwk/jwk.go @@ -83,10 +83,16 @@ func New(key interface{}) (Key, error) { // PublicKeyOf returns the corresponding public key of the given // value `v`. For example, if v is a `*rsa.PrivateKey`, then // `*rsa.PublicKey` is returned. +// // If given a public key, then the same public key will be returned. // For example, if v is a `*rsa.PublicKey`, then the same value // is returned. +// // If v is of a type that we don't support, an error is returned. +// +// This is useful when you are dealing with the jwk.Key interface +// alone and you don't know before hand what the underlying key +// type is, but you still want to obtain the corresponding public key func PublicKeyOf(v interface{}) (interface{}, error) { // may be a silly idea, but if the user gave us a non-pointer value... var ptr interface{} diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index 4d7892689..d98f72b0c 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -116,8 +116,13 @@ func TestParse(t *testing.T) { }) } +func generateRawRSAPrivateKey() (*rsa.PrivateKey, error) { + return rsa.GenerateKey(rand.Reader, 2048) +} + func generateRSAPrivateKey() (jwk.Key, error) { - key, err := rsa.GenerateKey(rand.Reader, 2048) + key, err := generateRawRSAPrivateKey() + if err != nil { return nil, errors.Wrap(err, `failed to generate RSA private key`) } @@ -139,8 +144,12 @@ func generateRSAPublicKey() (jwk.Key, error) { return k.(jwk.RSAPrivateKey).PublicKey() } +func generateRawECDSAPrivateKey() (*ecdsa.PrivateKey, error) { + return ecdsa.GenerateKey(elliptic.P521(), rand.Reader) +} + func generateECDSAPrivateKey() (jwk.Key, error) { - key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + key, err := generateRawECDSAPrivateKey() if err != nil { return nil, errors.Wrap(err, `failed to generate ECDSA private key`) } @@ -162,11 +171,14 @@ func generateECDSAPublicKey() (jwk.Key, error) { return k.(jwk.ECDSAPrivateKey).PublicKey() } -func generateSymmetricKey() (jwk.Key, error) { +func generateRawSymmetricKey() []byte { sharedKey := make([]byte, 64) rand.Read(sharedKey) + return sharedKey +} - key, err := jwk.New(sharedKey) +func generateSymmetricKey() (jwk.Key, error) { + key, err := jwk.New(generateRawSymmetricKey()) if err != nil { return nil, errors.Wrap(err, `failed to generate jwk.SymmetricKey`) } @@ -369,3 +381,75 @@ func TestAssignKeyID(t *testing.T) { } } } + +func TestPublicKeyOf(t *testing.T) { + rsakey, err := generateRawRSAPrivateKey() + if !assert.NoError(t, err, `generating raw RSA key should succeed`) { + return + } + + ecdsakey, err := generateRawECDSAPrivateKey() + if !assert.NoError(t, err, `generating raw ECDSA key should succeed`) { + return + } + + octets := generateRawSymmetricKey() + + keys := []struct { + Key interface{} + PublicKeyType reflect.Type + }{ + { + Key: rsakey, + PublicKeyType: reflect.PtrTo(reflect.TypeOf(rsakey.PublicKey)), + }, + { + Key: *rsakey, + PublicKeyType: reflect.PtrTo(reflect.TypeOf(rsakey.PublicKey)), + }, + { + Key: rsakey.PublicKey, + PublicKeyType: reflect.PtrTo(reflect.TypeOf(rsakey.PublicKey)), + }, + { + Key: &rsakey.PublicKey, + PublicKeyType: reflect.PtrTo(reflect.TypeOf(rsakey.PublicKey)), + }, + { + Key: ecdsakey, + PublicKeyType: reflect.PtrTo(reflect.TypeOf(ecdsakey.PublicKey)), + }, + { + Key: *ecdsakey, + PublicKeyType: reflect.PtrTo(reflect.TypeOf(ecdsakey.PublicKey)), + }, + { + Key: ecdsakey.PublicKey, + PublicKeyType: reflect.PtrTo(reflect.TypeOf(ecdsakey.PublicKey)), + }, + { + Key: &ecdsakey.PublicKey, + PublicKeyType: reflect.PtrTo(reflect.TypeOf(ecdsakey.PublicKey)), + }, + { + Key: octets, + PublicKeyType: reflect.TypeOf(octets), + }, + } + + for _, key := range keys { + key := key + t.Run(fmt.Sprintf("%T", key.Key), func(t *testing.T) { + t.Parallel() + + pubkey, err := jwk.PublicKeyOf(key.Key) + if !assert.NoError(t, err, `jwk.PublicKeyOf(%T) should succeed`) { + return + } + + if !assert.Equal(t, key.PublicKeyType, reflect.TypeOf(pubkey), `public key types should match`) { + return + } + }) + } +} From 5ffa8547a20c1cb0158948e81007731e2d5213d2 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 18:08:47 +0900 Subject: [PATCH 16/22] Excercise (openid.Token).Get --- jwt/openid/openid_test.go | 197 ++++++++++++++++++++++---------------- 1 file changed, 115 insertions(+), 82 deletions(-) diff --git a/jwt/openid/openid_test.go b/jwt/openid/openid_test.go index c78e127cb..e42dc1e56 100644 --- a/jwt/openid/openid_test.go +++ b/jwt/openid/openid_test.go @@ -20,36 +20,69 @@ const ( var expectedTokenTime = time.Unix(tokenTime, 0).UTC() -func assertStockAddressClaim(t *testing.T, x *openid.AddressClaim) bool { +func testStockAddressClaim(t *testing.T, x *openid.AddressClaim) { t.Helper() if !assert.NotNil(t, x) { - return false - } - - if !assert.Equal(t, "〒105-0011 東京都港区芝公園4丁目2−8", x.Formatted(), "formatted should match") { - return false - } - - if !assert.Equal(t, "日本", x.Country(), "country should match") { - return false - } - - if !assert.Equal(t, "東京都", x.Region(), "region should match") { - return false - } - - if !assert.Equal(t, "港区", x.Locality(), "locality should match") { - return false + return } - if !assert.Equal(t, "芝公園4丁目2−8", x.StreetAddress(), "street_address should match") { - return false + tests := []struct { + Accessor func() string + KeyName string + Value string + }{ + { + Accessor: x.Formatted, + KeyName: openid.AddressFormattedKey, + Value: "〒105-0011 東京都港区芝公園4丁目2−8", + }, + { + Accessor: x.Country, + KeyName: openid.AddressCountryKey, + Value: "日本", + }, + { + Accessor: x.Region, + KeyName: openid.AddressRegionKey, + Value: "東京都", + }, + { + Accessor: x.Locality, + KeyName: openid.AddressLocalityKey, + Value: "港区", + }, + { + Accessor: x.StreetAddress, + KeyName: openid.AddressStreetAddressKey, + Value: "芝公園4丁目2−8", + }, + { + Accessor: x.PostalCode, + KeyName: openid.AddressPostalCodeKey, + Value: "105-0011", + }, } - if !assert.Equal(t, "105-0011", x.PostalCode(), "postal_code should match") { - return false + for _, tc := range tests { + tc := tc + t.Run(tc.KeyName, func(t *testing.T) { + t.Parallel() + t.Run("Accessor", func(t *testing.T) { + if !assert.Equal(t, tc.Value, tc.Accessor(), "values should match") { + return + } + }) + t.Run("Get", func(t *testing.T) { + v, ok := x.Get(tc.KeyName) + if !assert.True(t, ok, `x.Get should succeed`) { + return + } + if !assert.Equal(t, tc.Value, v, `values should match`) { + return + } + }) + }) } - return true } func TestAdressClaim(t *testing.T) { @@ -78,9 +111,7 @@ func TestAdressClaim(t *testing.T) { } for _, x := range []*openid.AddressClaim{&address, &roundtrip} { - if !assertStockAddressClaim(t, x) { - return - } + testStockAddressClaim(t, x) } } @@ -97,13 +128,13 @@ func TestOpenIDClaims(t *testing.T) { Value interface{} Expected func(interface{}) interface{} Key string - Check func(openid.Token) bool + Check func(openid.Token) }{ { Key: openid.AudienceKey, Value: []string{"developers", "secops", "tac"}, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Audience(), []string{"developers", "secops", "tac"}) + Check: func(token openid.Token) { + assert.Equal(t, token.Audience(), []string{"developers", "secops", "tac"}) }, }, { @@ -116,8 +147,8 @@ func TestOpenIDClaims(t *testing.T) { } return n.Get() }, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Expiration(), expectedTokenTime) + Check: func(token openid.Token) { + assert.Equal(t, token.Expiration(), expectedTokenTime) }, }, { @@ -130,22 +161,22 @@ func TestOpenIDClaims(t *testing.T) { } return n.Get() }, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Expiration(), expectedTokenTime) + Check: func(token openid.Token) { + assert.Equal(t, token.Expiration(), expectedTokenTime) }, }, { Key: openid.IssuerKey, Value: "http://www.example.com", - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Issuer(), "http://www.example.com") + Check: func(token openid.Token) { + assert.Equal(t, token.Issuer(), "http://www.example.com") }, }, { Key: openid.JwtIDKey, Value: "e9bc097a-ce51-4036-9562-d2ade882db0d", - Check: func(token openid.Token) bool { - return assert.Equal(t, token.JwtID(), "e9bc097a-ce51-4036-9562-d2ade882db0d") + Check: func(token openid.Token) { + assert.Equal(t, token.JwtID(), "e9bc097a-ce51-4036-9562-d2ade882db0d") }, }, { @@ -158,99 +189,99 @@ func TestOpenIDClaims(t *testing.T) { } return n.Get() }, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.NotBefore(), expectedTokenTime) + Check: func(token openid.Token) { + assert.Equal(t, token.NotBefore(), expectedTokenTime) }, }, { Key: openid.SubjectKey, Value: "unit test", - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Subject(), "unit test") + Check: func(token openid.Token) { + assert.Equal(t, token.Subject(), "unit test") }, }, { Value: "jwx", Key: openid.NameKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Name(), "jwx") + Check: func(token openid.Token) { + assert.Equal(t, token.Name(), "jwx") }, }, { Value: "jay", Key: openid.GivenNameKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.GivenName(), "jay") + Check: func(token openid.Token) { + assert.Equal(t, token.GivenName(), "jay") }, }, { Value: "weee", Key: openid.MiddleNameKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.MiddleName(), "weee") + Check: func(token openid.Token) { + assert.Equal(t, token.MiddleName(), "weee") }, }, { Value: "xi", Key: openid.FamilyNameKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.FamilyName(), "xi") + Check: func(token openid.Token) { + assert.Equal(t, token.FamilyName(), "xi") }, }, { Value: "jayweexi", Key: openid.NicknameKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Nickname(), "jayweexi") + Check: func(token openid.Token) { + assert.Equal(t, token.Nickname(), "jayweexi") }, }, { Value: "jwx", Key: openid.PreferredUsernameKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.PreferredUsername(), "jwx") + Check: func(token openid.Token) { + assert.Equal(t, token.PreferredUsername(), "jwx") }, }, { Value: "https://github.com/lestrrat-go/jwx", Key: openid.ProfileKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Profile(), "https://github.com/lestrrat-go/jwx") + Check: func(token openid.Token) { + assert.Equal(t, token.Profile(), "https://github.com/lestrrat-go/jwx") }, }, { Value: "https://avatars1.githubusercontent.com/u/36653903?s=400&v=4", Key: openid.PictureKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Picture(), "https://avatars1.githubusercontent.com/u/36653903?s=400&v=4") + Check: func(token openid.Token) { + assert.Equal(t, token.Picture(), "https://avatars1.githubusercontent.com/u/36653903?s=400&v=4") }, }, { Value: "https://github.com/lestrrat-go/jwx", Key: openid.WebsiteKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Website(), "https://github.com/lestrrat-go/jwx") + Check: func(token openid.Token) { + assert.Equal(t, token.Website(), "https://github.com/lestrrat-go/jwx") }, }, { Value: "lestrrat+github@gmail.com", Key: openid.EmailKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Email(), "lestrrat+github@gmail.com") + Check: func(token openid.Token) { + assert.Equal(t, token.Email(), "lestrrat+github@gmail.com") }, }, { Value: true, Key: openid.EmailVerifiedKey, - Check: func(token openid.Token) bool { - return assert.True(t, token.EmailVerified()) + Check: func(token openid.Token) { + assert.True(t, token.EmailVerified()) }, }, { Value: "n/a", Key: openid.GenderKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Gender(), "n/a") + Check: func(token openid.Token) { + assert.Equal(t, token.Gender(), "n/a") }, }, { @@ -263,38 +294,38 @@ func TestOpenIDClaims(t *testing.T) { } return &b }, - Check: func(token openid.Token) bool { + Check: func(token openid.Token) { var b openid.BirthdateClaim b.Accept("2015-11-04") - return assert.Equal(t, token.Birthdate(), &b) + assert.Equal(t, token.Birthdate(), &b) }, }, { Value: "Asia/Tokyo", Key: openid.ZoneinfoKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Zoneinfo(), "Asia/Tokyo") + Check: func(token openid.Token) { + assert.Equal(t, token.Zoneinfo(), "Asia/Tokyo") }, }, { Value: "ja_JP", Key: openid.LocaleKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.Locale(), "ja_JP") + Check: func(token openid.Token) { + assert.Equal(t, token.Locale(), "ja_JP") }, }, { Value: "819012345678", Key: openid.PhoneNumberKey, - Check: func(token openid.Token) bool { - return assert.Equal(t, token.PhoneNumber(), "819012345678") + Check: func(token openid.Token) { + assert.Equal(t, token.PhoneNumber(), "819012345678") }, }, { Value: true, Key: openid.PhoneNumberVerifiedKey, - Check: func(token openid.Token) bool { - return assert.True(t, token.PhoneNumberVerified()) + Check: func(token openid.Token) { + assert.True(t, token.PhoneNumberVerified()) }, }, { @@ -320,8 +351,8 @@ func TestOpenIDClaims(t *testing.T) { } return address }, - Check: func(token openid.Token) bool { - return assertStockAddressClaim(t, token.Address()) + Check: func(token openid.Token) { + testStockAddressClaim(t, token.Address()) }, }, { @@ -334,19 +365,21 @@ func TestOpenIDClaims(t *testing.T) { } return n.Get() }, - Check: func(token openid.Token) bool { - return assert.Equal(t, time.Unix(aLongLongTimeAgo, 0).UTC(), token.UpdatedAt()) + Check: func(token openid.Token) { + assert.Equal(t, time.Unix(aLongLongTimeAgo, 0).UTC(), token.UpdatedAt()) }, }, { Value: `dummy`, Key: `dummy`, - Check: func(token openid.Token) bool { + Check: func(token openid.Token) { v, ok := token.Get(`dummy`) if !assert.True(t, ok, `token.Get should return valid value`) { - return false + return + } + if !assert.Equal(t, `dummy`, v, `values should match`) { + return } - return assert.Equal(t, `dummy`, v, `values should match`) }, }, } From 61285a508b72e755e91839a00ea49702969ce44e Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 18:30:04 +0900 Subject: [PATCH 17/22] excercise some accessors --- jwk/ecdsa_gen.go | 38 +++++++++++++++----------- jwk/ecdsa_test.go | 44 +++++++++++++++++++++++++++--- jwk/internal/cmd/genheader/main.go | 2 +- jwk/rsa_gen.go | 38 +++++++++++++++----------- jwk/symmetric_gen.go | 21 ++++++++------ 5 files changed, 97 insertions(+), 46 deletions(-) diff --git a/jwk/ecdsa_gen.go b/jwk/ecdsa_gen.go index 5ccb7ed5f..a1646bf8e 100644 --- a/jwk/ecdsa_gen.go +++ b/jwk/ecdsa_gen.go @@ -40,9 +40,9 @@ type ecdsaPrivateKey struct { algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4 crv *jwa.EllipticCurveAlgorithm d []byte - keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 - keyUsage *string // https://tools.ietf.org/html/rfc7517#section-4.2 - keyops KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 + keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 + keyUsage *string // https://tools.ietf.org/html/rfc7517#section-4.2 + keyops *KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 x []byte x509CertChain *CertificateChain // https://tools.ietf.org/html/rfc7515#section-4.1.6 x509CertThumbprint *string // https://tools.ietf.org/html/rfc7515#section-4.1.7 @@ -59,7 +59,7 @@ type ecdsaPrivateKeyMarshalProxy struct { Xd *string `json:"d,omitempty"` XkeyID *string `json:"kid,omitempty"` XkeyUsage *string `json:"use,omitempty"` - Xkeyops KeyOperationList `json:"key_ops,omitempty"` + Xkeyops *KeyOperationList `json:"key_ops,omitempty"` Xx *string `json:"x,omitempty"` Xx509CertChain *CertificateChain `json:"x5c,omitempty"` Xx509CertThumbprint *string `json:"x5t,omitempty"` @@ -105,7 +105,10 @@ func (h *ecdsaPrivateKey) KeyUsage() string { } func (h *ecdsaPrivateKey) KeyOps() KeyOperationList { - return h.keyops + if h.keyops != nil { + return *(h.keyops) + } + return nil } func (h *ecdsaPrivateKey) X() []byte { @@ -165,7 +168,7 @@ func (h *ecdsaPrivateKey) iterate(ctx context.Context, ch chan *HeaderPair) { pairs = append(pairs, &HeaderPair{Key: KeyUsageKey, Value: *(h.keyUsage)}) } if h.keyops != nil { - pairs = append(pairs, &HeaderPair{Key: KeyOpsKey, Value: h.keyops}) + pairs = append(pairs, &HeaderPair{Key: KeyOpsKey, Value: *(h.keyops)}) } if h.x != nil { pairs = append(pairs, &HeaderPair{Key: ECDSAXKey, Value: h.x}) @@ -234,7 +237,7 @@ func (h *ecdsaPrivateKey) Get(name string) (interface{}, bool) { if h.keyops == nil { return nil, false } - return h.keyops, true + return *(h.keyops), true case ECDSAXKey: if h.x == nil { return nil, false @@ -315,7 +318,7 @@ func (h *ecdsaPrivateKey) Set(name string, value interface{}) error { if err := acceptor.Accept(value); err != nil { return errors.Wrapf(err, `invalid value for %s key`, KeyOpsKey) } - h.keyops = acceptor + h.keyops = &acceptor return nil case ECDSAXKey: if v, ok := value.([]byte); ok { @@ -507,9 +510,9 @@ type ECDSAPublicKey interface { type ecdsaPublicKey struct { algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4 crv *jwa.EllipticCurveAlgorithm - keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 - keyUsage *string // https://tools.ietf.org/html/rfc7517#section-4.2 - keyops KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 + keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 + keyUsage *string // https://tools.ietf.org/html/rfc7517#section-4.2 + keyops *KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 x []byte x509CertChain *CertificateChain // https://tools.ietf.org/html/rfc7515#section-4.1.6 x509CertThumbprint *string // https://tools.ietf.org/html/rfc7515#section-4.1.7 @@ -525,7 +528,7 @@ type ecdsaPublicKeyMarshalProxy struct { Xcrv *jwa.EllipticCurveAlgorithm `json:"crv,omitempty"` XkeyID *string `json:"kid,omitempty"` XkeyUsage *string `json:"use,omitempty"` - Xkeyops KeyOperationList `json:"key_ops,omitempty"` + Xkeyops *KeyOperationList `json:"key_ops,omitempty"` Xx *string `json:"x,omitempty"` Xx509CertChain *CertificateChain `json:"x5c,omitempty"` Xx509CertThumbprint *string `json:"x5t,omitempty"` @@ -567,7 +570,10 @@ func (h *ecdsaPublicKey) KeyUsage() string { } func (h *ecdsaPublicKey) KeyOps() KeyOperationList { - return h.keyops + if h.keyops != nil { + return *(h.keyops) + } + return nil } func (h *ecdsaPublicKey) X() []byte { @@ -624,7 +630,7 @@ func (h *ecdsaPublicKey) iterate(ctx context.Context, ch chan *HeaderPair) { pairs = append(pairs, &HeaderPair{Key: KeyUsageKey, Value: *(h.keyUsage)}) } if h.keyops != nil { - pairs = append(pairs, &HeaderPair{Key: KeyOpsKey, Value: h.keyops}) + pairs = append(pairs, &HeaderPair{Key: KeyOpsKey, Value: *(h.keyops)}) } if h.x != nil { pairs = append(pairs, &HeaderPair{Key: ECDSAXKey, Value: h.x}) @@ -688,7 +694,7 @@ func (h *ecdsaPublicKey) Get(name string) (interface{}, bool) { if h.keyops == nil { return nil, false } - return h.keyops, true + return *(h.keyops), true case ECDSAXKey: if h.x == nil { return nil, false @@ -763,7 +769,7 @@ func (h *ecdsaPublicKey) Set(name string, value interface{}) error { if err := acceptor.Accept(value); err != nil { return errors.Wrapf(err, `invalid value for %s key`, KeyOpsKey) } - h.keyops = acceptor + h.keyops = &acceptor return nil case ECDSAXKey: if v, ok := value.([]byte); ok { diff --git a/jwk/ecdsa_test.go b/jwk/ecdsa_test.go index 27bbb3fc6..e7475e3b7 100644 --- a/jwk/ecdsa_test.go +++ b/jwk/ecdsa_test.go @@ -35,12 +35,32 @@ func TestECDSA(t *testing.T) { return } - if _, ok := set.Keys[0].(jwk.ECDSAPrivateKey); !assert.True(t, ok, `should be jwk.ECDSAPrivateKey`) { + privKey, ok := set.Keys[0].(jwk.ECDSAPrivateKey) + if !assert.True(t, ok, `should be jwk.ECDSAPrivateKey`) { + return + } + + if !assert.Empty(t, privKey.KeyUsage(), `KeyUsage() should be empty`) { + return + } + + if !assert.NotEmpty(t, privKey.KeyOps(), `KeyOps() should be non-empty`) { + return + } + + if !assert.NotEmpty(t, privKey.X(), `X() should be non-empty`) { + return + } + + if !assert.NotEmpty(t, privKey.Y(), `Y() should be non-empty`) { + return + } + + if !assert.NotEmpty(t, privKey.D(), `D() should be non-empty`) { return } var rawPrivKey ecdsa.PrivateKey - privKey := set.Keys[0].(jwk.ECDSAPrivateKey) if !assert.NoError(t, privKey.Raw(&rawPrivKey), "Raw should succeed") { return } @@ -49,13 +69,29 @@ func TestECDSA(t *testing.T) { return } - pubkey, err := privKey.PublicKey() + pubKey, err := privKey.PublicKey() if !assert.NoError(t, err, "Should be able to get ECDSA public key") { return } + if !assert.Empty(t, pubKey.KeyUsage(), `KeyUsage() should be empty`) { + return + } + + if !assert.Empty(t, pubKey.KeyOps(), `KeyOps() should be empty`) { + return + } + + if !assert.NotEmpty(t, pubKey.X(), `X() should be non-empty`) { + return + } + + if !assert.NotEmpty(t, pubKey.Y(), `Y() should be non-empty`) { + return + } + var rawPubKey ecdsa.PublicKey - if !assert.NoError(t, pubkey.Raw(&rawPubKey), "Raw should succeed") { + if !assert.NoError(t, pubKey.Raw(&rawPubKey), "Raw should succeed") { return } diff --git a/jwk/internal/cmd/genheader/main.go b/jwk/internal/cmd/genheader/main.go index 63ee6e665..6a1f85567 100644 --- a/jwk/internal/cmd/genheader/main.go +++ b/jwk/internal/cmd/genheader/main.go @@ -100,7 +100,7 @@ func fieldStorageType(s string) string { } func fieldStorageTypeIsIndirect(s string) bool { - return !(strings.HasPrefix(s, `*`) || strings.HasPrefix(s, `[]`) || strings.HasSuffix(s, `List`)) + return s == "KeyOperationList" || !(strings.HasPrefix(s, `*`) || strings.HasPrefix(s, `[]`) || strings.HasSuffix(s, `List`)) } var standardHeaders []headerField diff --git a/jwk/rsa_gen.go b/jwk/rsa_gen.go index 0a7a16f6e..7c91840e6 100644 --- a/jwk/rsa_gen.go +++ b/jwk/rsa_gen.go @@ -50,9 +50,9 @@ type rsaPrivateKey struct { dp []byte dq []byte e []byte - keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 - keyUsage *string // https://tools.ietf.org/html/rfc7517#section-4.2 - keyops KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 + keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 + keyUsage *string // https://tools.ietf.org/html/rfc7517#section-4.2 + keyops *KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 n []byte p []byte q []byte @@ -73,7 +73,7 @@ type rsaPrivateKeyMarshalProxy struct { Xe *string `json:"e,omitempty"` XkeyID *string `json:"kid,omitempty"` XkeyUsage *string `json:"use,omitempty"` - Xkeyops KeyOperationList `json:"key_ops,omitempty"` + Xkeyops *KeyOperationList `json:"key_ops,omitempty"` Xn *string `json:"n,omitempty"` Xp *string `json:"p,omitempty"` Xq *string `json:"q,omitempty"` @@ -126,7 +126,10 @@ func (h *rsaPrivateKey) KeyUsage() string { } func (h *rsaPrivateKey) KeyOps() KeyOperationList { - return h.keyops + if h.keyops != nil { + return *(h.keyops) + } + return nil } func (h *rsaPrivateKey) N() []byte { @@ -200,7 +203,7 @@ func (h *rsaPrivateKey) iterate(ctx context.Context, ch chan *HeaderPair) { pairs = append(pairs, &HeaderPair{Key: KeyUsageKey, Value: *(h.keyUsage)}) } if h.keyops != nil { - pairs = append(pairs, &HeaderPair{Key: KeyOpsKey, Value: h.keyops}) + pairs = append(pairs, &HeaderPair{Key: KeyOpsKey, Value: *(h.keyops)}) } if h.n != nil { pairs = append(pairs, &HeaderPair{Key: RSANKey, Value: h.n}) @@ -285,7 +288,7 @@ func (h *rsaPrivateKey) Get(name string) (interface{}, bool) { if h.keyops == nil { return nil, false } - return h.keyops, true + return *(h.keyops), true case RSANKey: if h.n == nil { return nil, false @@ -388,7 +391,7 @@ func (h *rsaPrivateKey) Set(name string, value interface{}) error { if err := acceptor.Accept(value); err != nil { return errors.Wrapf(err, `invalid value for %s key`, KeyOpsKey) } - h.keyops = acceptor + h.keyops = &acceptor return nil case RSANKey: if v, ok := value.([]byte); ok { @@ -654,9 +657,9 @@ type RSAPublicKey interface { type rsaPublicKey struct { algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4 e []byte - keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 - keyUsage *string // https://tools.ietf.org/html/rfc7517#section-4.2 - keyops KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 + keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 + keyUsage *string // https://tools.ietf.org/html/rfc7517#section-4.2 + keyops *KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 n []byte x509CertChain *CertificateChain // https://tools.ietf.org/html/rfc7515#section-4.1.6 x509CertThumbprint *string // https://tools.ietf.org/html/rfc7515#section-4.1.7 @@ -671,7 +674,7 @@ type rsaPublicKeyMarshalProxy struct { Xe *string `json:"e,omitempty"` XkeyID *string `json:"kid,omitempty"` XkeyUsage *string `json:"use,omitempty"` - Xkeyops KeyOperationList `json:"key_ops,omitempty"` + Xkeyops *KeyOperationList `json:"key_ops,omitempty"` Xn *string `json:"n,omitempty"` Xx509CertChain *CertificateChain `json:"x5c,omitempty"` Xx509CertThumbprint *string `json:"x5t,omitempty"` @@ -709,7 +712,10 @@ func (h *rsaPublicKey) KeyUsage() string { } func (h *rsaPublicKey) KeyOps() KeyOperationList { - return h.keyops + if h.keyops != nil { + return *(h.keyops) + } + return nil } func (h *rsaPublicKey) N() []byte { @@ -762,7 +768,7 @@ func (h *rsaPublicKey) iterate(ctx context.Context, ch chan *HeaderPair) { pairs = append(pairs, &HeaderPair{Key: KeyUsageKey, Value: *(h.keyUsage)}) } if h.keyops != nil { - pairs = append(pairs, &HeaderPair{Key: KeyOpsKey, Value: h.keyops}) + pairs = append(pairs, &HeaderPair{Key: KeyOpsKey, Value: *(h.keyops)}) } if h.n != nil { pairs = append(pairs, &HeaderPair{Key: RSANKey, Value: h.n}) @@ -823,7 +829,7 @@ func (h *rsaPublicKey) Get(name string) (interface{}, bool) { if h.keyops == nil { return nil, false } - return h.keyops, true + return *(h.keyops), true case RSANKey: if h.n == nil { return nil, false @@ -893,7 +899,7 @@ func (h *rsaPublicKey) Set(name string, value interface{}) error { if err := acceptor.Accept(value); err != nil { return errors.Wrapf(err, `invalid value for %s key`, KeyOpsKey) } - h.keyops = acceptor + h.keyops = &acceptor return nil case RSANKey: if v, ok := value.([]byte); ok { diff --git a/jwk/symmetric_gen.go b/jwk/symmetric_gen.go index ca210b050..665db9f51 100644 --- a/jwk/symmetric_gen.go +++ b/jwk/symmetric_gen.go @@ -29,10 +29,10 @@ type SymmetricKey interface { } type symmetricKey struct { - algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4 - keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 - keyUsage *string // https://tools.ietf.org/html/rfc7517#section-4.2 - keyops KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 + algorithm *string // https://tools.ietf.org/html/rfc7517#section-4.4 + keyID *string // https://tools.ietf.org/html/rfc7515#section-4.1.4 + keyUsage *string // https://tools.ietf.org/html/rfc7517#section-4.2 + keyops *KeyOperationList // https://tools.ietf.org/html/rfc7517#section-4.3 octets []byte x509CertChain *CertificateChain // https://tools.ietf.org/html/rfc7515#section-4.1.6 x509CertThumbprint *string // https://tools.ietf.org/html/rfc7515#section-4.1.7 @@ -46,7 +46,7 @@ type symmetricSymmetricKeyMarshalProxy struct { Xalgorithm *string `json:"alg,omitempty"` XkeyID *string `json:"kid,omitempty"` XkeyUsage *string `json:"use,omitempty"` - Xkeyops KeyOperationList `json:"key_ops,omitempty"` + Xkeyops *KeyOperationList `json:"key_ops,omitempty"` Xoctets *string `json:"k,omitempty"` Xx509CertChain *CertificateChain `json:"x5c,omitempty"` Xx509CertThumbprint *string `json:"x5t,omitempty"` @@ -80,7 +80,10 @@ func (h *symmetricKey) KeyUsage() string { } func (h *symmetricKey) KeyOps() KeyOperationList { - return h.keyops + if h.keyops != nil { + return *(h.keyops) + } + return nil } func (h *symmetricKey) Octets() []byte { @@ -130,7 +133,7 @@ func (h *symmetricKey) iterate(ctx context.Context, ch chan *HeaderPair) { pairs = append(pairs, &HeaderPair{Key: KeyUsageKey, Value: *(h.keyUsage)}) } if h.keyops != nil { - pairs = append(pairs, &HeaderPair{Key: KeyOpsKey, Value: h.keyops}) + pairs = append(pairs, &HeaderPair{Key: KeyOpsKey, Value: *(h.keyops)}) } if h.octets != nil { pairs = append(pairs, &HeaderPair{Key: SymmetricOctetsKey, Value: h.octets}) @@ -186,7 +189,7 @@ func (h *symmetricKey) Get(name string) (interface{}, bool) { if h.keyops == nil { return nil, false } - return h.keyops, true + return *(h.keyops), true case SymmetricOctetsKey: if h.octets == nil { return nil, false @@ -250,7 +253,7 @@ func (h *symmetricKey) Set(name string, value interface{}) error { if err := acceptor.Accept(value); err != nil { return errors.Wrapf(err, `invalid value for %s key`, KeyOpsKey) } - h.keyops = acceptor + h.keyops = &acceptor return nil case SymmetricOctetsKey: if v, ok := value.([]byte); ok { From 470ab5fc4ccd5d24e576c2fe85ecf4472d205585 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 18:33:18 +0900 Subject: [PATCH 18/22] RSA content cipher is not supported as of this moment Remove, for the sake of the coverage game --- jwe/internal/cipher/cipher.go | 7 ------- jwe/internal/cipher/interface.go | 6 ------ 2 files changed, 13 deletions(-) diff --git a/jwe/internal/cipher/cipher.go b/jwe/internal/cipher/cipher.go index e2ccbf3ee..ef1bde49c 100644 --- a/jwe/internal/cipher/cipher.go +++ b/jwe/internal/cipher/cipher.go @@ -3,7 +3,6 @@ package cipher import ( "crypto/aes" "crypto/cipher" - "crypto/rsa" "fmt" "github.com/lestrrat-go/jwx/jwa" @@ -185,9 +184,3 @@ func (c AesContentCipher) Decrypt(cek, iv, ciphertxt, tag, aad []byte) (plaintex plaintext, err = aead.Open(nil, iv, combined, aad) return } - -func NewRsaContentCipher(alg jwa.ContentEncryptionAlgorithm, pubkey *rsa.PublicKey) (*RsaContentCipher, error) { - return &RsaContentCipher{ - pubkey: pubkey, - }, nil -} diff --git a/jwe/internal/cipher/interface.go b/jwe/internal/cipher/interface.go index ad819667a..8af8202d4 100644 --- a/jwe/internal/cipher/interface.go +++ b/jwe/internal/cipher/interface.go @@ -2,7 +2,6 @@ package cipher import ( "crypto/cipher" - "crypto/rsa" "github.com/lestrrat-go/jwx/jwe/internal/keygen" ) @@ -33,8 +32,3 @@ type AesContentCipher struct { keysize int tagsize int } - -// RsaContentCipher represents a cipher based on RSA -type RsaContentCipher struct { - pubkey *rsa.PublicKey -} From ba6d5516769561b0860240a5c4d7f1a5bee9ac08 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 19:51:58 +0900 Subject: [PATCH 19/22] Excercise jwe header accessors --- jwe/headers_test.go | 84 +++++++++++++++++++++++++++++++++------------ 1 file changed, 62 insertions(+), 22 deletions(-) diff --git a/jwe/headers_test.go b/jwe/headers_test.go index c579da9ca..ecb07b8e0 100644 --- a/jwe/headers_test.go +++ b/jwe/headers_test.go @@ -2,8 +2,12 @@ package jwe_test import ( "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "testing" + "github.com/lestrrat-go/jwx/buffer" "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwe" "github.com/lestrrat-go/jwx/jwk" @@ -11,34 +15,62 @@ import ( ) func TestHeaders(t *testing.T) { - t.Run("Set/Get", func(t *testing.T) { - h := jwe.NewHeaders() + rawKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + if !assert.NoError(t, err, `ecdsa.GenerateKey should succeed`) { + return + } + privKey, err := jwk.New(rawKey) + if !assert.NoError(t, err, `jwk.New should succeed`) { + return + } + + pubKey, err := jwk.New(rawKey.PublicKey) + if !assert.NoError(t, err, `jwk.PublicKey should succeed`) { + return + } + + data := []struct { + Key string + Value interface{} + Expected interface{} + }{ + {Key: jwe.AgreementPartyUInfoKey, Value: []byte("apu foobarbaz"), Expected: buffer.Buffer("apu foobarbaz")}, + {Key: jwe.AgreementPartyVInfoKey, Value: []byte("apv foobarbaz"), Expected: buffer.Buffer("apv foobarbaz")}, + {Key: jwe.CompressionKey, Value: jwa.Deflate}, + {Key: jwe.ContentEncryptionKey, Value: jwa.A128GCM}, + {Key: jwe.ContentTypeKey, Value: "application/json"}, + {Key: jwe.CriticalKey, Value: []string{"crit blah"}}, + {Key: jwe.EphemeralPublicKeyKey, Value: pubKey}, + {Key: jwe.JWKKey, Value: privKey}, + {Key: jwe.JWKSetURLKey, Value: "http://github.com/lestrrat-go/jwx"}, + {Key: jwe.KeyIDKey, Value: "kid blah"}, + {Key: jwe.TypeKey, Value: "typ blah"}, + {Key: jwe.X509CertThumbprintKey, Value: "x5t blah"}, + {Key: jwe.X509CertThumbprintS256Key, Value: "x5t#256 blah"}, + {Key: jwe.X509URLKey, Value: "http://github.com/lestrrat-go/jwx"}, + } - data := map[string]struct { - Value interface{} - Expected interface{} - }{ - "kid": {Value: "kid blah"}, - "enc": {Value: jwa.A128GCM}, - "cty": {Value: "application/json"}, - "typ": {Value: "typ blah"}, - "x5t": {Value: "x5t blah"}, - "x5t#256": {Value: "x5t#256 blah"}, - "crit": {Value: []string{"crit blah"}}, - "jku": {Value: "http://github.com/lestrrat-go/jwx"}, - "x5u": {Value: "http://github.com/lestrrat-go/jwx"}, + base := jwe.NewHeaders() + + t.Run("Set values", func(t *testing.T) { + for _, tc := range data { + if !assert.NoError(t, base.Set(tc.Key, tc.Value), "Headers.Set should succeed") { + return + } } + }) - for name, testcase := range data { - h.Set(name, testcase.Value) - got, ok := h.Get(name) - if !assert.True(t, ok, "value should exist") { + t.Run("Set/Get", func(t *testing.T) { + h := base + for _, tc := range data { + got, ok := h.Get(tc.Key) + if !assert.True(t, ok, "value for %s should exist", tc.Key) { return } - expected := testcase.Expected + expected := tc.Expected if expected == nil { - expected = testcase.Value + expected = tc.Value } if !assert.Equal(t, expected, got, "value should match") { return @@ -67,7 +99,15 @@ func TestHeaders(t *testing.T) { t.Run("Iterator", func(t *testing.T) { expected := map[string]interface{}{} - v := jwe.NewHeaders() + for _, tc := range data { + v := tc.Value + if expected := tc.Expected; expected != nil { + v = expected + } + expected[tc.Key] = v + } + + v := base t.Run("Iterate", func(t *testing.T) { seen := make(map[string]interface{}) for iter := v.Iterate(context.TODO()); iter.Next(context.TODO()); { From 2b1f848b970e29d90400d9e4500a8d3ede70f9b1 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 20:49:07 +0900 Subject: [PATCH 20/22] Excercise a bit more on the jwe.Message Some signatures have been changed, but I doubt that anybody was actually using jwe.Message directly --- Changes | 1 + jwe/encrypt.go | 16 +++--- jwe/interface.go | 8 +-- jwe/jwe.go | 14 ++--- jwe/lowlevel_test.go | 12 ++--- jwe/message.go | 120 ++++++++++++++++++++++++++++++++++--------- 6 files changed, 125 insertions(+), 46 deletions(-) diff --git a/Changes b/Changes index 9fb052a8e..c3e769f86 100644 --- a/Changes +++ b/Changes @@ -4,6 +4,7 @@ Changes v1.0.2 * Add jwk.AssignKeyID to automatically assign a `kid` field to a JWK * Fix jwe.Encrypt / jwe.Decrypt to properly look at the `zip` field + * Change jwe.Message accessors to return []byte, not buffer.Buffer v1.0.1 - 04 May 2020 * Normalize all JWK serialization to use padding-less base64 encoding (#185) diff --git a/jwe/encrypt.go b/jwe/encrypt.go index f52c4e2db..21f2e19e1 100644 --- a/jwe/encrypt.go +++ b/jwe/encrypt.go @@ -4,6 +4,7 @@ import ( "context" "sync" + "github.com/lestrrat-go/jwx/buffer" "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/pdebug" "github.com/pkg/errors" @@ -127,14 +128,17 @@ func (e encryptCtx) Encrypt(plaintext []byte) (*Message, error) { } msg := NewMessage() - if err := msg.authenticatedData.Base64Decode(aad); err != nil { + + decodedAad, err := buffer.FromBase64(aad) + if err != nil { return nil, errors.Wrap(err, "failed to decode base64") } - msg.cipherText = ciphertext - msg.initializationVector = iv - msg.protectedHeaders = protected - msg.recipients = recipients - msg.tag = tag + msg.Set(AuthenticatedDataKey, decodedAad.Bytes()) + msg.Set(CipherTextKey, ciphertext) + msg.Set(InitializationVectorKey, iv) + msg.Set(ProtectedHeadersKey, protected) + msg.Set(RecipientsKey, recipients) + msg.Set(TagKey, tag) return msg, nil } diff --git a/jwe/interface.go b/jwe/interface.go index 15616f9e3..ed1f3384c 100644 --- a/jwe/interface.go +++ b/jwe/interface.go @@ -29,12 +29,12 @@ type stdRecipient struct { // Message contains the entire encrypted JWE message type Message struct { - authenticatedData buffer.Buffer - cipherText buffer.Buffer - initializationVector buffer.Buffer + authenticatedData *buffer.Buffer + cipherText *buffer.Buffer + initializationVector *buffer.Buffer protectedHeaders Headers recipients []Recipient - tag buffer.Buffer + tag *buffer.Buffer unprotectedHeaders Headers } diff --git a/jwe/jwe.go b/jwe/jwe.go index 728d617a0..bf5a6713b 100644 --- a/jwe/jwe.go +++ b/jwe/jwe.go @@ -213,17 +213,17 @@ func parseCompact(buf []byte) (*Message, error) { } m := NewMessage() - m.authenticatedData.SetBytes(hdrbuf.Bytes()) - m.protectedHeaders = protected - m.tag = tagbuf - m.cipherText = ctbuf - m.initializationVector = ivbuf - m.recipients = []Recipient{ + m.Set(AuthenticatedDataKey, hdrbuf.Bytes()) + m.Set(CipherTextKey, ctbuf) + m.Set(InitializationVectorKey, ivbuf) + m.Set(ProtectedHeadersKey, protected) + m.Set(RecipientsKey, []Recipient{ &stdRecipient{ headers: hdr, encryptedKey: enckeybuf, }, - } + }) + m.Set(TagKey, tagbuf) return m, nil } diff --git a/jwe/lowlevel_test.go b/jwe/lowlevel_test.go index a47e1036d..abbf2f29b 100644 --- a/jwe/lowlevel_test.go +++ b/jwe/lowlevel_test.go @@ -103,12 +103,12 @@ func TestLowLevelParts_A128KW_A128CBCHS256(t *testing.T) { protected.Set(ContentEncryptionKey, jwa.A128CBC_HS256) msg := NewMessage() - msg.protectedHeaders = protected - msg.authenticatedData = aad - msg.cipherText = ciphertext - msg.initializationVector = iv - msg.tag = tag - msg.recipients = []Recipient{r} + msg.Set(ProtectedHeadersKey, protected) + msg.Set(AuthenticatedDataKey, aad) + msg.Set(CipherTextKey, ciphertext) + msg.Set(InitializationVectorKey, iv) + msg.Set(TagKey, tag) + msg.Set(RecipientsKey, []Recipient{r}) serialized, err := Compact(msg) if !assert.NoError(t, err, "compact serialization is successful") { diff --git a/jwe/message.go b/jwe/message.go index 93fcb2e3f..118fb0d6b 100644 --- a/jwe/message.go +++ b/jwe/message.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/lestrrat-go/jwx/buffer" + "github.com/lestrrat-go/jwx/internal/base64" "github.com/lestrrat-go/jwx/jwa" "github.com/lestrrat-go/jwx/jwe/internal/cipher" "github.com/lestrrat-go/pdebug" @@ -91,16 +92,32 @@ func NewMessage() *Message { return &Message{} } -func (m *Message) AuthenticatedData() buffer.Buffer { - return m.authenticatedData +func (m *Message) AuthenticatedData() []byte { + if m.authenticatedData == nil { + return nil + } + return m.authenticatedData.Bytes() } -func (m *Message) CipherText() buffer.Buffer { - return m.cipherText +func (m *Message) CipherText() []byte { + if m.cipherText == nil { + return nil + } + return m.cipherText.Bytes() } -func (m *Message) InitializationVector() buffer.Buffer { - return m.initializationVector +func (m *Message) InitializationVector() []byte { + if m.initializationVector == nil { + return nil + } + return m.initializationVector.Bytes() +} + +func (m *Message) Tag() []byte { + if m.tag == nil { + return nil + } + return m.tag.Bytes() } func (m *Message) ProtectedHeaders() Headers { @@ -125,13 +142,70 @@ const ( UnprotectedHeadersKey = "unprotected" ) +func (m *Message) Set(k string, v interface{}) error { + switch k { + case AuthenticatedDataKey: + var acceptor buffer.Buffer + if err := acceptor.Accept(v); err != nil { + return errors.Wrapf(err, `invalid value %T for %s key`, v, AuthenticatedDataKey) + } + m.authenticatedData = &acceptor + return nil + case CipherTextKey: + var acceptor buffer.Buffer + if err := acceptor.Accept(v); err != nil { + return errors.Wrapf(err, `invalid value %T for %s key`, v, CipherTextKey) + } + m.cipherText = &acceptor + return nil + case InitializationVectorKey: + var acceptor buffer.Buffer + if err := acceptor.Accept(v); err != nil { + return errors.Wrapf(err, `invalid value %T for %s key`, v, InitializationVectorKey) + } + m.initializationVector = &acceptor + return nil + case ProtectedHeadersKey: + cv, ok := v.(Headers) + if !ok { + return errors.Errorf(`invalid value %T for %s key`, v, ProtectedHeadersKey) + } + m.protectedHeaders = cv + case RecipientsKey: + cv, ok := v.([]Recipient) + if !ok { + return errors.Errorf(`invalid value %T for %s key`, v, RecipientsKey) + } + m.recipients = cv + case TagKey: + var acceptor buffer.Buffer + if err := acceptor.Accept(v); err != nil { + return errors.Wrapf(err, `invalid value %T for %s key`, v, TagKey) + } + m.tag = &acceptor + return nil + case UnprotectedHeadersKey: + cv, ok := v.(Headers) + if !ok { + return errors.Errorf(`invalid value %T for %s key`, v, UnprotectedHeadersKey) + } + m.unprotectedHeaders = cv + default: + if m.unprotectedHeaders == nil { + m.unprotectedHeaders = NewHeaders() + } + return m.unprotectedHeaders.Set(k, v) + } + return errors.New(`unreached`) +} + type messageMarshalProxy struct { - AuthenticatedData buffer.Buffer `json:"aad,omitempty"` - CipherText buffer.Buffer `json:"ciphertext"` - InitializationVector buffer.Buffer `json:"iv,omitempty"` + AuthenticatedData *buffer.Buffer `json:"aad,omitempty"` + CipherText *buffer.Buffer `json:"ciphertext"` + InitializationVector *buffer.Buffer `json:"iv,omitempty"` ProtectedHeaders json.RawMessage `json:"protected"` Recipients []json.RawMessage `json:"recipients"` - Tag buffer.Buffer `json:"tag,omitempty"` + Tag *buffer.Buffer `json:"tag,omitempty"` UnprotectedHeaders Headers `json:"unprotected,omitempty"` } @@ -143,37 +217,37 @@ func (m *Message) MarshalJSON() ([]byte, error) { fmt.Fprintf(&buf, `{`) var wrote bool - if m.authenticatedData.Len() > 0 { + if aad := m.AuthenticatedData(); len(aad) > 0 { wrote = true fmt.Fprintf(&buf, `%#v:`, AuthenticatedDataKey) - if err := enc.Encode(m.authenticatedData); err != nil { + if err := enc.Encode(base64.EncodeToString(aad)); err != nil { return nil, errors.Wrapf(err, `failed to encode %s field`, AuthenticatedDataKey) } } - if m.cipherText.Len() > 0 { + if cipherText := m.CipherText(); len(cipherText) > 0 { if wrote { fmt.Fprintf(&buf, `,`) } wrote = true fmt.Fprintf(&buf, `%#v:`, CipherTextKey) - if err := enc.Encode(m.cipherText); err != nil { + if err := enc.Encode(base64.EncodeToString(cipherText)); err != nil { return nil, errors.Wrapf(err, `failed to encode %s field`, CipherTextKey) } } - if m.initializationVector.Len() > 0 { + if iv := m.InitializationVector(); len(iv) > 0 { if wrote { fmt.Fprintf(&buf, `,`) } wrote = true fmt.Fprintf(&buf, `%#v:`, InitializationVectorKey) - if err := enc.Encode(m.initializationVector); err != nil { + if err := enc.Encode(base64.EncodeToString(iv)); err != nil { return nil, errors.Wrapf(err, `failed to encode %s field`, InitializationVectorKey) } } - if m.protectedHeaders != nil { - encodedHeaders, err := m.protectedHeaders.Encode() + if h := m.ProtectedHeaders(); h != nil { + encodedHeaders, err := h.Encode() if err != nil { return nil, errors.Wrap(err, `failed to encode protected headers`) } @@ -191,19 +265,19 @@ func (m *Message) MarshalJSON() ([]byte, error) { fmt.Fprintf(&buf, `,`) } fmt.Fprintf(&buf, `%#v:`, RecipientsKey) - if err := enc.Encode(m.recipients); err != nil { + if err := enc.Encode(m.Recipients()); err != nil { return nil, errors.Wrapf(err, `failed to encode %s field`, RecipientsKey) } - if m.tag.Len() > 0 { + if tag := m.Tag(); len(tag) > 0 { fmt.Fprintf(&buf, `,%#v:`, TagKey) - if err := enc.Encode(m.tag); err != nil { + if err := enc.Encode(base64.EncodeToString(tag)); err != nil { return nil, errors.Wrapf(err, `failed to encode %s field`, TagKey) } } - if m.unprotectedHeaders != nil { - unprotected, err := json.Marshal(m.unprotectedHeaders) + if h := m.UnprotectedHeaders(); h != nil { + unprotected, err := json.Marshal(h) if err != nil { return nil, errors.Wrap(err, `failed to encode unprotected headers`) } From 0847df580937bea80dd3032d6b99723b97d337a9 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 20:56:14 +0900 Subject: [PATCH 21/22] appease golangci-lint I'm sure there's more of the same --- jwe/encrypt.go | 24 ++++++++++++++++++------ jwe/message.go | 6 +----- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/jwe/encrypt.go b/jwe/encrypt.go index 21f2e19e1..ee5b30a96 100644 --- a/jwe/encrypt.go +++ b/jwe/encrypt.go @@ -133,12 +133,24 @@ func (e encryptCtx) Encrypt(plaintext []byte) (*Message, error) { if err != nil { return nil, errors.Wrap(err, "failed to decode base64") } - msg.Set(AuthenticatedDataKey, decodedAad.Bytes()) - msg.Set(CipherTextKey, ciphertext) - msg.Set(InitializationVectorKey, iv) - msg.Set(ProtectedHeadersKey, protected) - msg.Set(RecipientsKey, recipients) - msg.Set(TagKey, tag) + if err := msg.Set(AuthenticatedDataKey, decodedAad.Bytes()); err != nil { + return nil, errors.Wrapf(err, `failed to set %s`, AuthenticatedDataKey) + } + if err := msg.Set(CipherTextKey, ciphertext); err != nil { + return nil, errors.Wrapf(err, `failed to set %s`, CipherTextKey) + } + if err := msg.Set(InitializationVectorKey, iv); err != nil { + return nil, errors.Wrapf(err, `failed to set %s`, InitializationVectorKey) + } + if err := msg.Set(ProtectedHeadersKey, protected); err != nil { + return nil, errors.Wrapf(err, `failed to set %s`, ProtectedHeadersKey) + } + if err := msg.Set(RecipientsKey, recipients); err != nil { + return nil, errors.Wrapf(err, `failed to set %s`, RecipientsKey) + } + if err := msg.Set(TagKey, tag); err != nil { + return nil, errors.Wrapf(err, `failed to set %s`, TagKey) + } return msg, nil } diff --git a/jwe/message.go b/jwe/message.go index 118fb0d6b..dfcd249d5 100644 --- a/jwe/message.go +++ b/jwe/message.go @@ -150,21 +150,18 @@ func (m *Message) Set(k string, v interface{}) error { return errors.Wrapf(err, `invalid value %T for %s key`, v, AuthenticatedDataKey) } m.authenticatedData = &acceptor - return nil case CipherTextKey: var acceptor buffer.Buffer if err := acceptor.Accept(v); err != nil { return errors.Wrapf(err, `invalid value %T for %s key`, v, CipherTextKey) } m.cipherText = &acceptor - return nil case InitializationVectorKey: var acceptor buffer.Buffer if err := acceptor.Accept(v); err != nil { return errors.Wrapf(err, `invalid value %T for %s key`, v, InitializationVectorKey) } m.initializationVector = &acceptor - return nil case ProtectedHeadersKey: cv, ok := v.(Headers) if !ok { @@ -183,7 +180,6 @@ func (m *Message) Set(k string, v interface{}) error { return errors.Wrapf(err, `invalid value %T for %s key`, v, TagKey) } m.tag = &acceptor - return nil case UnprotectedHeadersKey: cv, ok := v.(Headers) if !ok { @@ -196,7 +192,7 @@ func (m *Message) Set(k string, v interface{}) error { } return m.unprotectedHeaders.Set(k, v) } - return errors.New(`unreached`) + return nil } type messageMarshalProxy struct { From 35a33c4e848c4c6095011eb180c2a817607a57d3 Mon Sep 17 00:00:00 2001 From: Daisuke Maki Date: Thu, 7 May 2020 21:01:34 +0900 Subject: [PATCH 22/22] appease golangci-lint --- jwe/jwe.go | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/jwe/jwe.go b/jwe/jwe.go index bf5a6713b..c821d9644 100644 --- a/jwe/jwe.go +++ b/jwe/jwe.go @@ -213,17 +213,30 @@ func parseCompact(buf []byte) (*Message, error) { } m := NewMessage() - m.Set(AuthenticatedDataKey, hdrbuf.Bytes()) - m.Set(CipherTextKey, ctbuf) - m.Set(InitializationVectorKey, ivbuf) - m.Set(ProtectedHeadersKey, protected) - m.Set(RecipientsKey, []Recipient{ + if err := m.Set(AuthenticatedDataKey, hdrbuf.Bytes()); err != nil { + return nil, errors.Wrapf(err, `failed to set %s`, AuthenticatedDataKey) + } + if err := m.Set(CipherTextKey, ctbuf); err != nil { + return nil, errors.Wrapf(err, `failed to set %s`, CipherTextKey) + } + if err := m.Set(InitializationVectorKey, ivbuf); err != nil { + return nil, errors.Wrapf(err, `failed to set %s`, InitializationVectorKey) + } + if err := m.Set(ProtectedHeadersKey, protected); err != nil { + return nil, errors.Wrapf(err, `failed to set %s`, ProtectedHeadersKey) + } + + if err := m.Set(RecipientsKey, []Recipient{ &stdRecipient{ headers: hdr, encryptedKey: enckeybuf, }, - }) - m.Set(TagKey, tagbuf) + }); err != nil { + return nil, errors.Wrapf(err, `failed to set %s`, RecipientsKey) + } + if err := m.Set(TagKey, tagbuf); err != nil { + return nil, errors.Wrapf(err, `failed to set %s`, TagKey) + } return m, nil }