Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor code by using switch instead of if-else #318

Merged
merged 2 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 44 additions & 39 deletions cmd/jwt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ func main() {

// Figure out which thing to do and then do that
func start() error {
if *flagSign != "" {
switch {
case *flagSign != "":
return signToken()
} else if *flagVerify != "" {
case *flagVerify != "":
return verifyToken()
} else if *flagShow != "" {
case *flagShow != "":
return showToken()
} else {
default:
flag.Usage()
return fmt.Errorf("none of the required flags are present. What do you want me to do?")
}
Expand All @@ -79,17 +80,18 @@ func loadData(p string) ([]byte, error) {
}

var rdr io.Reader
if p == "-" {
switch p {
case "-":
rdr = os.Stdin
} else if p == "+" {
case "+":
return []byte("{}"), nil
} else {
if f, err := os.Open(p); err == nil {
rdr = f
defer f.Close()
} else {
default:
f, err := os.Open(p)
if err != nil {
return nil, err
}
rdr = f
defer f.Close()
}
return io.ReadAll(rdr)
}
Expand Down Expand Up @@ -136,14 +138,16 @@ func verifyToken() error {
if err != nil {
return nil, err
}
if isEs() {
switch {
case isEs():
return jwt.ParseECPublicKeyFromPEM(data)
} else if isRs() {
case isRs():
return jwt.ParseRSAPublicKeyFromPEM(data)
} else if isEd() {
case isEd():
return jwt.ParseEdPublicKeyFromPEM(data)
default:
return data, nil
}
return data, nil
})

// Print some debug data
Expand Down Expand Up @@ -221,40 +225,41 @@ func signToken() error {
}
}

if isEs() {
if k, ok := key.([]byte); !ok {
switch {
case isEs():
k, ok := key.([]byte)
if !ok {
return fmt.Errorf("couldn't convert key data to key")
} else {
key, err = jwt.ParseECPrivateKeyFromPEM(k)
if err != nil {
return err
}
}
} else if isRs() {
if k, ok := key.([]byte); !ok {
key, err = jwt.ParseECPrivateKeyFromPEM(k)
if err != nil {
return err
}
case isRs():
k, ok := key.([]byte)
if !ok {
return fmt.Errorf("couldn't convert key data to key")
} else {
key, err = jwt.ParseRSAPrivateKeyFromPEM(k)
if err != nil {
return err
}
}
} else if isEd() {
if k, ok := key.([]byte); !ok {
key, err = jwt.ParseRSAPrivateKeyFromPEM(k)
if err != nil {
return err
}
case isEd():
k, ok := key.([]byte)
if !ok {
return fmt.Errorf("couldn't convert key data to key")
} else {
key, err = jwt.ParseEdPrivateKeyFromPEM(k)
if err != nil {
return err
}
}
key, err = jwt.ParseEdPrivateKeyFromPEM(k)
if err != nil {
return err
}
}

if out, err := token.SignedString(key); err == nil {
fmt.Println(out)
} else {
out, err := token.SignedString(key)
if err != nil {
return fmt.Errorf("error signing token: %w", err)
}
fmt.Println(out)

return nil
}
Expand Down
11 changes: 6 additions & 5 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,18 @@ func ExampleParse_errorChecking() {
return []byte("AllYourBase"), nil
})

if token.Valid {
switch {
case token.Valid:
fmt.Println("You look nice today")
} else if errors.Is(err, jwt.ErrTokenMalformed) {
case errors.Is(err, jwt.ErrTokenMalformed):
fmt.Println("That's not even a token")
} else if errors.Is(err, jwt.ErrTokenSignatureInvalid) {
case errors.Is(err, jwt.ErrTokenSignatureInvalid):
// Invalid signature
fmt.Println("Invalid signature")
} else if errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet) {
case errors.Is(err, jwt.ErrTokenExpired) || errors.Is(err, jwt.ErrTokenNotValidYet):
// Token is either expired or not active yet
fmt.Println("Timing is everything")
} else {
default:
fmt.Println("Couldn't handle this token:", err)
}

Expand Down
21 changes: 11 additions & 10 deletions hmac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,17 @@ func TestHMACVerify(t *testing.T) {

func TestHMACSign(t *testing.T) {
for _, data := range hmacTestData {
if data.valid {
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(strings.Join(parts[0:2], "."), hmacTestKey)
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
}
if !data.valid {
continue
}
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(strings.Join(parts[0:2], "."), hmacTestKey)
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
}
}
}
Expand Down
21 changes: 11 additions & 10 deletions none_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,17 @@ func TestNoneVerify(t *testing.T) {

func TestNoneSign(t *testing.T) {
for _, data := range noneTestData {
if data.valid {
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(strings.Join(parts[0:2], "."), data.key)
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
}
if !data.valid {
continue
}
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(strings.Join(parts[0:2], "."), data.key)
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
}
}
}
25 changes: 13 additions & 12 deletions rsa_pss_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,19 @@ func TestRSAPSSSign(t *testing.T) {
}

for _, data := range rsaPSSTestData {
if data.valid {
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(strings.Join(parts[0:2], "."), rsaPSSKey)
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}

ssig := encodeSegment(sig)
if ssig == parts[2] {
t.Errorf("[%v] Signatures shouldn't match\nnew:\n%v\noriginal:\n%v", data.name, ssig, parts[2])
}
if !data.valid {
continue
}
parts := strings.Split(data.tokenString, ".")
method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(strings.Join(parts[0:2], "."), rsaPSSKey)
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}

ssig := encodeSegment(sig)
if ssig == parts[2] {
t.Errorf("[%v] Signatures shouldn't match\nnew:\n%v\noriginal:\n%v", data.name, ssig, parts[2])
}
}
}
Expand Down