diff --git a/encoding.go b/encoding.go index b3062ab8..dea533e8 100644 --- a/encoding.go +++ b/encoding.go @@ -23,13 +23,12 @@ import ( "encoding/binary" "io" "math/big" - "regexp" + "strings" + "unicode" "github.com/square/go-jose/v3/json" ) -var stripWhitespaceRegex = regexp.MustCompile(`\s`) - // Helper function to serialize known-good objects. // Precondition: value is not a nil pointer. func mustSerializeJSON(value interface{}) []byte { @@ -56,7 +55,16 @@ func mustSerializeJSON(value interface{}) []byte { // Strip all newlines and whitespace func stripWhitespace(data string) string { - return stripWhitespaceRegex.ReplaceAllString(data, "") + buf := strings.Builder{} + buf.Grow(len(data)) + + for _, r := range data { + if !unicode.IsSpace(r) { + buf.WriteRune(r) + } + } + + return buf.String() } // Perform compression based on algorithm diff --git a/jws.go b/jws.go index 6bb310e6..a05c56bd 100644 --- a/jws.go +++ b/jws.go @@ -102,14 +102,14 @@ func (sig Signature) mergedHeaders() rawHeader { } // Compute data to be signed -func (obj JSONWebSignature) computeAuthData(payload []byte, signature *Signature) []byte { +func (obj JSONWebSignature) computeAuthData(payload []byte, signature *Signature) ([]byte, error) { var authData bytes.Buffer protectedHeader := new(rawHeader) if signature.original != nil && signature.original.Protected != nil { if err := json.Unmarshal(signature.original.Protected.bytes(), protectedHeader); err != nil { - panic(err) + return nil, err } authData.WriteString(signature.original.Protected.base64()) } else if signature.protected != nil { @@ -134,7 +134,7 @@ func (obj JSONWebSignature) computeAuthData(payload []byte, signature *Signature authData.Write(payload) } - return authData.Bytes() + return authData.Bytes(), nil } // parseSignedFull parses a message in full format. diff --git a/jws_test.go b/jws_test.go index 64b0b7c0..2ed51973 100644 --- a/jws_test.go +++ b/jws_test.go @@ -18,8 +18,11 @@ package jose import ( "crypto/x509" + "encoding/base64" "strings" "testing" + + "github.com/stretchr/testify/assert" ) const trustedCA = ` @@ -649,3 +652,39 @@ func TestDetachedCompactSerialization(t *testing.T) { t.Fatalf("got '%s', expected '%s'", ser, msg) } } + +func TestJWSComputeAuthDataBase64(t *testing.T) { + jws := JSONWebSignature{} + + _, err := jws.computeAuthData([]byte{0x01}, &Signature{ + original: &rawSignatureInfo{ + Protected: newBuffer([]byte("{!invalid-json}")), + }, + }) + // Invalid header, should return error + assert.NotNil(t, err) + + payload := []byte{0x01} + encodedPayload := base64.RawURLEncoding.EncodeToString(payload) + + b64TrueHeader := newBuffer([]byte(`{"alg":"RSA-OAEP","enc":"A256GCM","b64":true}`)) + b64FalseHeader := newBuffer([]byte(`{"alg":"RSA-OAEP","enc":"A256GCM","b64":false}`)) + + data, err := jws.computeAuthData(payload, &Signature{ + original: &rawSignatureInfo{ + Protected: b64TrueHeader, + }, + }) + assert.Nil(t, err) + // Payload should be b64 encoded + assert.Len(t, data, len(b64TrueHeader.base64())+len(encodedPayload)+1) + + data, err = jws.computeAuthData(payload, &Signature{ + original: &rawSignatureInfo{ + Protected: b64FalseHeader, + }, + }) + assert.Nil(t, err) + // Payload should *not* be b64 encoded + assert.Len(t, data, len(b64FalseHeader.base64())+len(payload)+1) +} diff --git a/jwt/validation.go b/jwt/validation.go index 045d5dfb..6f3ff4e8 100644 --- a/jwt/validation.go +++ b/jwt/validation.go @@ -35,7 +35,7 @@ type Expected struct { Audience Audience // ID matches the "jti" claim exactly. ID string - // Time matches the "exp" and "nbf" claims with leeway. + // Time matches the "exp", "nbf" and "iat" claims with leeway. Time time.Time } diff --git a/signing.go b/signing.go index 0a631bb9..cd086053 100644 --- a/signing.go +++ b/signing.go @@ -370,7 +370,11 @@ func (obj JSONWebSignature) DetachedVerify(payload []byte, verificationKey inter } } - input := obj.computeAuthData(payload, &signature) + input, err := obj.computeAuthData(payload, &signature) + if err != nil { + return ErrCryptoFailure + } + alg := headers.getSignatureAlgorithm() err = verifier.verifyPayload(input, signature.Signature, alg) if err == nil { @@ -421,7 +425,11 @@ outer: } } - input := obj.computeAuthData(payload, &signature) + input, err := obj.computeAuthData(payload, &signature) + if err != nil { + continue + } + alg := headers.getSignatureAlgorithm() err = verifier.verifyPayload(input, signature.Signature, alg) if err == nil { diff --git a/signing_test.go b/signing_test.go index c6c20791..a73488b6 100644 --- a/signing_test.go +++ b/signing_test.go @@ -583,3 +583,14 @@ func TestSignerB64(t *testing.T) { t.Errorf("Input/output do not match, got '%s', expected '%s'", output, input) } } + +func BenchmarkParseSigned(b *testing.B) { + msg := `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c` + + for i := 0; i < b.N; i++ { + _, err := ParseSigned(msg) + if err != nil { + b.Errorf("Error on parse: %s", err) + } + } +}