diff --git a/rsa_test.go b/rsa_test.go index cba41001..87734760 100644 --- a/rsa_test.go +++ b/rsa_test.go @@ -1,6 +1,11 @@ package jwt_test import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "os" "reflect" "strings" @@ -115,6 +120,17 @@ func TestRSAKeyParsing(t *testing.T) { pubKey, _ := os.ReadFile("test/sample_key.pub") badKey := []byte("All your base are belong to key") + randomKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Errorf("Failed to generate RSA private key: %v", err) + } + + publicKeyBytes := x509.MarshalPKCS1PublicKey(&randomKey.PublicKey) + pkcs1Buffer := new(bytes.Buffer) + if err = pem.Encode(pkcs1Buffer, &pem.Block{Type: "RSA PUBLIC KEY", Bytes: publicKeyBytes}); err != nil { + t.Errorf("Failed to encode public pem: %v", err) + } + // Test parsePrivateKey if _, e := jwt.ParseRSAPrivateKeyFromPEM(key); e != nil { t.Errorf("Failed to parse valid private key: %v", e) @@ -149,6 +165,9 @@ func TestRSAKeyParsing(t *testing.T) { t.Errorf("Parsed invalid key as valid private key: %v", k) } + if _, err := jwt.ParseRSAPublicKeyFromPEM(pkcs1Buffer.Bytes()); err != nil { + t.Errorf("failed to parse RSA public key: %v", err) + } } func BenchmarkRSAParsing(b *testing.B) { diff --git a/rsa_utils.go b/rsa_utils.go index 1966c450..b3aeebbe 100644 --- a/rsa_utils.go +++ b/rsa_utils.go @@ -75,7 +75,7 @@ func ParseRSAPrivateKeyFromPEMWithPassword(key []byte, password string) (*rsa.Pr return pkey, nil } -// ParseRSAPublicKeyFromPEM parses a PEM encoded PKCS1 or PKCS8 public key +// ParseRSAPublicKeyFromPEM parses a certificate or a PEM encoded PKCS1 or PKIX public key func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) { var err error @@ -91,7 +91,9 @@ func ParseRSAPublicKeyFromPEM(key []byte) (*rsa.PublicKey, error) { if cert, err := x509.ParseCertificate(block.Bytes); err == nil { parsedKey = cert.PublicKey } else { - return nil, err + if parsedKey, err = x509.ParsePKCS1PublicKey(block.Bytes); err != nil { + return nil, err + } } }