Skip to content

Commit

Permalink
fix: migrate to a different jwt library
Browse files Browse the repository at this point in the history
  • Loading branch information
shaj13 committed Feb 14, 2021
1 parent 8642557 commit e678879
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 126 deletions.
25 changes: 11 additions & 14 deletions auth/strategies/jwt/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,12 @@ import (
"net/http"
"time"

"github.com/shaj13/go-guardian/v2/auth/strategies/token"

"github.com/shaj13/go-guardian/v2/auth"
"github.com/shaj13/go-guardian/v2/auth/strategies/jwt"
"github.com/shaj13/go-guardian/v2/auth/strategies/token"

gojwt "github.com/dgrijalva/jwt-go/v4"
"github.com/shaj13/libcache"
_ "github.com/shaj13/libcache/lru"

"github.com/shaj13/go-guardian/v2/auth/strategies/jwt"
)

type RotatedSecrets struct {
Expand All @@ -32,21 +29,21 @@ func (r RotatedSecrets) KID() string {
return r.LatestID
}

func (r RotatedSecrets) Get(kid string) (key interface{}, m gojwt.SigningMethod, err error) {
func (r RotatedSecrets) Get(kid string) (key interface{}, alg string, err error) {
s, ok := r.Secrtes[kid]
if ok {
return s, gojwt.SigningMethodHS256, nil
return s, jwt.HS256, nil
}
return nil, nil, fmt.Errorf("Invalid KID %s", kid)
return nil, "", fmt.Errorf("Invalid KID %s", kid)
}

func Example() {
u := auth.NewUserInfo("example", "example", nil, nil)
c := libcache.LRU.New(0)
s := jwt.StaticSecret{
ID: "id",
Method: gojwt.SigningMethodHS256,
Secret: []byte("your secret"),
ID: "id",
Algorithm: jwt.HS256,
Secret: []byte("your secret"),
}

token, err := jwt.IssueAccessToken(u, s)
Expand All @@ -71,9 +68,9 @@ func Example_scope() {
u := auth.NewUserInfo("example", "example", nil, nil)
c := libcache.LRU.New(0)
s := jwt.StaticSecret{
ID: "id",
Method: gojwt.SigningMethodHS256,
Secret: []byte("your secret"),
ID: "id",
Algorithm: jwt.HS256,
Secret: []byte("your secret"),
}

token, err := jwt.IssueAccessToken(u, s, ns)
Expand Down
2 changes: 1 addition & 1 deletion auth/strategies/jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func GetAuthenticateFunc(s SecretsKeeper, opts ...auth.Option) token.Authenticat
if len(c.Scopes) > 0 {
token.WithNamedScopes(c.UserInfo, c.Scopes...)
}
return c.UserInfo, c.ExpiresAt.Time, err
return c.UserInfo, c.Expiry.Time(), err
}
}

Expand Down
8 changes: 3 additions & 5 deletions auth/strategies/jwt/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package jwt
import (
"time"

"github.com/dgrijalva/jwt-go/v4"

"github.com/shaj13/go-guardian/v2/auth"
)

Expand All @@ -13,13 +11,13 @@ import (
func SetAudience(aud string) auth.Option {
return auth.OptionFunc(func(v interface{}) {
if t, ok := v.(*accessToken); ok {
t.aud = jwt.ClaimStrings{aud}
t.aud = aud
}
})
}

// SetIssuer sets token issuer(iss),
// Default Value "go-guardian".
// no default value.
func SetIssuer(iss string) auth.Option {
return auth.OptionFunc(func(v interface{}) {
if t, ok := v.(*accessToken); ok {
Expand All @@ -33,7 +31,7 @@ func SetIssuer(iss string) auth.Option {
func SetExpDuration(d time.Duration) auth.Option {
return auth.OptionFunc(func(v interface{}) {
if t, ok := v.(*accessToken); ok {
t.d = d
t.dur = d
}
})
}
Expand Down
7 changes: 4 additions & 3 deletions auth/strategies/jwt/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ import (
)

func TestSetAudience(t *testing.T) {
opt := SetAudience("test")
aud := "test"
opt := SetAudience(aud)
tk := newAccessToken(nil, opt)
assert.Equal(t, "test", tk.aud[0])
assert.Equal(t, aud, tk.aud)
}

func TestSetIssuer(t *testing.T) {
Expand All @@ -22,5 +23,5 @@ func TestSetIssuer(t *testing.T) {
func TestSetExpDuration(t *testing.T) {
opt := SetExpDuration(time.Hour)
tk := newAccessToken(nil, opt)
assert.Equal(t, time.Hour, tk.d)
assert.Equal(t, time.Hour, tk.dur)
}
20 changes: 9 additions & 11 deletions auth/strategies/jwt/secrets_keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package jwt

import (
"errors"

"github.com/dgrijalva/jwt-go/v4"
)

// SecretsKeeper hold all secrets/keys to sign and parse JWT token
Expand All @@ -12,28 +10,28 @@ type SecretsKeeper interface {
// KID must return the most recently used id if more than one secret/key exists.
// https://tools.ietf.org/html/rfc7515#section-4.1.4
KID() string
// Get return's secret/key and the corresponding sign method.
Get(kid string) (key interface{}, m jwt.SigningMethod, err error)
// Get return's secret/key and the corresponding sign algorithm.
Get(kid string) (key interface{}, algorithm string, err error)
}

// StaticSecret implements the SecretsKeeper and holds only a single secret.
type StaticSecret struct {
Secret interface{}
ID string
Method jwt.SigningMethod
Secret interface{}
ID string
Algorithm string
}

// KID return's secret/key id.
func (s StaticSecret) KID() string {
return s.ID
}

// Get return's secret/key and the corresponding sign method.
func (s StaticSecret) Get(kid string) (key interface{}, m jwt.SigningMethod, err error) {
// Get return's secret/key and the corresponding sign algorithm.
func (s StaticSecret) Get(kid string) (key interface{}, algorithm string, err error) {
if kid != s.ID {
msg := "strategies/jwt: Invalid " + kid + " KID"
return nil, nil, errors.New(msg)
return nil, "", errors.New(msg)
}

return s.Secret, s.Method, nil
return s.Secret, s.Algorithm, nil
}
16 changes: 7 additions & 9 deletions auth/strategies/jwt/secrets_keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,31 @@ package jwt
import (
"testing"

"github.com/dgrijalva/jwt-go/v4"
"github.com/stretchr/testify/assert"
)

func TestStaticSecretGet(t *testing.T) {
t.Run("StaticSecretGet always return same secret", func(t *testing.T) {
method := jwt.SigningMethodHS256
kid := "test-kid"
secret := []byte("test-secret")
s := StaticSecret{
ID: kid,
Method: method,
Secret: secret,
ID: kid,
Algorithm: HS256,
Secret: secret,
}
for i := 0; i < 10; i++ {
gotSecret, gotMethod, err := s.Get(kid)
gotSecret, gotAlg, err := s.Get(kid)
assert.NoError(t, err)
assert.Equal(t, secret, gotSecret)
assert.Equal(t, method, gotMethod)
assert.Equal(t, HS256, gotAlg)
}
})

t.Run("StaticSecretGet return error when kid invalid", func(t *testing.T) {
s := StaticSecret{}
secret, method, err := s.Get("kid")
secret, alg, err := s.Get("kid")
assert.Error(t, err)
assert.Nil(t, secret)
assert.Nil(t, method)
assert.Empty(t, alg)
})
}
Loading

0 comments on commit e678879

Please sign in to comment.