diff --git a/authorization.go b/authorization.go index 9f8d8bc..741fd65 100644 --- a/authorization.go +++ b/authorization.go @@ -9,7 +9,6 @@ import ( "hash" "io" "net/url" - "regexp" "strings" "time" ) @@ -49,6 +48,13 @@ func newAuthorization(dr *DigestRequest) (*authorization, error) { return ah.refreshAuthorization(dr) } +const ( + algorithmMD5 = "MD5" + algorithmMD5Sess = "MD5-SESS" + algorithmSHA256 = "SHA-256" + algorithmSHA256Sess = "SHA-256-SESS" +) + func (ah *authorization) refreshAuthorization(dr *DigestRequest) (*authorization, error) { ah.Username = dr.Username @@ -82,11 +88,13 @@ func (ah *authorization) computeResponse(dr *DigestRequest) (s string) { func (ah *authorization) computeA1(dr *DigestRequest) string { - if ah.Algorithm == "" || ah.Algorithm == "MD5" || ah.Algorithm == "SHA-256" { + algorithm := strings.ToUpper(ah.Algorithm) + + if algorithm == "" || algorithm == algorithmMD5 || algorithm == algorithmSHA256 { return fmt.Sprintf("%s:%s:%s", ah.Username, ah.Realm, dr.Password) } - if ah.Algorithm == "MD5-sess" || ah.Algorithm == "SHA-256-sess" { + if algorithm == algorithmMD5Sess || algorithm == algorithmSHA256Sess { upHash := ah.hash(fmt.Sprintf("%s:%s:%s", ah.Username, ah.Realm, dr.Password)) return fmt.Sprintf("%s:%s:%s", upHash, ah.Nonce, ah.Cnonce) } @@ -96,7 +104,7 @@ func (ah *authorization) computeA1(dr *DigestRequest) string { func (ah *authorization) computeA2(dr *DigestRequest) string { - if matched, _ := regexp.MatchString("auth-int", dr.Wa.Qop); matched { + if strings.Contains(dr.Wa.Qop, "auth-int") { ah.Qop = "auth-int" return fmt.Sprintf("%s:%s:%s", dr.Method, ah.URI, ah.hash(dr.Body)) } @@ -109,20 +117,21 @@ func (ah *authorization) computeA2(dr *DigestRequest) string { return "" } -func (ah *authorization) hash(a string) (s string) { - +func (ah *authorization) hash(a string) string { var h hash.Hash + algorithm := strings.ToUpper(ah.Algorithm) - if ah.Algorithm == "" || ah.Algorithm == "MD5" || ah.Algorithm == "MD5-sess" { + if algorithm == "" || algorithm == algorithmMD5 || algorithm == algorithmMD5Sess { h = md5.New() - } else if ah.Algorithm == "SHA-256" || ah.Algorithm == "SHA-256-sess" { + } else if algorithm == algorithmSHA256 || algorithm == algorithmSHA256Sess { h = sha256.New() + } else { + // unknown algorithm + return "" } io.WriteString(h, a) - s = hex.EncodeToString(h.Sum(nil)) - - return + return hex.EncodeToString(h.Sum(nil)) } func (ah *authorization) toString() string { diff --git a/authorization_test.go b/authorization_test.go new file mode 100644 index 0000000..76cb721 --- /dev/null +++ b/authorization_test.go @@ -0,0 +1,174 @@ +package digest_auth_client + +import "testing" + +func TestHash(t *testing.T) { + testCases := []struct { + name string + algorithm string + expRes string + }{ + { + name: "empty algorithm", + algorithm: "", + expRes: "1a79a4d60de6718e8e5b326e338ae533", + }, + { + name: "MD5 algorithm", + algorithm: "MD5", + expRes: "1a79a4d60de6718e8e5b326e338ae533", + }, + { + name: "MD5-sess algorithm", + algorithm: "MD5", + expRes: "1a79a4d60de6718e8e5b326e338ae533", + }, + { + name: "SHA256 algorithm", + algorithm: "SHA-256", + expRes: "50d858e0985ecc7f60418aaf0cc5ab587f42c2570a884095a9e8ccacd0f6545c", + }, + { + name: "SHA256-sess algorithm", + algorithm: "SHA-256", + expRes: "50d858e0985ecc7f60418aaf0cc5ab587f42c2570a884095a9e8ccacd0f6545c", + }, + { + name: "md5 algorithm", + algorithm: "md5", + expRes: "1a79a4d60de6718e8e5b326e338ae533", + }, + { + name: "unknown algorithm", + algorithm: "unknown", + expRes: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ah := &authorization{Algorithm: tc.algorithm} + res := ah.hash("example") + if res != tc.expRes { + t.Errorf("got: %q, want: %q", res, tc.expRes) + } + }) + } +} + +func TestComputeA1(t *testing.T) { + testCases := []struct { + name string + algorithm string + expRes string + }{ + { + name: "empty algorithm", + algorithm: "", + expRes: "username:realm:secret", + }, + { + name: "MD5 algorithm", + algorithm: "MD5", + expRes: "username:realm:secret", + }, + { + name: "MD5-sess algorithm", + algorithm: "MD5", + expRes: "username:realm:secret", + }, + { + name: "SHA256 algorithm", + algorithm: "SHA-256", + expRes: "username:realm:secret", + }, + { + name: "SHA256-sess algorithm", + algorithm: "SHA-256", + expRes: "username:realm:secret", + }, + { + name: "md5 algorithm", + algorithm: "md5", + expRes: "username:realm:secret", + }, + { + name: "unknown algorithm", + algorithm: "unknown", + expRes: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dr := &DigestRequest{Password: "secret"} + ah := &authorization{ + Algorithm: tc.algorithm, + Nonce: "nonce", + Cnonce: "cnonce", + Username: "username", + Realm: "realm", + } + res := ah.computeA1(dr) + if res != tc.expRes { + t.Errorf("got: %q, want: %q", res, tc.expRes) + } + }) + } +} + +func TestComputeA2(t *testing.T) { + testCases := []struct { + name string + qop string + expRes string + expAuthQop string + }{ + { + name: "empty qop", + qop: "", + expRes: "method:uri", + expAuthQop: "auth", + }, + { + name: "qop is auth", + qop: "auth", + expRes: "method:uri", + expAuthQop: "auth", + }, + { + name: "qop is auth-int", + qop: "qop is auth-int", + expRes: "method:uri:841a2d689ad86bd1611447453c22c6fc", + expAuthQop: "auth-int", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + dr := &DigestRequest{ + Method: "method", + Body: "body", + Wa: &wwwAuthenticate{ + Qop: tc.qop, + }, + } + ah := &authorization{ + Algorithm: "MD5", + Nonce: "nonce", + Cnonce: "cnonce", + Username: "username", + Realm: "realm", + URI: "uri", + Qop: tc.qop, + } + res := ah.computeA2(dr) + if res != tc.expRes { + t.Errorf("wrong result, got: %q, want: %q", res, tc.expRes) + } + if ah.Qop != tc.expAuthQop { + t.Errorf("wrong qop, got: %q, want: %q", ah.Qop, tc.expAuthQop) + } + }) + } +}