Skip to content

Commit

Permalink
Nicer pointer handling for jwks/support serializing secp ecdsa keys (#…
Browse files Browse the repository at this point in the history
…338)

* handle ptrs better

* better
  • Loading branch information
decentralgabe committed Apr 1, 2023
1 parent 9d44b15 commit bca28f8
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 50 deletions.
64 changes: 27 additions & 37 deletions crypto/jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"encoding/base64"
"fmt"
"reflect"

secp256k1 "github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/goccy/go-json"
Expand Down Expand Up @@ -51,7 +51,7 @@ func (k PrivateKeyJWK) ToPublicKeyJWK() PublicKeyJWK {
}
}

func (k PrivateKeyJWK) ToKey() (crypto.PrivateKey, error) {
func (k PrivateKeyJWK) ToPrivateKey() (crypto.PrivateKey, error) {
gotJWK, err := JWKFromPrivateKeyJWK(k)
if err != nil {
return nil, errors.Wrap(err, "creating JWK from private key")
Expand All @@ -60,6 +60,10 @@ func (k PrivateKeyJWK) ToKey() (crypto.PrivateKey, error) {
if err = gotJWK.Raw(&goKey); err != nil {
return nil, errors.Wrap(err, "converting JWK to go key")
}
// dereference the ptr
if reflect.ValueOf(goKey).Kind() == reflect.Ptr {
goKey = reflect.ValueOf(goKey).Elem().Interface().(crypto.PrivateKey)
}
return goKey, nil
}

Expand All @@ -77,7 +81,7 @@ type PublicKeyJWK struct {
KID string `json:"kid,omitempty"`
}

func (k PublicKeyJWK) ToKey() (crypto.PublicKey, error) {
func (k PublicKeyJWK) ToPublicKey() (crypto.PublicKey, error) {
gotJWK, err := JWKFromPublicKeyJWK(k)
if err != nil {
return nil, errors.Wrap(err, "creating JWK from public key")
Expand All @@ -86,6 +90,10 @@ func (k PublicKeyJWK) ToKey() (crypto.PublicKey, error) {
if err = gotJWK.Raw(&goKey); err != nil {
return nil, errors.Wrap(err, "converting JWK to go key")
}
// dereference the ptr
if reflect.ValueOf(goKey).Kind() == reflect.Ptr {
goKey = reflect.ValueOf(goKey).Elem().Interface().(crypto.PublicKey)
}
return goKey, nil
}

Expand Down Expand Up @@ -135,90 +143,76 @@ func JWKFromPrivateKeyJWK(key PrivateKeyJWK) (jwk.Key, error) {

// PublicKeyToJWK converts a public key to a JWK
func PublicKeyToJWK(key crypto.PublicKey) (jwk.Key, error) {
// dereference the ptr
if reflect.ValueOf(key).Kind() == reflect.Ptr {
key = reflect.ValueOf(key).Elem().Interface().(crypto.PublicKey)
}
switch k := key.(type) {
case rsa.PublicKey:
return jwkKeyFromRSAPublicKey(k)
case *rsa.PublicKey:
return jwkKeyFromRSAPublicKey(*k)
case ed25519.PublicKey:
return jwkKeyFromEd25519PublicKey(k)
case *ed25519.PublicKey:
return jwkKeyFromEd25519PublicKey(*k)
case x25519.PublicKey:
return jwkKeyFromX25519PublicKey(k)
case *x25519.PublicKey:
return jwkKeyFromX25519PublicKey(*k)
case secp256k1.PublicKey:
return jwkKeyFromSECP256k1PublicKey(k)
case *secp256k1.PublicKey:
return jwkKeyFromSECP256k1PublicKey(*k)
case ecdsa.PublicKey:
return jwkKeyFromECDSAPublicKey(k)
case *ecdsa.PublicKey:
return jwkKeyFromECDSAPublicKey(*k)
default:
return nil, fmt.Errorf("unsupported public key type: %T", k)
}
}

// PublicKeyToPublicKeyJWK converts a public key to a PublicKeyJWK
func PublicKeyToPublicKeyJWK(key crypto.PublicKey) (*PublicKeyJWK, error) {
// dereference the ptr
if reflect.ValueOf(key).Kind() == reflect.Ptr {
key = reflect.ValueOf(key).Elem().Interface().(crypto.PublicKey)
}
switch k := key.(type) {
case rsa.PublicKey:
return jwkFromRSAPublicKey(k)
case *rsa.PublicKey:
return jwkFromRSAPublicKey(*k)
case ed25519.PublicKey:
return jwkFromEd25519PublicKey(k)
case *ed25519.PublicKey:
return jwkFromEd25519PublicKey(*k)
case x25519.PublicKey:
return jwkFromX25519PublicKey(k)
case *x25519.PublicKey:
return jwkFromX25519PublicKey(*k)
case secp256k1.PublicKey:
return jwkFromSECP256k1PublicKey(k)
case *secp256k1.PublicKey:
return jwkFromSECP256k1PublicKey(*k)
case ecdsa.PublicKey:
return jwkFromECDSAPublicKey(k)
case *ecdsa.PublicKey:
return jwkFromECDSAPublicKey(*k)
default:
return nil, fmt.Errorf("unsupported public key type: %T", k)
}
}

// PrivateKeyToJWK converts a private key to a JWK
func PrivateKeyToJWK(key crypto.PrivateKey) (jwk.Key, error) {
// dereference the ptr
if reflect.ValueOf(key).Kind() == reflect.Ptr {
key = reflect.ValueOf(key).Elem().Interface().(crypto.PrivateKey)
}
switch k := key.(type) {
case rsa.PrivateKey:
return jwkKeyFromRSAPrivateKey(k)
case *rsa.PrivateKey:
return jwkKeyFromRSAPrivateKey(*k)
case ed25519.PrivateKey:
return jwkKeyFromEd25519PrivateKey(k)
case *ed25519.PrivateKey:
return jwkKeyFromEd25519PrivateKey(*k)
case x25519.PrivateKey:
return jwkKeyFromX25519PrivateKey(k)
case *x25519.PrivateKey:
return jwkKeyFromX25519PrivateKey(*k)
case secp256k1.PrivateKey:
return jwkKeyFromSECP256k1PrivateKey(k)
case *secp256k1.PrivateKey:
return jwkKeyFromSECP256k1PrivateKey(*k)
case ecdsa.PrivateKey:
return jwkKeyFromECDSAPrivateKey(k)
case *ecdsa.PrivateKey:
return jwkKeyFromECDSAPrivateKey(*k)
default:
return nil, fmt.Errorf("unsupported private key type: %T", k)
}
}

// PrivateKeyToPrivateKeyJWK converts a private key to a PrivateKeyJWK
func PrivateKeyToPrivateKeyJWK(key crypto.PrivateKey) (*PublicKeyJWK, *PrivateKeyJWK, error) {
// dereference the ptr
if reflect.ValueOf(key).Kind() == reflect.Ptr {
key = reflect.ValueOf(key).Elem().Interface().(crypto.PrivateKey)
}
switch k := key.(type) {
case rsa.PrivateKey:
return jwkFromRSAPrivateKey(k)
Expand Down Expand Up @@ -551,7 +545,3 @@ func jwkFromECDSAPublicKey(key ecdsa.PublicKey) (*PublicKeyJWK, error) {
}
return &publicKeyJWK, nil
}

func encodeToBase64RawURL(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
103 changes: 101 additions & 2 deletions crypto/jwk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package crypto
import (
"testing"

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/stretchr/testify/assert"
)
Expand All @@ -27,7 +28,7 @@ func TestJWKToPrivateKeyJWK(t *testing.T) {
assert.Equal(t, "Ed25519", privKeyJWK.CRV)

// convert back
gotPrivKey, err := privKeyJWK.ToKey()
gotPrivKey, err := privKeyJWK.ToPrivateKey()
assert.NoError(t, err)
assert.NotEmpty(t, gotPrivKey)
assert.Equal(t, privateKey, gotPrivKey)
Expand All @@ -53,7 +54,7 @@ func TestJWKToPublicKeyJWK(t *testing.T) {
assert.Equal(t, "Ed25519", pubKeyJWK.CRV)

// convert back
gotPubKey, err := pubKeyJWK.ToKey()
gotPubKey, err := pubKeyJWK.ToPublicKey()
assert.NoError(t, err)
assert.NotEmpty(t, gotPubKey)
assert.Equal(t, publicKey, gotPubKey)
Expand Down Expand Up @@ -111,6 +112,104 @@ func TestJWKFromPublicKeyJWK(t *testing.T) {
assert.Equal(t, key, gotJWK)
}

func TestPublicKeyToJWK(t *testing.T) {
t.Run("RSA", func(tt *testing.T) {
pubKey, _, err := GenerateRSA2048Key()
assert.NoError(t, err)

jwk, err := PublicKeyToJWK(pubKey)
assert.NoError(tt, err)
assert.NotEmpty(tt, jwk)
assert.Equal(tt, jwa.RSA, jwk.KeyType())

jwk2, err := PublicKeyToJWK(&pubKey)
assert.NoError(tt, err)
assert.NotEmpty(tt, jwk2)
assert.Equal(tt, jwa.RSA, jwk2.KeyType())
})

t.Run("Ed25519", func(tt *testing.T) {
pubKey, _, err := GenerateEd25519Key()
assert.NoError(t, err)

jwk, err := PublicKeyToJWK(pubKey)
assert.NoError(tt, err)
assert.NotEmpty(tt, jwk)
assert.Equal(tt, jwa.OKP, jwk.KeyType())

jwk2, err := PublicKeyToJWK(&pubKey)
assert.NoError(tt, err)
assert.NotEmpty(tt, jwk2)
assert.Equal(tt, jwa.OKP, jwk2.KeyType())
})

t.Run("X25519", func(tt *testing.T) {
pubKey, _, err := GenerateX25519Key()
assert.NoError(t, err)

jwk, err := PublicKeyToJWK(pubKey)
assert.NoError(tt, err)
assert.NotEmpty(tt, jwk)
assert.Equal(tt, jwa.OKP, jwk.KeyType())

jwk2, err := PublicKeyToJWK(&pubKey)
assert.NoError(tt, err)
assert.NotEmpty(tt, jwk2)
assert.Equal(tt, jwa.OKP, jwk2.KeyType())
})

t.Run("secp256k1", func(tt *testing.T) {
pubKey, _, err := GenerateSECP256k1Key()
assert.NoError(t, err)

jwk, err := PublicKeyToJWK(pubKey)
assert.NoError(tt, err)
assert.NotEmpty(tt, jwk)
assert.Equal(tt, jwa.EC, jwk.KeyType())

jwk2, err := PublicKeyToJWK(&pubKey)
assert.NoError(tt, err)
assert.NotEmpty(tt, jwk2)
assert.Equal(tt, jwa.EC, jwk2.KeyType())
})

t.Run("ecdsa P-256", func(tt *testing.T) {
pubKey, _, err := GenerateP256Key()
assert.NoError(t, err)

jwk, err := PublicKeyToJWK(pubKey)
assert.NoError(tt, err)
assert.NotEmpty(tt, jwk)
assert.Equal(tt, jwa.EC, jwk.KeyType())

jwk2, err := PublicKeyToJWK(&pubKey)
assert.NoError(tt, err)
assert.NotEmpty(tt, jwk2)
assert.Equal(tt, jwa.EC, jwk.KeyType())
})

t.Run("ecdsa P-384", func(tt *testing.T) {
pubKey, _, err := GenerateP384Key()
assert.NoError(t, err)

jwk, err := PublicKeyToJWK(pubKey)
assert.NoError(tt, err)
assert.NotEmpty(tt, jwk)
assert.Equal(tt, jwa.EC, jwk.KeyType())

jwk2, err := PublicKeyToJWK(&pubKey)
assert.NoError(tt, err)
assert.NotEmpty(tt, jwk2)
assert.Equal(tt, jwa.EC, jwk2.KeyType())
})

t.Run("unsupported", func(tt *testing.T) {
jwk, err := PublicKeyToJWK(nil)
assert.Error(tt, err)
assert.Empty(tt, jwk)
})
}

func TestPublicKeyToPublicKeyJWK(t *testing.T) {
t.Run("RSA", func(tt *testing.T) {
pubKey, _, err := GenerateRSA2048Key()
Expand Down
3 changes: 3 additions & 0 deletions crypto/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ func TestSignVerifyJWTForEachSupportedKeyType(t *testing.T) {
{
kt: SECP256k1,
},
{
kt: SECP256k1ECDSA,
},
{
kt: P256,
},
Expand Down
Loading

0 comments on commit bca28f8

Please sign in to comment.