Skip to content

Commit

Permalink
Refactor TestNewAccount with Table-Driven Tests and Enhance Readability
Browse files Browse the repository at this point in the history
  • Loading branch information
sanowl authored and sethterashima committed Jun 17, 2024
1 parent 5e966bd commit d60de3b
Showing 1 changed file with 39 additions and 31 deletions.
70 changes: 39 additions & 31 deletions pkg/account/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,44 @@ import (
"testing"
)

// b64Encode encodes a string to base64 without padding.
func b64Encode(payload string) string {
return base64.RawStdEncoding.EncodeToString([]byte(payload))
}

// TestNewAccount tests the creation of a new account with various JWT scenarios.
func TestNewAccount(t *testing.T) {
validDomain := "fleet-api.example.tesla.com"
if _, err := New("", ""); err == nil {
t.Error("Returned success empty JWT")
}
if _, err := New(b64Encode(validDomain), ""); err == nil {
t.Error("Returned success on one-field JWT")
}
if _, err := New("x."+b64Encode(validDomain), ""); err == nil {
t.Error("Returned success on two-field JWT")
}
if _, err := New("x."+b64Encode(validDomain)+"y.z", ""); err == nil {
t.Error("Returned success on four-field JWT")
}
if _, err := New("x."+validDomain+".y", ""); err == nil {
t.Error("Returned success on non-base64 encoded JWT")
}
if _, err := New("x."+b64Encode("{\"aud\": \"example.com\"}")+".y", ""); err == nil {
t.Error("Returned success on untrusted domain")
}
if _, err := New("x."+b64Encode(fmt.Sprintf("{\"aud\": \"%s\"}", validDomain))+".y", ""); err == nil {
t.Error("Returned when aud field not a list")
}

acct, err := New("x."+b64Encode(fmt.Sprintf("{\"aud\": [\"%s\"]}", validDomain))+".y", "")
if err != nil {
t.Fatalf("Returned error on valid JWT: %s", err)
tests := []struct {
jwt string
shouldError bool
description string
}{
{"", true, "empty JWT"},
{b64Encode(validDomain), true, "one-field JWT"},
{"x." + b64Encode(validDomain), true, "two-field JWT"},
{"x." + b64Encode(validDomain) + "y.z", true, "four-field JWT"},
{"x." + validDomain + ".y", true, "non-base64 encoded JWT"},
{"x." + b64Encode("{\"aud\": \"example.com\"}") + ".y", true, "untrusted domain"},
{"x." + b64Encode(fmt.Sprintf("{\"aud\": \"%s\"}", validDomain)) + ".y", true, "aud field not a list"},
{"x." + b64Encode(fmt.Sprintf("{\"aud\": [\"%s\"]}", validDomain)) + ".y", false, "valid JWT"},
}
if acct == nil || acct.Host != validDomain {
t.Errorf("acct = %+v", acct)

for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
acct, err := New(test.jwt, "")
if (err != nil) != test.shouldError {
t.Errorf("Unexpected result: err = %v, shouldError = %v", err, test.shouldError)
}
if !test.shouldError && (acct == nil || acct.Host != validDomain) {
t.Errorf("acct = %+v, expected Host = %s", acct, validDomain)
}
})
}
}

// TestDomainDefault tests the default domain extraction.
func TestDomainDefault(t *testing.T) {
payload := &oauthPayload{
Audiences: []string{"https://auth.tesla.com/nts"},
Expand All @@ -54,25 +55,32 @@ func TestDomainDefault(t *testing.T) {
t.Fatalf("Returned error on valid JWT: %s", err)
}
if acct == nil || acct.Host != defaultDomain {
t.Errorf("acct = %+v", acct)
t.Errorf("acct = %+v, expected Host = %s", acct, defaultDomain)
}
}

// TestDomainExtraction tests the extraction of the correct domain based on OUCode.
func TestDomainExtraction(t *testing.T) {
payload := &oauthPayload{
Audiences: []string{"https://auth.tesla.com/nts", "https://fleet-api.prd.na.vn.cloud.tesla.com", "https://fleet-api.prd.eu.vn.cloud.tesla.com"},
OUCode: "EU",
Audiences: []string{
"https://auth.tesla.com/nts",
"https://fleet-api.prd.na.vn.cloud.tesla.com",
"https://fleet-api.prd.eu.vn.cloud.tesla.com",
},
OUCode: "EU",
}

acct, err := New(makeTestJWT(payload), "")
if err != nil {
t.Fatalf("Returned error on valid JWT: %s", err)
}
if acct == nil || acct.Host != "fleet-api.prd.eu.vn.cloud.tesla.com" {
t.Errorf("acct = %+v", acct)
expectedHost := "fleet-api.prd.eu.vn.cloud.tesla.com"
if acct == nil || acct.Host != expectedHost {
t.Errorf("acct = %+v, expected Host = %s", acct, expectedHost)
}
}

// makeTestJWT creates a JWT string with the given payload.
func makeTestJWT(payload *oauthPayload) string {
jwtBody, _ := json.Marshal(payload)
return fmt.Sprintf("x.%s.y", b64Encode(string(jwtBody)))
Expand Down

0 comments on commit d60de3b

Please sign in to comment.