diff --git a/.github/.codecov.yml b/.github/.codecov.yml index 7c6d863..4b598be 100644 --- a/.github/.codecov.yml +++ b/.github/.codecov.yml @@ -1,5 +1,6 @@ coverage: status: + patch: off project: default: - target: 89% \ No newline at end of file + target: 89% diff --git a/algorithm.go b/algorithm.go index 7b95ed7..7e68535 100644 --- a/algorithm.go +++ b/algorithm.go @@ -2,6 +2,7 @@ package cose import ( "crypto" + "fmt" "strconv" ) @@ -36,10 +37,12 @@ const ( // PureEdDSA by RFC 8152. AlgorithmEd25519 Algorithm = -8 + + // An invalid/unrecognised algorithm. + AlgorithmInvalid Algorithm = 0 ) // Algorithm represents an IANA algorithm entry in the COSE Algorithms registry. -// Algorithms with string values are not supported. // // # See Also // @@ -72,6 +75,35 @@ func (a Algorithm) String() string { } } +// MarshalCBOR marshals the Algorithm as a CBOR int. +func (a Algorithm) MarshalCBOR() ([]byte, error) { + return encMode.Marshal(int64(a)) +} + +// UnmarshalCBOR populates the Algorithm from the provided CBOR value (must be +// int or tstr). +func (a *Algorithm) UnmarshalCBOR(data []byte) error { + var raw intOrStr + + if err := raw.UnmarshalCBOR(data); err != nil { + return fmt.Errorf("invalid algorithm value: %w", err) + } + + if raw.IsString() { + v := algorithmFromString(raw.String()) + if v == AlgorithmInvalid { + return fmt.Errorf("unknown algorithm value %q", raw.String()) + } + + *a = v + } else { + v := raw.Int() + *a = Algorithm(v) + } + + return nil +} + // hashFunc returns the hash associated with the algorithm supported by this // library. func (a Algorithm) hashFunc() crypto.Hash { @@ -103,3 +135,8 @@ func computeHash(h crypto.Hash, data []byte) ([]byte, error) { } return hh.Sum(nil), nil } + +// NOTE: there are currently no registered string values for an algorithm. +func algorithmFromString(v string) Algorithm { + return AlgorithmInvalid +} diff --git a/algorithm_test.go b/algorithm_test.go index 7ccafeb..bd92a8d 100644 --- a/algorithm_test.go +++ b/algorithm_test.go @@ -16,41 +16,6 @@ func TestAlgorithm_String(t *testing.T) { alg Algorithm want string }{ - { - name: "PS256", - alg: AlgorithmPS256, - want: "PS256", - }, - { - name: "PS384", - alg: AlgorithmPS384, - want: "PS384", - }, - { - name: "PS512", - alg: AlgorithmPS512, - want: "PS512", - }, - { - name: "ES256", - alg: AlgorithmES256, - want: "ES256", - }, - { - name: "ES384", - alg: AlgorithmES384, - want: "ES384", - }, - { - name: "ES512", - alg: AlgorithmES512, - want: "ES512", - }, - { - name: "Ed25519", - alg: AlgorithmEd25519, - want: "EdDSA", - }, { name: "unknown algorithm", alg: 0, @@ -66,6 +31,23 @@ func TestAlgorithm_String(t *testing.T) { } } +func TestAlgorithm_CBOR(t *testing.T) { + tvs2 := []struct { + Data []byte + ExpectedError string + }{ + {[]byte{0x63, 0x66, 0x6f, 0x6f}, "unknown algorithm value \"foo\""}, + {[]byte{0x40}, "invalid algorithm value: must be int or string, found []uint8"}, + } + + for _, tv := range tvs2 { + var a Algorithm + + err := a.UnmarshalCBOR(tv.Data) + assertEqualError(t, err, tv.ExpectedError) + } +} + func TestAlgorithm_computeHash(t *testing.T) { // run tests data := []byte("hello world") diff --git a/common.go b/common.go new file mode 100644 index 0000000..32294a5 --- /dev/null +++ b/common.go @@ -0,0 +1,96 @@ +package cose + +import ( + "errors" + "fmt" +) + +// intOrStr is a value that can be either an int or a tstr when serialized to +// CBOR. +type intOrStr struct { + intVal int64 + strVal string + isString bool +} + +func newIntOrStr(v interface{}) *intOrStr { + var ios intOrStr + if err := ios.Set(v); err != nil { + return nil + } + return &ios +} + +func (ios intOrStr) Int() int64 { + return ios.intVal +} + +func (ios intOrStr) String() string { + if ios.IsString() { + return ios.strVal + } + return fmt.Sprint(ios.intVal) +} + +func (ios intOrStr) IsInt() bool { + return !ios.isString +} + +func (ios intOrStr) IsString() bool { + return ios.isString +} + +func (ios intOrStr) Value() interface{} { + if ios.IsInt() { + return ios.intVal + } + + return ios.strVal +} + +func (ios *intOrStr) Set(v interface{}) error { + switch t := v.(type) { + case int64: + ios.intVal = t + ios.strVal = "" + ios.isString = false + case int: + ios.intVal = int64(t) + ios.strVal = "" + ios.isString = false + case string: + ios.strVal = t + ios.intVal = 0 + ios.isString = true + default: + return fmt.Errorf("must be int or string, found %T", t) + } + + return nil +} + +// MarshalCBOR returns the encoded CBOR representation of the intOrString, as +// either int or tstr, depending on the value. If no value has been set, +// intOrStr is encoded as a zero-length tstr. +func (ios intOrStr) MarshalCBOR() ([]byte, error) { + if ios.IsInt() { + return encMode.Marshal(ios.intVal) + } + + return encMode.Marshal(ios.strVal) +} + +// UnmarshalCBOR unmarshals the provided CBOR encoded data (must be an int, +// uint, or tstr). +func (ios *intOrStr) UnmarshalCBOR(data []byte) error { + if len(data) == 0 { + return errors.New("zero length buffer") + } + + var val interface{} + if err := decMode.Unmarshal(data, &val); err != nil { + return err + } + + return ios.Set(val) +} diff --git a/common_test.go b/common_test.go new file mode 100644 index 0000000..5f2c8d5 --- /dev/null +++ b/common_test.go @@ -0,0 +1,140 @@ +package cose + +import ( + "bytes" + "reflect" + "testing" + + "github.com/fxamacker/cbor/v2" +) + +func Test_intOrStr(t *testing.T) { + ios := newIntOrStr(3) + assertEqual(t, true, ios.IsInt()) + assertEqual(t, false, ios.IsString()) + assertEqual(t, 3, ios.Int()) + assertEqual(t, "3", ios.String()) + + ios = newIntOrStr("foo") + assertEqual(t, false, ios.IsInt()) + assertEqual(t, true, ios.IsString()) + assertEqual(t, 0, ios.Int()) + assertEqual(t, "foo", ios.String()) + + ios = newIntOrStr(3.5) + if ios != nil { + t.Errorf("Expected nil, got %v", ios) + } +} + +func Test_intOrStr_CBOR(t *testing.T) { + ios := newIntOrStr(3) + data, err := ios.MarshalCBOR() + requireNoError(t, err) + assertEqual(t, []byte{0x03}, data) + + ios = &intOrStr{} + err = ios.UnmarshalCBOR(data) + requireNoError(t, err) + assertEqual(t, true, ios.IsInt()) + assertEqual(t, 3, ios.Int()) + + ios = newIntOrStr("foo") + data, err = ios.MarshalCBOR() + requireNoError(t, err) + assertEqual(t, []byte{0x63, 0x66, 0x6f, 0x6f}, data) + + ios = &intOrStr{} + err = ios.UnmarshalCBOR(data) + requireNoError(t, err) + assertEqual(t, true, ios.IsString()) + assertEqual(t, "foo", ios.String()) + + // empty value as field + s := struct { + Field1 intOrStr `cbor:"1,keyasint"` + Field2 int `cbor:"2,keyasint"` + }{Field1: intOrStr{}, Field2: 7} + + data, err = cbor.Marshal(s) + requireNoError(t, err) + assertEqual(t, []byte{0xa2, 0x1, 0x00, 0x2, 0x7}, data) + + ios = &intOrStr{} + data = []byte{0x22} + err = ios.UnmarshalCBOR(data) + requireNoError(t, err) + assertEqual(t, true, ios.IsInt()) + assertEqual(t, -3, ios.Int()) + + data = []byte{} + err = ios.UnmarshalCBOR(data) + assertEqualError(t, err, "zero length buffer") + + data = []byte{0x40} + err = ios.UnmarshalCBOR(data) + assertEqualError(t, err, "must be int or string, found []uint8") + + data = []byte{0xff, 0xff} + err = ios.UnmarshalCBOR(data) + assertEqualError(t, err, "cbor: unexpected \"break\" code") +} + +func requireNoError(t *testing.T, err error) { + if err != nil { + t.Errorf("Unexpected error: %q", err) + t.Fail() + } +} + +func assertEqualError(t *testing.T, err error, expected string) { + if err == nil || err.Error() != expected { + t.Errorf("Unexpected error: want %q, got %q", expected, err) + } +} + +func assertEqual(t *testing.T, expected, actual interface{}) { + if !objectsAreEqualValues(expected, actual) { + t.Errorf("Unexpected value: want %v, got %v", expected, actual) + } +} + +// taken from github.com/stretchr/testify +func objectsAreEqualValues(expected, actual interface{}) bool { + if objectsAreEqual(expected, actual) { + return true + } + + actualType := reflect.TypeOf(actual) + if actualType == nil { + return false + } + expectedValue := reflect.ValueOf(expected) + if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { + // Attempt comparison after type conversion + return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) + } + + return false +} + +// taken from github.com/stretchr/testify +func objectsAreEqual(expected, actual interface{}) bool { + if expected == nil || actual == nil { + return expected == actual + } + + exp, ok := expected.([]byte) + if !ok { + return reflect.DeepEqual(expected, actual) + } + + act, ok := actual.([]byte) + if !ok { + return false + } + if exp == nil || act == nil { + return exp == nil && act == nil + } + return bytes.Equal(exp, act) +} diff --git a/errors.go b/errors.go index 8c240e2..7d16741 100644 --- a/errors.go +++ b/errors.go @@ -14,4 +14,8 @@ var ( ErrUnavailableHashFunc = errors.New("hash function is not available") ErrVerification = errors.New("verification error") ErrInvalidPubKey = errors.New("invalid public key") + ErrInvalidPrivKey = errors.New("invalid private key") + ErrNotPrivKey = errors.New("not a private key") + ErrSignOpNotSupported = errors.New("sign key_op not supported by key") + ErrVerifyOpNotSupported = errors.New("verify key_op not supported by key") ) diff --git a/headers.go b/headers.go index 7074936..4218eee 100644 --- a/headers.go +++ b/headers.go @@ -53,7 +53,8 @@ func (h ProtectedHeader) MarshalCBOR() ([]byte, error) { // UnmarshalCBOR decodes a CBOR bstr object into ProtectedHeader. // // ProtectedHeader is an empty_or_serialized_map where -// empty_or_serialized_map = bstr .cbor header_map / bstr .size 0 +// +// empty_or_serialized_map = bstr .cbor header_map / bstr .size 0 func (h *ProtectedHeader) UnmarshalCBOR(data []byte) error { if h == nil { return errors.New("cbor: UnmarshalCBOR on nil ProtectedHeader pointer") @@ -117,8 +118,17 @@ func (h ProtectedHeader) Algorithm() (Algorithm, error) { return Algorithm(alg), nil case int64: return Algorithm(alg), nil + case string: + v := algorithmFromString(alg) + + var err error + if v == AlgorithmInvalid { + err = fmt.Errorf("unknown algorithm value %q", alg) + } + + return v, err default: - return 0, ErrInvalidAlgorithm + return AlgorithmInvalid, ErrInvalidAlgorithm } } @@ -212,22 +222,22 @@ func (h *UnprotectedHeader) UnmarshalCBOR(data []byte) error { // // It is represented by CDDL fragments: // -// Headers = ( -// protected : empty_or_serialized_map, -// unprotected : header_map -// ) +// Headers = ( +// protected : empty_or_serialized_map, +// unprotected : header_map +// ) // -// header_map = { -// Generic_Headers, -// * label => values -// } +// header_map = { +// Generic_Headers, +// * label => values +// } // -// label = int / tstr -// values = any +// label = int / tstr +// values = any // -// empty_or_serialized_map = bstr .cbor header_map / bstr .size 0 +// empty_or_serialized_map = bstr .cbor header_map / bstr .size 0 // -// See Also +// # See Also // // https://tools.ietf.org/html/rfc8152#section-3 type Headers struct { @@ -553,7 +563,7 @@ func (discardedCBORMessage) UnmarshalCBOR(data []byte) error { // validateHeaderLabelCBOR validates if all header labels are integers or // strings of a CBOR map object. // -// label = int / tstr +// label = int / tstr // // Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-1.4 func validateHeaderLabelCBOR(data []byte) error { diff --git a/headers_test.go b/headers_test.go index 1bc505d..11d4e72 100644 --- a/headers_test.go +++ b/headers_test.go @@ -1,6 +1,7 @@ package cose import ( + "errors" "reflect" "testing" ) @@ -422,13 +423,20 @@ func TestProtectedHeader_Algorithm(t *testing.T) { h: ProtectedHeader{ HeaderLabelAlgorithm: "foo", }, + wantErr: errors.New("unknown algorithm value \"foo\""), + }, + { + name: "invalid algorithm", + h: ProtectedHeader{ + HeaderLabelAlgorithm: 2.5, + }, wantErr: ErrInvalidAlgorithm, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.h.Algorithm() - if err != tt.wantErr { + if tt.wantErr != nil && err.Error() != tt.wantErr.Error() { t.Errorf("ProtectedHeader.Algorithm() error = %v, wantErr %v", err, tt.wantErr) return } diff --git a/key.go b/key.go new file mode 100644 index 0000000..741eaf3 --- /dev/null +++ b/key.go @@ -0,0 +1,802 @@ +package cose + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "errors" + "fmt" + "math/big" + "strconv" + + cbor "github.com/fxamacker/cbor/v2" +) + +const ( + // An inviald key_op value + KeyOpInvalid KeyOp = 0 + + // The key is used to create signatures. Requires private key fields. + KeyOpSign KeyOp = 1 + + // The key is used for verification of signatures. + KeyOpVerify KeyOp = 2 + + // The key is used for key transport encryption. + KeyOpEncrypt KeyOp = 3 + + // The key is used for key transport decryption. Requires private key fields. + KeyOpDecrypt KeyOp = 4 + + // The key is used for key wrap encryption. + KeyOpWrapKey KeyOp = 5 + + // The key is used for key wrap decryption. + KeyOpUnwrapKey KeyOp = 6 + + // The key is used for deriving keys. Requires private key fields. + KeyOpDeriveKey KeyOp = 7 + + // The key is used for deriving bits not to be used as a key. Requires + // private key fields. + KeyOpDeriveBits KeyOp = 8 + + // The key is used for creating MACs. + KeyOpMACCreate KeyOp = 9 + + // The key is used for validating MACs. + KeyOpMACVerify KeyOp = 10 +) + +// KeyOp represents a key_ops value used to restrict purposes for which a Key +// may be used. +type KeyOp int64 + +// KeyOpFromString returns the KeyOp corresponding to the specified name. +// The values are taken from https://www.rfc-editor.org/rfc/rfc7517#section-4.3 +func KeyOpFromString(val string) (KeyOp, error) { + switch val { + case "sign": + return KeyOpSign, nil + case "verify": + return KeyOpVerify, nil + case "encrypt": + return KeyOpEncrypt, nil + case "decrypt": + return KeyOpDecrypt, nil + case "wrapKey": + return KeyOpWrapKey, nil + case "unwrapKey": + return KeyOpUnwrapKey, nil + case "deriveKey": + return KeyOpDeriveKey, nil + case "deriveBits": + return KeyOpDeriveBits, nil + default: + return KeyOpInvalid, fmt.Errorf("unknown key_ops value %q", val) + } +} + +// String returns a string representation of the KeyType. Note does not +// represent a valid value of the corresponding serialized entry, and must not +// be used as such. (The values returned _mostly_ correspond to those accepted +// by KeyOpFromString, except for MAC create/verify, which are not defined by +// RFC7517). +func (ko KeyOp) String() string { + switch ko { + case KeyOpSign: + return "sign" + case KeyOpVerify: + return "verify" + case KeyOpEncrypt: + return "encrypt" + case KeyOpDecrypt: + return "decrypt" + case KeyOpWrapKey: + return "wrapKey" + case KeyOpUnwrapKey: + return "unwrapKey" + case KeyOpDeriveKey: + return "deriveKey" + case KeyOpDeriveBits: + return "deriveBits" + case KeyOpMACCreate: + return "MAC create" + case KeyOpMACVerify: + return "MAC verify" + default: + return "unknown key_op value " + strconv.Itoa(int(ko)) + } +} + +// IsSupported returnns true if the specified value is represents one of the +// key_ops defined in +// https://www.rfc-editor.org/rfc/rfc9052.html#name-cose-key-common-parameters +func (ko KeyOp) IsSupported() bool { + return ko >= 1 && ko <= 10 +} + +// MarshalCBOR marshals the KeyOp as a CBOR int. +func (ko KeyOp) MarshalCBOR() ([]byte, error) { + return encMode.Marshal(int64(ko)) +} + +// UnmarshalCBOR populates the KeyOp from the provided CBOR value (must be int +// or tstr). +func (ko *KeyOp) UnmarshalCBOR(data []byte) error { + var raw intOrStr + + if err := raw.UnmarshalCBOR(data); err != nil { + return fmt.Errorf("invalid key_ops value %w", err) + } + + if raw.IsString() { + v, err := KeyOpFromString(raw.String()) + if err != nil { + return err + } + + *ko = v + } else { + v := raw.Int() + *ko = KeyOp(v) + + if !ko.IsSupported() { + return fmt.Errorf("unknown key_ops value %d", v) + } + } + + return nil +} + +// KeyType identifies the family of keys represented by the associated Key. +// This determines which files within the Key must be set in order for it to be +// valid. +type KeyType int64 + +const ( + // Invlaid key type + KeyTypeInvalid KeyType = 0 + // Octet Key Pair + KeyTypeOKP KeyType = 1 + // Elliptic Curve Keys w/ x- and y-coordinate pair + KeyTypeEC2 KeyType = 2 + // Symmetric Keys + KeyTypeSymmetric KeyType = 4 +) + +// String returns a string representation of the KeyType. Note does not +// represent a valid value of the corresponding serialized entry, and must +// not be used as such. +func (kt KeyType) String() string { + switch kt { + case KeyTypeOKP: + return "OKP" + case KeyTypeEC2: + return "EC2" + case KeyTypeSymmetric: + return "Symmetric" + default: + return "unknown key type value " + strconv.Itoa(int(kt)) + } +} + +// MarshalCBOR marshals the KeyType as a CBOR int. +func (kt KeyType) MarshalCBOR() ([]byte, error) { + return encMode.Marshal(int(kt)) +} + +// UnmarshalCBOR populates the KeyType from the provided CBOR value (must be +// int or tstr). +func (kt *KeyType) UnmarshalCBOR(data []byte) error { + var raw intOrStr + + if err := raw.UnmarshalCBOR(data); err != nil { + return fmt.Errorf("invalid key type value: %w", err) + } + + if raw.IsString() { + v, err := keyTypeFromString(raw.String()) + + if err != nil { + return err + } + + *kt = v + } else { + v := raw.Int() + + if v == 0 { + // 0 is reserved, and so can never be valid + return fmt.Errorf("invalid key type value 0") + } + + if v > 4 || v < 0 || v == 3 { + return fmt.Errorf("unknown key type value %d", v) + } + + *kt = KeyType(v) + } + + return nil +} + +// NOTE: there are currently no registered string key type values. +func keyTypeFromString(v string) (KeyType, error) { + return KeyTypeInvalid, fmt.Errorf("unknown key type value %q", v) +} + +const ( + + // Invalid/unrecognised curve + CurveInvalid Curve = 0 + + // NIST P-256 also known as secp256r1 + CurveP256 Curve = 1 + + // NIST P-384 also known as secp384r1 + CurveP384 Curve = 2 + + // NIST P-521 also known as secp521r1 + CurveP521 Curve = 3 + + // X25519 for use w/ ECDH only + CurveX25519 Curve = 4 + + // X448 for use w/ ECDH only + CurveX448 Curve = 5 + + // Ed25519 for use /w EdDSA only + CurveEd25519 Curve = 6 + + // Ed448 for use /w EdDSA only + CurveEd448 Curve = 7 +) + +// Curve represents the EC2/OKP key's curve. See: +// https://datatracker.ietf.org/doc/html/rfc8152#section-13.1 +type Curve int64 + +// String returns a string representation of the Curve. Note does not +// represent a valid value of the corresponding serialized entry, and must +// not be used as such. +func (c Curve) String() string { + switch c { + case CurveP256: + return "P-256" + case CurveP384: + return "P-384" + case CurveP521: + return "P-521" + case CurveX25519: + return "X25519" + case CurveX448: + return "X448" + case CurveEd25519: + return "Ed25519" + case CurveEd448: + return "Ed448" + default: + return "unknown curve value " + strconv.Itoa(int(c)) + } +} + +// MarshalCBOR marshals the KeyType as a CBOR int. +func (c Curve) MarshalCBOR() ([]byte, error) { + return encMode.Marshal(int(c)) +} + +// UnmarshalCBOR populates the KeyType from the provided CBOR value (must be +// int or tstr). +func (c *Curve) UnmarshalCBOR(data []byte) error { + var raw intOrStr + + if err := raw.UnmarshalCBOR(data); err != nil { + return fmt.Errorf("invalid curve value: %w", err) + } + + if raw.IsString() { + v, err := curveFromString(raw.String()) + + if err != nil { + return err + } + + *c = v + } else { + v := raw.Int() + + if v < 1 || v > 7 { + return fmt.Errorf("unknown curve value %d", v) + } + + *c = Curve(v) + } + + return nil +} + +// NOTE: there are currently no registered string values for curves. +func curveFromString(v string) (Curve, error) { + return CurveInvalid, fmt.Errorf("unknown curve value %q", v) +} + +// Key represents a COSE_Key structure, as defined by RFC8152. +// Note: currently, this does NOT support RFC8230 (RSA algorithms). +type Key struct { + // Common parameters. These are independent of the key type. Only + // KeyType common parameter MUST be set. + + // KeyType identifies the family of keys for this structure, and thus, + // which of the key-type-specific parameters need to be set. + KeyType KeyType `cbor:"1,keyasint"` + // KeyID is the identification value matched to the kid in the message. + KeyID []byte `cbor:"2,keyasint,omitempty"` + // KeyOps can be set to restrict the set of operations that the Key is used for. + KeyOps []KeyOp `cbor:"4,keyasint,omitempty"` + // BaseIV is the Base IV to be xor-ed with Partial IVs. + BaseIV []byte `cbor:"5,keyasint,omitempty"` + + // Algorithm is used to restrict the algorithm that is used with the + // key. If it is set, the application MUST verify that it matches the + // algorithm for which the Key is being used. + Algorithm Algorithm `cbor:"-"` + // Curve is EC identifier -- taken form "COSE Elliptic Curves" IANA registry. + // Populated from keyStruct.RawKeyParam when key type is EC2 or OKP. + Curve Curve `cbor:"-"` + // K is the key value. Populated from keyStruct.RawKeyParam when key + // type is Symmetric. + K []byte `cbor:"-"` + + // EC2/OKP params + + // X is the x-coordinate + X []byte `cbor:"-2,keyasint,omitempty"` + // Y is the y-coordinate (sign bits are not supported) + Y []byte `cbor:"-3,keyasint,omitempty"` + // D is the private key + D []byte `cbor:"-4,keyasint,omitempty"` +} + +// NewOKPKey returns a Key created using the provided Octet Key Pair data. +func NewOKPKey(alg Algorithm, x, d []byte) (*Key, error) { + if alg != AlgorithmEd25519 { + return nil, fmt.Errorf("unsupported algorithm %q", alg) + } + + key := &Key{ + KeyType: KeyTypeOKP, + Algorithm: alg, + Curve: CurveEd25519, + X: x, + D: d, + } + return key, key.Validate() +} + +// NewEC2Key returns a Key created using the provided elliptic curve key +// data. +func NewEC2Key(alg Algorithm, x, y, d []byte) (*Key, error) { + var curve Curve + + switch alg { + case AlgorithmES256: + curve = CurveP256 + case AlgorithmES384: + curve = CurveP384 + case AlgorithmES512: + curve = CurveP521 + default: + return nil, fmt.Errorf("unsupported algorithm %q", alg) + } + + key := &Key{ + KeyType: KeyTypeEC2, + Algorithm: alg, + Curve: curve, + X: x, + Y: y, + D: d, + } + return key, key.Validate() +} + +// NewSymmetricKey returns a Key created using the provided Symmetric key +// bytes. +func NewSymmetricKey(k []byte) (*Key, error) { + key := &Key{ + KeyType: KeyTypeSymmetric, + K: k, + } + return key, key.Validate() +} + +// NewKeyFromPublic returns a Key created using the provided crypto.PublicKey +// and Algorithm. +func NewKeyFromPublic(alg Algorithm, pub crypto.PublicKey) (*Key, error) { + switch alg { + case AlgorithmES256, AlgorithmES384, AlgorithmES512: + vk, ok := pub.(*ecdsa.PublicKey) + if !ok { + return nil, fmt.Errorf("%v: %w", alg, ErrInvalidPubKey) + } + + return NewEC2Key(alg, vk.X.Bytes(), vk.Y.Bytes(), nil) + case AlgorithmEd25519: + vk, ok := pub.(ed25519.PublicKey) + if !ok { + return nil, fmt.Errorf("%v: %w", alg, ErrInvalidPubKey) + } + + return NewOKPKey(alg, []byte(vk), nil) + default: + return nil, ErrAlgorithmNotSupported + } +} + +// NewKeyFromPrivate returns a Key created using provided crypto.PrivateKey +// and Algorithm. +func NewKeyFromPrivate(alg Algorithm, priv crypto.PrivateKey) (*Key, error) { + switch alg { + case AlgorithmES256, AlgorithmES384, AlgorithmES512: + sk, ok := priv.(*ecdsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("%v: %w", alg, ErrInvalidPrivKey) + } + + return NewEC2Key(alg, sk.X.Bytes(), sk.Y.Bytes(), sk.D.Bytes()) + case AlgorithmEd25519: + sk, ok := priv.(ed25519.PrivateKey) + if !ok { + return nil, fmt.Errorf("%v: %w", alg, ErrInvalidPrivKey) + } + return NewOKPKey(alg, []byte(sk[32:]), []byte(sk[:32])) + default: + return nil, ErrAlgorithmNotSupported + } +} + +// Validate ensures that the parameters set inside the Key are internally +// consistent (e.g., that the key type is appropriate to the curve.) +func (k Key) Validate() error { + switch k.KeyType { + case KeyTypeEC2: + switch k.Curve { + case CurveP256, CurveP384, CurveP521: + // ok + default: + return fmt.Errorf( + "EC2 curve must be P-256, P-384, or P-521; found %q", + k.Curve.String(), + ) + } + case KeyTypeOKP: + switch k.Curve { + case CurveX25519, CurveX448, CurveEd25519, CurveEd448: + // ok + default: + return fmt.Errorf( + "OKP curve must be X25519, X448, Ed25519, or Ed448; found %q", + k.Curve.String(), + ) + } + case KeyTypeSymmetric: + default: + return errors.New(k.KeyType.String()) + } + + // If Algorithm is set, it must match the specified key parameters. + if k.Algorithm != AlgorithmInvalid { + expectedAlg, err := k.deriveAlgorithm() + if err != nil { + return err + } + + if k.Algorithm != expectedAlg { + return fmt.Errorf( + "found algorithm %q (expected %q)", + k.Algorithm.String(), + expectedAlg.String(), + ) + } + } + + return nil +} + +type keyalias Key + +type marshaledKey struct { + keyalias + + // RawAlgorithm contains the raw Algorithm value, this is necessary + // because cbor library ignores omitempty on types that implement the + // cbor.Marshaler interface. + RawAlgorithm cbor.RawMessage `cbor:"3,keyasint,omitempty"` + + // RawKeyParam contains the raw CBOR encoded data for the label -1. + // Depending on the KeyType this is used to populate either Curve or K + // below. + RawKeyParam cbor.RawMessage `cbor:"-1,keyasint,omitempty"` +} + +// MarshalCBOR encodes Key into a COSE_Key object. +func (k *Key) MarshalCBOR() ([]byte, error) { + tmp := marshaledKey{ + keyalias: keyalias(*k), + } + var err error + + switch k.KeyType { + case KeyTypeSymmetric: + if tmp.RawKeyParam, err = encMode.Marshal(k.K); err != nil { + return nil, err + } + case KeyTypeEC2, KeyTypeOKP: + if tmp.RawKeyParam, err = encMode.Marshal(k.Curve); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("invalid key type: %q", k.KeyType.String()) + } + + if k.Algorithm != AlgorithmInvalid { + if tmp.RawAlgorithm, err = encMode.Marshal(k.Algorithm); err != nil { + return nil, err + } + } + + return encMode.Marshal(tmp) +} + +// UnmarshalCBOR decodes a COSE_Key object into Key. +func (k *Key) UnmarshalCBOR(data []byte) error { + var tmp marshaledKey + + if err := decMode.Unmarshal(data, &tmp); err != nil { + return err + } + *k = Key(tmp.keyalias) + + if tmp.RawAlgorithm != nil { + if err := decMode.Unmarshal(tmp.RawAlgorithm, &k.Algorithm); err != nil { + return err + } + } + + switch k.KeyType { + case KeyTypeEC2: + if tmp.RawKeyParam == nil { + return errors.New("missing Curve parameter (required for EC2 key type)") + } + + if err := decMode.Unmarshal(tmp.RawKeyParam, &k.Curve); err != nil { + return err + } + case KeyTypeOKP: + if tmp.RawKeyParam == nil { + return errors.New("missing Curve parameter (required for OKP key type)") + } + + if err := decMode.Unmarshal(tmp.RawKeyParam, &k.Curve); err != nil { + return err + } + case KeyTypeSymmetric: + if tmp.RawKeyParam == nil { + return errors.New("missing K parameter (required for Symmetric key type)") + } + + if err := decMode.Unmarshal(tmp.RawKeyParam, &k.K); err != nil { + return err + } + default: + // this should not be reachable as KeyType.UnmarshalCBOR would + // result in an error during decMode.Unmarshal() above, if the + // value in the data doesn't correspond to one of the above + // types. + return fmt.Errorf("unexpected key type %q", k.KeyType.String()) + } + + return k.Validate() +} + +// PublicKey returns a crypto.PublicKey generated using Key's parameters. +func (k *Key) PublicKey() (crypto.PublicKey, error) { + alg, err := k.deriveAlgorithm() + if err != nil { + return nil, err + } + + switch alg { + case AlgorithmES256, AlgorithmES384, AlgorithmES512: + var curve elliptic.Curve + + switch alg { + case AlgorithmES256: + curve = elliptic.P256() + case AlgorithmES384: + curve = elliptic.P384() + case AlgorithmES512: + curve = elliptic.P521() + } + + pub := &ecdsa.PublicKey{Curve: curve, X: new(big.Int), Y: new(big.Int)} + pub.X.SetBytes(k.X) + pub.Y.SetBytes(k.Y) + + return pub, nil + case AlgorithmEd25519: + return ed25519.PublicKey(k.X), nil + default: + return nil, ErrAlgorithmNotSupported + } +} + +// PrivateKey returns a crypto.PrivateKey generated using Key's parameters. +func (k *Key) PrivateKey() (crypto.PrivateKey, error) { + alg, err := k.deriveAlgorithm() + if err != nil { + return nil, err + } + + if len(k.D) == 0 { + return nil, ErrNotPrivKey + } + + switch alg { + case AlgorithmES256, AlgorithmES384, AlgorithmES512: + var curve elliptic.Curve + + switch alg { + case AlgorithmES256: + curve = elliptic.P256() + case AlgorithmES384: + curve = elliptic.P384() + case AlgorithmES512: + curve = elliptic.P521() + } + + priv := &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{Curve: curve, X: new(big.Int), Y: new(big.Int)}, + D: new(big.Int), + } + priv.X.SetBytes(k.X) + priv.Y.SetBytes(k.Y) + priv.D.SetBytes(k.D) + + return priv, nil + case AlgorithmEd25519: + buf := make([]byte, ed25519.PrivateKeySize) + + copy(buf, k.D) + copy(buf[32:], k.X) + + return ed25519.PrivateKey(buf), nil + default: + return nil, ErrAlgorithmNotSupported + } +} + +// AlgorithmOrDefault returns the Algorithm associated with Key. If Key.Algorithm is +// set, that is what is returned. Otherwise, the algorithm is inferred using +// Key.Curve. This method does NOT validate that Key.Algorithm, if set, aligns +// with Key.Curve. +func (k *Key) AlgorithmOrDefault() (Algorithm, error) { + if k.Algorithm != AlgorithmInvalid { + return k.Algorithm, nil + } + + return k.deriveAlgorithm() +} + +// Signer returns a Signer created using Key. +func (k *Key) Signer() (Signer, error) { + if err := k.Validate(); err != nil { + return nil, err + } + + if k.KeyOps != nil { + signFound := false + + for _, kop := range k.KeyOps { + if kop == KeyOpSign { + signFound = true + break + } + } + + if !signFound { + return nil, ErrSignOpNotSupported + } + } + + priv, err := k.PrivateKey() + if err != nil { + return nil, err + } + + alg, err := k.AlgorithmOrDefault() + if err != nil { + return nil, err + } + + var signer crypto.Signer + var ok bool + + switch alg { + case AlgorithmES256, AlgorithmES384, AlgorithmES512: + signer, ok = priv.(*ecdsa.PrivateKey) + if !ok { + return nil, ErrInvalidPrivKey + } + case AlgorithmEd25519: + signer, ok = priv.(ed25519.PrivateKey) + if !ok { + return nil, ErrInvalidPrivKey + } + default: + return nil, ErrAlgorithmNotSupported + } + + return NewSigner(alg, signer) +} + +// Verifier returns a Verifier created using Key. +func (k *Key) Verifier() (Verifier, error) { + if err := k.Validate(); err != nil { + return nil, err + } + + if k.KeyOps != nil { + verifyFound := false + + for _, kop := range k.KeyOps { + if kop == KeyOpVerify { + verifyFound = true + break + } + } + + if !verifyFound { + return nil, ErrVerifyOpNotSupported + } + } + + pub, err := k.PublicKey() + if err != nil { + return nil, err + } + + alg, err := k.AlgorithmOrDefault() + if err != nil { + return nil, err + } + + return NewVerifier(alg, pub) +} + +// deriveAlgorithm derives the intended algorithm for the key from its curve. +// The deriviation is based on the recommendation in RFC8152 that SHA-256 is +// only used with P-256, etc. For other combinations, the Algorithm in the Key +// must be explicitly set,so that this derivation is not used. +func (k *Key) deriveAlgorithm() (Algorithm, error) { + switch k.KeyType { + case KeyTypeEC2, KeyTypeOKP: + switch k.Curve { + case CurveP256: + return AlgorithmES256, nil + case CurveP384: + return AlgorithmES384, nil + case CurveP521: + return AlgorithmES512, nil + case CurveEd25519: + return AlgorithmEd25519, nil + default: + return AlgorithmInvalid, fmt.Errorf("unsupported curve %q", k.Curve.String()) + } + default: + // Symmetric algorithms are not supported in the current inmplementation. + return AlgorithmInvalid, fmt.Errorf("unexpected key type %q", k.KeyType.String()) + } +} diff --git a/key_test.go b/key_test.go new file mode 100644 index 0000000..e57db62 --- /dev/null +++ b/key_test.go @@ -0,0 +1,641 @@ +package cose + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "testing" + + "github.com/fxamacker/cbor/v2" +) + +func Test_KeyOp(t *testing.T) { + + tvs := []struct { + Name string + Value KeyOp + }{ + {"sign", KeyOpSign}, + {"verify", KeyOpVerify}, + {"encrypt", KeyOpEncrypt}, + {"decrypt", KeyOpDecrypt}, + {"wrapKey", KeyOpWrapKey}, + {"unwrapKey", KeyOpUnwrapKey}, + {"deriveKey", KeyOpDeriveKey}, + {"deriveBits", KeyOpDeriveBits}, + } + + for _, tv := range tvs { + if tv.Name != tv.Value.String() { + t.Errorf( + "String value mismatch: expected %q, got %q", + tv.Name, + tv.Value.String(), + ) + } + + data, err := cbor.Marshal(tv.Name) + if err != nil { + t.Errorf("Unexpected error: %s", err) + return + } + + var ko KeyOp + err = cbor.Unmarshal(data, &ko) + if err != nil { + t.Errorf("Unexpected error: %s", err) + return + } + if tv.Value != ko { + t.Errorf( + "Value mismatch: want %v, got %v", + tv.Value, + ko, + ) + } + + data, err = cbor.Marshal(int(tv.Value)) + if err != nil { + t.Errorf("Unexpected error: %q", err) + return + } + + err = cbor.Unmarshal(data, &ko) + if err != nil { + t.Errorf("Unexpected error: %q", err) + return + } + if tv.Value != ko { + t.Errorf( + "Value mismatch: want %v, got %v", + tv.Value, + ko, + ) + } + } + + var ko KeyOp + + data := []byte{0x20} + err := ko.UnmarshalCBOR(data) + assertEqualError(t, err, "unknown key_ops value -1") + + data = []byte{0x18, 0xff} + err = ko.UnmarshalCBOR(data) + assertEqualError(t, err, "unknown key_ops value 255") + + data = []byte{0x63, 0x66, 0x6f, 0x6f} + err = ko.UnmarshalCBOR(data) + assertEqualError(t, err, "unknown key_ops value \"foo\"") + + data = []byte{0x40} + err = ko.UnmarshalCBOR(data) + assertEqualError(t, err, "invalid key_ops value must be int or string, found []uint8") + + if "MAC create" != KeyOpMACCreate.String() { + t.Errorf("Unexpected value: %q", KeyOpMACCreate.String()) + } + + if "MAC verify" != KeyOpMACVerify.String() { + t.Errorf("Unexpected value: %q", KeyOpMACVerify.String()) + } + + if "unknown key_op value 42" != KeyOp(42).String() { + t.Errorf("Unexpected value: %q", KeyOp(42).String()) + } +} + +func Test_KeyType(t *testing.T) { + var ko KeyType + + data := []byte{0x20} + err := ko.UnmarshalCBOR(data) + assertEqualError(t, err, "unknown key type value -1") + + data = []byte{0x00} + err = ko.UnmarshalCBOR(data) + assertEqualError(t, err, "invalid key type value 0") + + data = []byte{0x03} + err = ko.UnmarshalCBOR(data) + assertEqualError(t, err, "unknown key type value 3") + + data = []byte{0x63, 0x66, 0x6f, 0x6f} + err = ko.UnmarshalCBOR(data) + assertEqualError(t, err, "unknown key type value \"foo\"") + + data = []byte{0x40} + err = ko.UnmarshalCBOR(data) + assertEqualError(t, err, "invalid key type value: must be int or string, found []uint8") +} + +func Test_Curve(t *testing.T) { + var c Curve + + data := []byte{0x20} + err := c.UnmarshalCBOR(data) + assertEqualError(t, err, "unknown curve value -1") + + data = []byte{0x00} + err = c.UnmarshalCBOR(data) + assertEqualError(t, err, "unknown curve value 0") + + data = []byte{0x63, 0x66, 0x6f, 0x6f} + err = c.UnmarshalCBOR(data) + assertEqualError(t, err, "unknown curve value \"foo\"") + + data = []byte{0x40} + err = c.UnmarshalCBOR(data) + assertEqualError(t, err, "invalid curve value: must be int or string, found []uint8") + + if "unknown curve value 42" != Curve(42).String() { + t.Errorf("Unexpected string value %q", Curve(42).String()) + } +} + +func Test_Key_UnmarshalCBOR(t *testing.T) { + tvs := []struct { + Name string + Value []byte + WantErr string + Validate func(k *Key) + }{ + { + Name: "ok OKP", + Value: []byte{ + 0xa5, // map (5) + 0x01, 0x01, // kty: OKP + 0x03, 0x27, // alg: EdDSA w/ Ed25519 + 0x04, // key ops + 0x81, // array (1) + 0x02, // verify + 0x20, 0x06, // curve: Ed25519 + 0x21, 0x58, 0x20, // x-coordinate: bytes(32) + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + WantErr: "", + Validate: func(k *Key) { + assertEqual(t, KeyTypeOKP, k.KeyType) + assertEqual(t, AlgorithmEd25519, k.Algorithm) + assertEqual(t, CurveEd25519, k.Curve) + assertEqual(t, []KeyOp{KeyOpVerify}, k.KeyOps) + assertEqual(t, []byte{ + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + k.X, + ) + assertEqual(t, []byte(nil), k.K) + }, + }, + { + Name: "invalid key type", + Value: []byte{ + 0xa1, // map (2) + 0x01, 0x00, // kty: invalid + }, + WantErr: "invalid key type value 0", + Validate: nil, + }, + { + Name: "missing curve OKP", + Value: []byte{ + 0xa1, // map (2) + 0x01, 0x01, // kty: OKP + }, + WantErr: "missing Curve parameter (required for OKP key type)", + Validate: nil, + }, + { + Name: "missing curve EC2", + Value: []byte{ + 0xa1, // map (2) + 0x01, 0x02, // kty: EC2 + }, + WantErr: "missing Curve parameter (required for EC2 key type)", + Validate: nil, + }, + { + Name: "invalid curve OKP", + Value: []byte{ + 0xa2, // map (2) + 0x01, 0x01, // kty: OKP + 0x20, 0x01, // curve: CurveP256 + }, + WantErr: "OKP curve must be X25519, X448, Ed25519, or Ed448; found \"P-256\"", + Validate: nil, + }, + { + Name: "invalid curve EC2", + Value: []byte{ + 0xa2, // map (2) + 0x01, 0x02, // kty: EC2 + 0x20, 0x06, // curve: CurveEd25519 + }, + WantErr: "EC2 curve must be P-256, P-384, or P-521; found \"Ed25519\"", + Validate: nil, + }, + { + Name: "ok Symmetric", + Value: []byte{ + 0xa4, // map (4) + 0x01, 0x04, // kty: Symmetric + 0x03, 0x38, 0x24, // alg: PS256 + 0x04, // key ops + 0x81, // array (1) + 0x02, // verify + 0x20, 0x58, 0x20, // k: bytes(32) + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + WantErr: "", + Validate: func(k *Key) { + assertEqual(t, KeyTypeSymmetric, k.KeyType) + assertEqual(t, AlgorithmPS256, k.Algorithm) + assertEqual(t, int64(0), int64(k.Curve)) + assertEqual(t, []KeyOp{KeyOpVerify}, k.KeyOps) + assertEqual(t, []byte{ + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + k.K, + ) + }, + }, + { + Name: "missing K", + Value: []byte{ + 0xa1, // map (1) + 0x01, 0x04, // kty: Symmetric + }, + WantErr: "missing K parameter (required for Symmetric key type)", + Validate: nil, + }, + { + Name: "wrong algorithm", + Value: []byte{ + 0xa4, // map (3) + 0x01, 0x01, // kty: OKP + 0x03, 0x26, // alg: ECDSA w/ SHA-256 + 0x20, 0x06, // curve: Ed25519 + 0x21, 0x58, 0x20, // x-coordinate: bytes(32) + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + WantErr: "found algorithm \"ES256\" (expected \"EdDSA\")", + Validate: nil, + }, + } + + for _, tv := range tvs { + t.Run(tv.Name, func(t *testing.T) { + var k Key + + err := k.UnmarshalCBOR(tv.Value) + if tv.WantErr != "" { + if err == nil || err.Error() != tv.WantErr { + t.Errorf("Unexpected error: want %q, got %q", tv.WantErr, err) + } + } else { + tv.Validate(&k) + } + }) + } +} + +func Test_Key_MarshalCBOR(t *testing.T) { + k := Key{ + KeyType: KeyTypeOKP, + KeyOps: []KeyOp{KeyOpVerify, KeyOpEncrypt}, + X: []byte{ + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + Algorithm: AlgorithmEd25519, + Curve: CurveEd25519, + } + + data, err := k.MarshalCBOR() + if err != nil { + t.Errorf("Unexpected error: %s", err) + return + } + expected := []byte{ + 0xa5, // map (5) + 0x01, 0x01, // kty: OKP + 0x03, 0x27, // alg: EdDSA w/ Ed25519 + 0x04, // key ops + 0x82, // array (2) + 0x02, 0x03, // verify, encrypt + 0x20, 0x06, // curve: Ed25519 + 0x21, 0x58, 0x20, // x-coordinate: bytes(32) + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + } + if !bytes.Equal(expected, data) { + t.Errorf("Bad marshal: %v", data) + } + + k = Key{ + KeyType: KeyTypeSymmetric, + K: []byte{ + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + } + + data, err = k.MarshalCBOR() + if err != nil { + t.Errorf("Unexpected error: %s", err) + return + } + expected = []byte{ + 0xa2, // map (2) + 0x01, 0x04, // kty: Symmetric + 0x20, 0x58, 0x20, // K: bytes(32) + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + } + if !bytes.Equal(expected, data) { + t.Errorf("Bad marshal: %v", data) + } + + k.KeyType = KeyType(42) + _, err = k.MarshalCBOR() + wantErr := "invalid key type: \"unknown key type value 42\"" + if err == nil || err.Error() != wantErr { + t.Errorf("Unexpected error: want %q, got %q", wantErr, err) + } +} + +func Test_Key_Create_and_Validate(t *testing.T) { + x := []byte{ + 0x30, 0xa0, 0x42, 0x4c, 0xd2, 0x1c, 0x29, 0x44, + 0x83, 0x8a, 0x2d, 0x75, 0xc9, 0x2b, 0x37, 0xe7, + 0x6e, 0xa2, 0x0d, 0x9f, 0x00, 0x89, 0x3a, 0x3b, + 0x4e, 0xee, 0x8a, 0x3c, 0x0a, 0xaf, 0xec, 0x3e, + } + + y := []byte{ + 0xe0, 0x4b, 0x65, 0xe9, 0x24, 0x56, 0xd9, 0x88, + 0x8b, 0x52, 0xb3, 0x79, 0xbd, 0xfb, 0xd5, 0x1e, + 0xe8, 0x69, 0xef, 0x1f, 0x0f, 0xc6, 0x5b, 0x66, + 0x59, 0x69, 0x5b, 0x6c, 0xce, 0x08, 0x17, 0x23, + } + + key, err := NewOKPKey(AlgorithmEd25519, x, nil) + requireNoError(t, err) + assertEqual(t, KeyTypeOKP, key.KeyType) + assertEqual(t, x, key.X) + + _, err = NewOKPKey(AlgorithmES256, x, nil) + assertEqualError(t, err, "unsupported algorithm \"ES256\"") + + _, err = NewEC2Key(AlgorithmEd25519, x, y, nil) + assertEqualError(t, err, "unsupported algorithm \"EdDSA\"") + + key, err = NewEC2Key(AlgorithmES256, x, y, nil) + requireNoError(t, err) + assertEqual(t, KeyTypeEC2, key.KeyType) + assertEqual(t, x, key.X) + assertEqual(t, y, key.Y) + + key, err = NewSymmetricKey(x) + requireNoError(t, err) + assertEqual(t, x, key.K) + + key.KeyType = KeyType(7) + err = key.Validate() + assertEqualError(t, err, "unknown key type value 7") + + _, err = NewKeyFromPublic(AlgorithmES256, + crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef})) + assertEqualError(t, err, "ES256: invalid public key") + + _, err = NewKeyFromPublic(AlgorithmEd25519, + crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef})) + assertEqualError(t, err, "EdDSA: invalid public key") + + _, err = NewKeyFromPublic(AlgorithmInvalid, + crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef})) + assertEqualError(t, err, "algorithm not supported") + + _, err = NewKeyFromPrivate(AlgorithmES256, + crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef})) + assertEqualError(t, err, "ES256: invalid private key") + + _, err = NewKeyFromPrivate(AlgorithmEd25519, + crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef})) + assertEqualError(t, err, "EdDSA: invalid private key") + + _, err = NewKeyFromPrivate(AlgorithmInvalid, + crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef})) + assertEqualError(t, err, "algorithm not supported") +} + +func Test_Key_ed25519_signature_round_trip(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + requireNoError(t, err) + + key, err := NewKeyFromPrivate(AlgorithmEd25519, priv) + requireNoError(t, err) + assertEqual(t, AlgorithmEd25519, key.Algorithm) + assertEqual(t, CurveEd25519, key.Curve) + assertEqual(t, pub, key.X) + assertEqual(t, priv[:32], key.D) + + signer, err := key.Signer() + requireNoError(t, err) + + message := []byte("foo bar") + sig, err := signer.Sign(rand.Reader, message) + requireNoError(t, err) + + key, err = NewKeyFromPublic(AlgorithmEd25519, pub) + requireNoError(t, err) + + assertEqual(t, AlgorithmEd25519, key.Algorithm) + assertEqual(t, CurveEd25519, key.Curve) + assertEqual(t, pub, key.X) + + verifier, err := key.Verifier() + requireNoError(t, err) + + err = verifier.Verify(message, sig) + requireNoError(t, err) +} + +func Test_Key_ecdsa_signature_round_trip(t *testing.T) { + for _, tv := range []struct { + EC elliptic.Curve + Curve Curve + Algorithm Algorithm + }{ + {elliptic.P256(), CurveP256, AlgorithmES256}, + {elliptic.P384(), CurveP384, AlgorithmES384}, + {elliptic.P521(), CurveP521, AlgorithmES512}, + } { + t.Run(tv.Curve.String(), func(t *testing.T) { + priv, err := ecdsa.GenerateKey(tv.EC, rand.Reader) + requireNoError(t, err) + + key, err := NewKeyFromPrivate(tv.Algorithm, priv) + requireNoError(t, err) + assertEqual(t, tv.Algorithm, key.Algorithm) + assertEqual(t, tv.Curve, key.Curve) + assertEqual(t, priv.X.Bytes(), key.X) + assertEqual(t, priv.Y.Bytes(), key.Y) + assertEqual(t, priv.D.Bytes(), key.D) + + signer, err := key.Signer() + requireNoError(t, err) + + message := []byte("foo bar") + sig, err := signer.Sign(rand.Reader, message) + requireNoError(t, err) + + pub := priv.Public() + + key, err = NewKeyFromPublic(tv.Algorithm, pub) + requireNoError(t, err) + + assertEqual(t, tv.Algorithm, key.Algorithm) + assertEqual(t, tv.Curve, key.Curve) + assertEqual(t, priv.X.Bytes(), key.X) + assertEqual(t, priv.Y.Bytes(), key.Y) + + verifier, err := key.Verifier() + requireNoError(t, err) + + err = verifier.Verify(message, sig) + requireNoError(t, err) + }) + } +} + +func Test_Key_derive_algorithm(t *testing.T) { + k := Key{ + KeyType: KeyTypeOKP, + Curve: CurveX448, + } + + _, err := k.AlgorithmOrDefault() + assertEqualError(t, err, "unsupported curve \"X448\"") + + k = Key{ + KeyType: KeyTypeOKP, + Curve: CurveEd25519, + } + + alg, err := k.AlgorithmOrDefault() + requireNoError(t, err) + assertEqual(t, AlgorithmEd25519, alg) +} + +func Test_Key_signer_validation(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + requireNoError(t, err) + + key, err := NewKeyFromPublic(AlgorithmEd25519, pub) + requireNoError(t, err) + + _, err = key.Signer() + assertEqualError(t, err, ErrNotPrivKey.Error()) + + key, err = NewKeyFromPrivate(AlgorithmEd25519, priv) + requireNoError(t, err) + + key.KeyType = KeyTypeEC2 + _, err = key.Signer() + assertEqualError(t, err, "EC2 curve must be P-256, P-384, or P-521; found \"Ed25519\"") + + key.Curve = CurveP256 + _, err = key.Signer() + assertEqualError(t, err, "found algorithm \"EdDSA\" (expected \"ES256\")") + + key.KeyType = KeyTypeOKP + key.Algorithm = AlgorithmEd25519 + key.Curve = CurveEd25519 + key.KeyOps = []KeyOp{} + _, err = key.Signer() + assertEqualError(t, err, ErrSignOpNotSupported.Error()) + + key.KeyOps = []KeyOp{KeyOpSign} + _, err = key.Signer() + requireNoError(t, err) + + key.Algorithm = AlgorithmES256 + _, err = key.Signer() + assertEqualError(t, err, "found algorithm \"ES256\" (expected \"EdDSA\")") + + key.Curve = CurveX448 + _, err = key.Signer() + assertEqualError(t, err, "unsupported curve \"X448\"") +} + +func Test_Key_verifier_validation(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + requireNoError(t, err) + + key, err := NewKeyFromPublic(AlgorithmEd25519, pub) + requireNoError(t, err) + + _, err = key.Verifier() + requireNoError(t, err) + + key.KeyType = KeyTypeEC2 + _, err = key.Verifier() + assertEqualError(t, err, "EC2 curve must be P-256, P-384, or P-521; found \"Ed25519\"") + + key.KeyType = KeyTypeOKP + key.KeyOps = []KeyOp{} + _, err = key.Verifier() + assertEqualError(t, err, ErrVerifyOpNotSupported.Error()) + + key.KeyOps = []KeyOp{KeyOpVerify} + _, err = key.Verifier() + requireNoError(t, err) +} + +func Test_Key_crypto_keys(t *testing.T) { + k := Key{ + KeyType: KeyType(7), + } + + _, err := k.PublicKey() + assertEqualError(t, err, "unexpected key type \"unknown key type value 7\"") + _, err = k.PrivateKey() + assertEqualError(t, err, "unexpected key type \"unknown key type value 7\"") + + k = Key{ + KeyType: KeyTypeOKP, + Curve: CurveX448, + } + + _, err = k.PublicKey() + assertEqualError(t, err, "unsupported curve \"X448\"") + _, err = k.PrivateKey() + assertEqualError(t, err, "unsupported curve \"X448\"") +}