From 1195e9a7b9c3fe9a4cf7aa92da08cf68d251f34e Mon Sep 17 00:00:00 2001 From: Yogesh Deshpande Date: Tue, 21 Feb 2023 06:50:50 -0500 Subject: [PATCH 1/2] Enhance individual UT in go-cose library to handle specific errors Signed-off-by: Yogesh Deshpande --- cbor_test.go | 17 ++-- ecdsa_test.go | 15 ++-- headers_test.go | 164 +++++++++++++++++++++--------------- sign1_test.go | 105 +++++++++++++---------- sign_test.go | 220 +++++++++++++++++++++++++++--------------------- signer_test.go | 17 ++-- 6 files changed, 305 insertions(+), 233 deletions(-) diff --git a/cbor_test.go b/cbor_test.go index 2fdd6b9..b96b13b 100644 --- a/cbor_test.go +++ b/cbor_test.go @@ -20,7 +20,7 @@ func Test_byteString_UnmarshalCBOR(t *testing.T) { name string data []byte want byteString - wantErr bool + wantErr string }{ { name: "valid string", @@ -40,33 +40,36 @@ func Test_byteString_UnmarshalCBOR(t *testing.T) { { name: "undefined string", data: []byte{0xf7}, - wantErr: true, + wantErr: "cbor: require bstr type", }, { name: "nil CBOR data", data: nil, - wantErr: true, + wantErr: "EOF", }, { name: "empty CBOR data", data: []byte{}, - wantErr: true, + wantErr: "EOF", }, { name: "tagged string", data: []byte{0xc2, 0x40}, - wantErr: true, + wantErr: "cbor: require bstr type", }, { name: "array of bytes", // issue #46 data: []byte{0x82, 0x00, 0x1}, - wantErr: true, + wantErr: "cbor: require bstr type", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got byteString - if err := got.UnmarshalCBOR(tt.data); (err != nil) != tt.wantErr { + err := got.UnmarshalCBOR(tt.data) + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("byteString.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + } else if err == nil && (tt.wantErr != "") { t.Errorf("byteString.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) } if !bytes.Equal(got, tt.want) { diff --git a/ecdsa_test.go b/ecdsa_test.go index 4a32519..feaa673 100644 --- a/ecdsa_test.go +++ b/ecdsa_test.go @@ -20,31 +20,31 @@ func TestI2OSP(t *testing.T) { x *big.Int buf []byte want []byte - wantErr bool + wantErr string }{ { name: "negative int", x: big.NewInt(-1), buf: make([]byte, 2), - wantErr: true, + wantErr: "I2OSP: negative integer", }, { name: "integer too large #1", x: big.NewInt(1), buf: make([]byte, 0), - wantErr: true, + wantErr: "I2OSP: integer too large", }, { name: "integer too large #2", x: big.NewInt(256), buf: make([]byte, 0), - wantErr: true, + wantErr: "I2OSP: integer too large", }, { name: "integer too large #3", x: big.NewInt(1 << 24), buf: make([]byte, 3), - wantErr: true, + wantErr: "I2OSP: integer too large", }, { name: "zero length string", @@ -98,11 +98,12 @@ func TestI2OSP(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := I2OSP(tt.x, tt.buf) - if (err != nil) != tt.wantErr { + if err != nil && (err.Error() != tt.wantErr) { t.Errorf("I2OSP() error = %v, wantErr %v", err, tt.wantErr) return } - if got := tt.buf; !tt.wantErr && !reflect.DeepEqual(got, tt.want) { + + if got := tt.buf; (tt.wantErr == "") && !reflect.DeepEqual(got, tt.want) { t.Errorf("I2OSP() = %v, want %v", got, tt.want) } }) diff --git a/headers_test.go b/headers_test.go index f0f6e11..7981028 100644 --- a/headers_test.go +++ b/headers_test.go @@ -10,7 +10,7 @@ func TestProtectedHeader_MarshalCBOR(t *testing.T) { name string h ProtectedHeader want []byte - wantErr bool + wantErr string }{ { name: "valid header", @@ -78,21 +78,21 @@ func TestProtectedHeader_MarshalCBOR(t *testing.T) { value int }{}: 42, }, - wantErr: true, + wantErr: "protected header: header label: require int / tstr type", }, { name: "empty critical", h: ProtectedHeader{ HeaderLabelCritical: []interface{}{}, }, - wantErr: true, + wantErr: "protected header: header parameter: crit: empty crit header", }, { name: "invalid critical", h: ProtectedHeader{ HeaderLabelCritical: 42, }, - wantErr: true, + wantErr: "protected header: header parameter: crit: invalid crit header", }, { name: "missing header marked as critical", @@ -101,14 +101,14 @@ func TestProtectedHeader_MarshalCBOR(t *testing.T) { HeaderLabelContentType, }, }, - wantErr: true, + wantErr: "protected header: header parameter: crit: missing critical header: 3", }, { name: "critical header contains non-label element", h: ProtectedHeader{ HeaderLabelCritical: []interface{}{[]uint8{}}, }, - wantErr: true, + wantErr: "protected header: header parameter: crit: require int / tstr type, got '[]uint8': []", }, { name: "duplicated key", @@ -116,14 +116,14 @@ func TestProtectedHeader_MarshalCBOR(t *testing.T) { int8(42): "foo", int64(42): "bar", }, - wantErr: true, + wantErr: "protected header: header label: duplicated label: 42", }, { name: "un-marshalable content", h: ProtectedHeader{ "foo": make(chan bool), }, - wantErr: true, + wantErr: "cbor: unsupported type: chan bool", }, { name: "iv and partial iv present", @@ -131,41 +131,44 @@ func TestProtectedHeader_MarshalCBOR(t *testing.T) { HeaderLabelIV: []byte("foo"), HeaderLabelPartialIV: []byte("bar"), }, - wantErr: true, + wantErr: "protected header: header parameter: IV and PartialIV: parameters must not both be present", }, { name: "content type is string", h: ProtectedHeader{ HeaderLabelContentType: []byte("foo"), }, - wantErr: true, + wantErr: "protected header: header parameter: content type: require tstr / uint type", }, { name: "content type is negative int8", h: ProtectedHeader{ HeaderLabelContentType: int8(-1), }, - wantErr: true, + wantErr: "protected header: header parameter: content type: require tstr / uint type", }, { name: "content type is negative int16", h: ProtectedHeader{ HeaderLabelContentType: int16(-1), }, - wantErr: true, + wantErr: "protected header: header parameter: content type: require tstr / uint type", }, { name: "content type is negative int32", h: ProtectedHeader{ HeaderLabelContentType: int32(-1), }, - wantErr: true, + wantErr: "protected header: header parameter: content type: require tstr / uint type", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.h.MarshalCBOR() - if (err != nil) != tt.wantErr { + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("ProtectedHeader.MarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && tt.wantErr != "" { t.Errorf("ProtectedHeader.MarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) return } @@ -181,7 +184,7 @@ func TestProtectedHeader_UnmarshalCBOR(t *testing.T) { name string data []byte want ProtectedHeader - wantErr bool + wantErr string }{ { name: "valid header", @@ -216,117 +219,121 @@ func TestProtectedHeader_UnmarshalCBOR(t *testing.T) { { name: "nil CBOR data", data: nil, - wantErr: true, + wantErr: "EOF", }, { name: "empty CBOR data", data: []byte{}, - wantErr: true, + wantErr: "EOF", }, { name: "bad CBOR data", data: []byte{0x00, 0x01, 0x02, 0x04}, - wantErr: true, + wantErr: "cbor: require bstr type", }, { name: "nil bstr", data: []byte{0xf6}, - wantErr: true, + wantErr: "cbor: nil protected header", }, { name: "non-map header", data: []byte{0x41, 0x00}, - wantErr: true, + wantErr: "cbor: protected header: require map type", }, { name: "invalid header label type: bstr type", data: []byte{ 0x43, 0xa1, 0x40, 0x00, }, - wantErr: true, + wantErr: "cbor: header label: require int / tstr type", }, { name: "invalid header label type: major type 7: simple value", // issue #38 data: []byte{ 0x43, 0xa1, 0xf3, 0x00, }, - wantErr: true, + wantErr: "cbor: header label: require int / tstr type", }, { name: "empty critical", data: []byte{ 0x43, 0xa1, 0x02, 0x80, }, - wantErr: true, + wantErr: "protected header: header parameter: crit: empty crit header", }, { name: "invalid critical", data: []byte{ 0x43, 0xa1, 0x02, 0x00, }, - wantErr: true, + wantErr: "protected header: header parameter: crit: invalid crit header", }, { name: "missing header marked as critical", data: []byte{ 0x44, 0xa1, 0x02, 0x81, 0x03, }, - wantErr: true, + wantErr: "protected header: header parameter: crit: missing critical header: 3", }, { name: "critical header contains non-label element", data: []byte{ 0x44, 0xa1, 0x2, 0x81, 0x40, }, - wantErr: true, + wantErr: "protected header: header parameter: crit: require int / tstr type, got '[]uint8': []", }, { name: "duplicated key", data: []byte{ 0x45, 0xa2, 0x01, 0x00, 0x01, 0x00, }, - wantErr: true, + wantErr: "cbor: found duplicate map key \"1\" at map element index 1", }, { name: "incomplete CBOR data", data: []byte{ 0x45, }, - wantErr: true, + wantErr: "unexpected EOF", }, { name: "invalid map value", data: []byte{ 0x46, 0xa1, 0x00, 0xa1, 0x00, 0x4f, 0x01, }, - wantErr: true, + wantErr: "unexpected EOF", }, { name: "int map key too large", data: []byte{ 0x4b, 0xa1, 0x3b, 0x83, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, }, - wantErr: true, + wantErr: "cbor: header label: int key must not be higher than 1<<63 - 1", }, { name: "header as a byte array", data: []byte{ 0x80, }, - wantErr: true, + wantErr: "cbor: require bstr type", }, { name: "iv and partial iv present", data: []byte{ 0x4b, 0xa2, 0x5, 0x63, 0x66, 0x6f, 0x6f, 0x6, 0x63, 0x62, 0x61, 0x72, }, - wantErr: true, + wantErr: "protected header: header parameter: IV: require bstr type", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got ProtectedHeader - if err := got.UnmarshalCBOR(tt.data); (err != nil) != tt.wantErr { + err := got.UnmarshalCBOR(tt.data) + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("ProtectedHeader.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && tt.wantErr != "" { t.Errorf("ProtectedHeader.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) return } @@ -430,7 +437,7 @@ func TestProtectedHeader_Critical(t *testing.T) { name string h ProtectedHeader want []interface{} - wantErr bool + wantErr string }{ { name: "valid header", @@ -470,20 +477,23 @@ func TestProtectedHeader_Critical(t *testing.T) { h: ProtectedHeader{ HeaderLabelCritical: []interface{}{}, }, - wantErr: true, + wantErr: "empty crit header", }, { name: "invalid critical", h: ProtectedHeader{ HeaderLabelCritical: 42, }, - wantErr: true, + wantErr: "invalid crit header", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.h.Critical() - if (err != nil) != tt.wantErr { + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("ProtectedHeader.Critical() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && tt.wantErr != "" { t.Errorf("ProtectedHeader.Critical() error = %v, wantErr %v", err, tt.wantErr) return } @@ -499,7 +509,7 @@ func TestUnprotectedHeader_MarshalCBOR(t *testing.T) { name string h UnprotectedHeader want []byte - wantErr bool + wantErr string }{ { name: "valid header", @@ -557,7 +567,7 @@ func TestUnprotectedHeader_MarshalCBOR(t *testing.T) { value int }{}: 42, }, - wantErr: true, + wantErr: "unprotected header: header label: require int / tstr type", }, { name: "duplicated key", @@ -565,14 +575,14 @@ func TestUnprotectedHeader_MarshalCBOR(t *testing.T) { int8(42): "foo", int64(42): "bar", }, - wantErr: true, + wantErr: "unprotected header: header label: duplicated label: 42", }, { name: "un-marshalable content", h: UnprotectedHeader{ "foo": make(chan bool), }, - wantErr: true, + wantErr: "cbor: unsupported type: chan bool", }, { name: "iv and partial iv present", @@ -580,20 +590,23 @@ func TestUnprotectedHeader_MarshalCBOR(t *testing.T) { HeaderLabelIV: []byte("foo"), HeaderLabelPartialIV: []byte("bar"), }, - wantErr: true, + wantErr: "unprotected header: header parameter: IV and PartialIV: parameters must not both be present", }, { name: "critical present", h: UnprotectedHeader{ HeaderLabelCritical: []string{"foo"}, }, - wantErr: true, + wantErr: "unprotected header: header parameter: crit: not allowed", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.h.MarshalCBOR() - if (err != nil) != tt.wantErr { + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("UnprotectedHeader.MarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && tt.wantErr != "" { t.Errorf("UnprotectedHeader.MarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) return } @@ -609,7 +622,7 @@ func TestUnprotectedHeader_UnmarshalCBOR(t *testing.T) { name string data []byte want UnprotectedHeader - wantErr bool + wantErr string }{ { name: "valid header", @@ -630,71 +643,71 @@ func TestUnprotectedHeader_UnmarshalCBOR(t *testing.T) { { name: "nil CBOR data", data: nil, - wantErr: true, + wantErr: "cbor: nil unprotected header", }, { name: "empty CBOR data", data: []byte{}, - wantErr: true, + wantErr: "cbor: unprotected header: missing type", }, { name: "bad CBOR data", data: []byte{0x00, 0x01, 0x02, 0x04}, - wantErr: true, + wantErr: "cbor: unprotected header: require map type", }, { name: "non-map header", data: []byte{0x00}, - wantErr: true, + wantErr: "cbor: unprotected header: require map type", }, { name: "invalid header label type: bstr type", data: []byte{ 0xa1, 0x40, 0x00, }, - wantErr: true, + wantErr: "cbor: header label: require int / tstr type", }, { name: "invalid header label type: major type 7: simple value", // issue #38 data: []byte{ 0xa1, 0xf3, 0x00, }, - wantErr: true, + wantErr: "cbor: header label: require int / tstr type", }, { name: "duplicated key", data: []byte{ 0xa2, 0x01, 0x00, 0x01, 0x00, }, - wantErr: true, + wantErr: "cbor: found duplicate map key \"1\" at map element index 1", }, { name: "incomplete CBOR data", data: []byte{ 0xa5, }, - wantErr: true, + wantErr: "unexpected EOF", }, { name: "invalid map value", data: []byte{ 0xa1, 0x00, 0xa1, 0x00, 0x4f, 0x01, }, - wantErr: true, + wantErr: "unexpected EOF", }, { name: "int map key too large", data: []byte{ 0xa1, 0x3b, 0x83, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, }, - wantErr: true, + wantErr: "cbor: header label: int key must not be higher than 1<<63 - 1", }, { name: "iv and partial iv present", data: []byte{ 0xa2, 0x5, 0x63, 0x66, 0x6f, 0x6f, 0x6, 0x63, 0x62, 0x61, 0x72, }, - wantErr: true, + wantErr: "unprotected header: header parameter: IV: require bstr type", }, { name: "critical present", @@ -702,13 +715,17 @@ func TestUnprotectedHeader_UnmarshalCBOR(t *testing.T) { 0xa1, // map 0x02, 0x82, 0x03, 0x63, 0x66, 0x6f, 0x6f, // crit }, - wantErr: true, + wantErr: "unprotected header: header parameter: crit: not allowed", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got UnprotectedHeader - if err := got.UnmarshalCBOR(tt.data); (err != nil) != tt.wantErr { + err := got.UnmarshalCBOR(tt.data) + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("UnprotectedHeader.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && tt.wantErr != "" { t.Errorf("UnprotectedHeader.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) return } @@ -724,7 +741,7 @@ func TestHeaders_MarshalProtected(t *testing.T) { name string h Headers want []byte - wantErr bool + wantErr string }{ { name: "pre-marshaled protected header", @@ -771,13 +788,16 @@ func TestHeaders_MarshalProtected(t *testing.T) { HeaderLabelKeyID: 42, }, }, - wantErr: true, + wantErr: "protected header: header parameter: alg: require int / tstr type", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.h.MarshalProtected() - if (err != nil) != tt.wantErr { + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("Headers.MarshalProtected() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && tt.wantErr != "" { t.Errorf("Headers.MarshalProtected() error = %v, wantErr %v", err, tt.wantErr) return } @@ -793,7 +813,7 @@ func TestHeaders_MarshalUnprotected(t *testing.T) { name string h Headers want []byte - wantErr bool + wantErr string }{ { name: "pre-marshaled protected header", @@ -843,13 +863,16 @@ func TestHeaders_MarshalUnprotected(t *testing.T) { HeaderLabelKeyID: make(chan bool), }, }, - wantErr: true, + wantErr: "unprotected header: header parameter: kid: require bstr type", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.h.MarshalUnprotected() - if (err != nil) != tt.wantErr { + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("Headers.MarshalUnprotected() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && tt.wantErr != "" { t.Errorf("Headers.MarshalUnprotected() error = %v, wantErr %v", err, tt.wantErr) return } @@ -865,21 +888,21 @@ func TestHeaders_UnmarshalFromRaw(t *testing.T) { name string h Headers want Headers - wantErr bool + wantErr string }{ { name: "nil raw protected header", h: Headers{ RawUnprotected: []byte{0xa1, 0x04, 0x18, 0x2a}, }, - wantErr: true, + wantErr: "cbor: invalid protected header: EOF", }, { name: "nil raw unprotected header", h: Headers{ RawProtected: []byte{0x43, 0xa1, 0x01, 0x26}, }, - wantErr: true, + wantErr: "cbor: invalid unprotected header: EOF", }, { name: "valid raw header", @@ -925,8 +948,13 @@ func TestHeaders_UnmarshalFromRaw(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := tt.h - if err := got.UnmarshalFromRaw(); (err != nil) != tt.wantErr { + err := got.UnmarshalFromRaw() + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("Headers.UnmarshalFromRaw() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && tt.wantErr != "" { t.Errorf("Headers.UnmarshalFromRaw() error = %v, wantErr %v", err, tt.wantErr) + return } }) } diff --git a/sign1_test.go b/sign1_test.go index 5a45fcf..b713c54 100644 --- a/sign1_test.go +++ b/sign1_test.go @@ -12,7 +12,7 @@ func TestSign1Message_MarshalCBOR(t *testing.T) { name string m *Sign1Message want []byte - wantErr bool + wantErr string }{ { name: "valid message", @@ -40,7 +40,7 @@ func TestSign1Message_MarshalCBOR(t *testing.T) { { name: "nil message", m: nil, - wantErr: true, + wantErr: "cbor: MarshalCBOR on nil Sign1Message pointer", }, { name: "nil payload", @@ -79,7 +79,7 @@ func TestSign1Message_MarshalCBOR(t *testing.T) { Payload: []byte("foo"), Signature: nil, }, - wantErr: true, + wantErr: "empty signature", }, { name: "empty signature", @@ -95,7 +95,7 @@ func TestSign1Message_MarshalCBOR(t *testing.T) { Payload: nil, Signature: []byte{}, }, - wantErr: true, + wantErr: "empty signature", }, { name: "invalid protected header", @@ -111,7 +111,7 @@ func TestSign1Message_MarshalCBOR(t *testing.T) { Payload: []byte("foo"), Signature: []byte("bar"), }, - wantErr: true, + wantErr: "protected header: header parameter: alg: require int / tstr type", }, { name: "invalid unprotected header", @@ -127,7 +127,7 @@ func TestSign1Message_MarshalCBOR(t *testing.T) { Payload: []byte("foo"), Signature: []byte("bar"), }, - wantErr: true, + wantErr: "cbor: unsupported type: chan bool", }, { name: "protected has IV and unprotected has PartialIV error", @@ -144,7 +144,7 @@ func TestSign1Message_MarshalCBOR(t *testing.T) { Payload: []byte("foo"), Signature: []byte("bar"), }, - wantErr: true, + wantErr: "IV (protected) and PartialIV (unprotected) parameters must not both be present", }, { name: "protected has PartialIV and unprotected has IV error", @@ -161,13 +161,17 @@ func TestSign1Message_MarshalCBOR(t *testing.T) { Payload: []byte("foo"), Signature: []byte("bar"), }, - wantErr: true, + wantErr: "IV (unprotected) and PartialIV (protected) parameters must not both be present", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.m.MarshalCBOR() - if (err != nil) != tt.wantErr { + + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("Sign1Message.MarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && (tt.wantErr != "") { t.Errorf("Sign1Message.MarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) return } @@ -193,7 +197,7 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { name string data []byte want Sign1Message - wantErr bool + wantErr string }{ { name: "valid message", @@ -248,17 +252,17 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { { name: "nil CBOR data", data: nil, - wantErr: true, + wantErr: "cbor: invalid COSE_Sign1_Tagged object", }, { name: "empty CBOR data", data: []byte{}, - wantErr: true, + wantErr: "cbor: invalid COSE_Sign1_Tagged object", }, { name: "invalid message with valid prefix", // issue #29 data: []byte{0xd2, 0x84, 0xf7, 0xf7, 0xf7, 0xf7}, - wantErr: true, + wantErr: "cbor: require bstr type", }, { name: "tagged signature", // issue #30 @@ -268,7 +272,7 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { 0xf6, // nil payload 0xcb, 0xa1, 0x00, // tagged signature }, - wantErr: true, + wantErr: "cbor: CBOR tag isn't allowed", }, { name: "nil signature", @@ -278,7 +282,7 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { 0xf6, // payload 0xf6, // nil signature }, - wantErr: true, + wantErr: "empty signature", }, { name: "empty signature", @@ -288,7 +292,7 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { 0xf6, // payload 0x40, // empty signature }, - wantErr: true, + wantErr: "empty signature", }, { name: "mismatch tag", @@ -298,14 +302,14 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { 0xf6, // payload 0x41, 0x00, // signature }, - wantErr: true, + wantErr: "cbor: invalid COSE_Sign1_Tagged object", }, { name: "mismatch type", data: []byte{ 0xd2, 0x40, }, - wantErr: true, + wantErr: "cbor: invalid COSE_Sign1_Tagged object", }, { name: "smaller array size", @@ -314,7 +318,7 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { 0x40, 0xa0, // empty headers 0xf6, // payload }, - wantErr: true, + wantErr: "cbor: invalid COSE_Sign1_Tagged object", }, { name: "larger array size", @@ -325,7 +329,7 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { 0x41, 0x00, // signature 0x40, }, - wantErr: true, + wantErr: "cbor: invalid COSE_Sign1_Tagged object", }, { name: "undefined payload", @@ -335,7 +339,7 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { 0xf7, // undefined payload 0x41, 0x00, // signature }, - wantErr: true, + wantErr: "cbor: require bstr type", }, { name: "payload as a byte array", @@ -345,7 +349,7 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { 0x80, // payload 0x41, 0x00, // signature }, - wantErr: true, + wantErr: "cbor: require bstr type", }, { name: "signature as a byte array", @@ -355,7 +359,7 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { 0xf6, // nil payload 0x81, 0x00, // signature }, - wantErr: true, + wantErr: "cbor: require bstr type", }, { name: "protected has IV and unprotected has PartialIV", @@ -367,7 +371,7 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { 0xf6, // payload 0x43, 0x62, 0x61, 0x72, // signature }, - wantErr: true, + wantErr: "cbor: invalid protected header: protected header: header parameter: IV: require bstr type", }, { name: "protected has PartialIV and unprotected has IV", @@ -379,13 +383,17 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) { 0xf6, // payload 0x43, 0x62, 0x61, 0x72, // signature }, - wantErr: true, + wantErr: "cbor: invalid protected header: protected header: header parameter: Partial IV: require bstr type", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got Sign1Message - if err := got.UnmarshalCBOR(tt.data); (err != nil) != tt.wantErr { + err := got.UnmarshalCBOR(tt.data) + if (err != nil) && (err.Error() != tt.wantErr) { + t.Errorf("Sign1Message.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && (tt.wantErr != "") { t.Errorf("Sign1Message.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) return } @@ -416,7 +424,7 @@ func TestSign1Message_Sign(t *testing.T) { msg *Sign1Message externalOnSign []byte externalOnVerify []byte - wantErr bool + wantErr string check func(t *testing.T, m *Sign1Message) }{ { @@ -493,7 +501,7 @@ func TestSign1Message_Sign(t *testing.T) { }, Payload: nil, }, - wantErr: true, + wantErr: "missing payload", }, { name: "mismatch algorithm", @@ -505,7 +513,7 @@ func TestSign1Message_Sign(t *testing.T) { }, Payload: []byte("hello world"), }, - wantErr: true, + wantErr: "algorithm mismatch: signer ES256: header ES512", }, { name: "missing algorithm", @@ -530,7 +538,7 @@ func TestSign1Message_Sign(t *testing.T) { }, Payload: []byte("hello world"), }, - wantErr: true, + wantErr: "algorithm not found", }, { name: "missing algorithm with externally supplied data", @@ -552,22 +560,25 @@ func TestSign1Message_Sign(t *testing.T) { Payload: []byte("hello world"), Signature: []byte("foobar"), }, - wantErr: true, + wantErr: "Sign1Message signature already has signature bytes", }, { name: "nil message", msg: nil, - wantErr: true, + wantErr: "signing nil Sign1Message", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.msg.Sign(rand.Reader, tt.externalOnSign, signer) - if (err != nil) != tt.wantErr { - t.Errorf("Sign1Message.Sign() error = %v, wantErr %v", err, tt.wantErr) - return - } + if err != nil { + if err.Error() != tt.wantErr { + t.Errorf("Sign1Message.Sign() error = %v, wantErr %v", err, tt.wantErr) + } + return + } else if tt.wantErr != "" { + t.Errorf("Sign1Message.Sign() error = %v, wantErr %v", err, tt.wantErr) return } if tt.check != nil { @@ -710,7 +721,7 @@ func TestSign1Message_Verify(t *testing.T) { externalOnSign []byte externalOnVerify []byte tamper func(m *Sign1Message) *Sign1Message - wantErr bool + wantErr string }{ { name: "round trip on valid message", @@ -719,7 +730,7 @@ func TestSign1Message_Verify(t *testing.T) { name: "external mismatch", externalOnSign: []byte("foo"), externalOnVerify: []byte("bar"), - wantErr: true, + wantErr: "verification error", }, { name: "mixed nil / empty external", @@ -731,7 +742,7 @@ func TestSign1Message_Verify(t *testing.T) { tamper: func(m *Sign1Message) *Sign1Message { return nil }, - wantErr: true, + wantErr: "verifying nil Sign1Message", }, { name: "strip signature", @@ -739,7 +750,7 @@ func TestSign1Message_Verify(t *testing.T) { m.Signature = nil return m }, - wantErr: true, + wantErr: "empty signature", }, { name: "empty signature", @@ -747,7 +758,7 @@ func TestSign1Message_Verify(t *testing.T) { m.Signature = []byte{} return m }, - wantErr: true, + wantErr: "empty signature", }, { name: "tamper protected header", @@ -755,7 +766,7 @@ func TestSign1Message_Verify(t *testing.T) { m.Headers.Protected["foo"] = "bar" return m }, - wantErr: true, + wantErr: "verification error", }, { name: "tamper unprotected header", @@ -763,7 +774,6 @@ func TestSign1Message_Verify(t *testing.T) { m.Headers.Unprotected["foo"] = "bar" return m }, - wantErr: false, // allowed }, { name: "tamper payload", @@ -771,7 +781,7 @@ func TestSign1Message_Verify(t *testing.T) { m.Payload = []byte("foobar") return m }, - wantErr: true, + wantErr: "verification error", }, { name: "tamper signature", @@ -779,7 +789,7 @@ func TestSign1Message_Verify(t *testing.T) { m.Signature[0]++ return m }, - wantErr: true, + wantErr: "verification error", }, } for _, tt := range tests { @@ -807,7 +817,10 @@ func TestSign1Message_Verify(t *testing.T) { } // verify message - if err := msg.Verify(tt.externalOnVerify, verifier); (err != nil) != tt.wantErr { + err := msg.Verify(tt.externalOnVerify, verifier) + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("Sign1Message.Verify() error = %v, wantErr %v", err, tt.wantErr) + } else if err == nil && (tt.wantErr != "") { t.Errorf("Sign1Message.Verify() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/sign_test.go b/sign_test.go index ee865ec..c1639c4 100644 --- a/sign_test.go +++ b/sign_test.go @@ -14,7 +14,7 @@ func TestSignature_MarshalCBOR(t *testing.T) { name string s *Signature want []byte - wantErr bool + wantErr string }{ { name: "valid message", @@ -39,7 +39,7 @@ func TestSignature_MarshalCBOR(t *testing.T) { { name: "nil signature", s: nil, - wantErr: true, + wantErr: "cbor: MarshalCBOR on nil Signature pointer", }, { name: "nil signature", @@ -54,7 +54,7 @@ func TestSignature_MarshalCBOR(t *testing.T) { }, Signature: nil, }, - wantErr: true, + wantErr: "empty signature", }, { name: "empty signature", @@ -69,7 +69,7 @@ func TestSignature_MarshalCBOR(t *testing.T) { }, Signature: []byte{}, }, - wantErr: true, + wantErr: "empty signature", }, { name: "invalid protected header", @@ -84,7 +84,7 @@ func TestSignature_MarshalCBOR(t *testing.T) { }, Signature: []byte("bar"), }, - wantErr: true, + wantErr: "protected header: header parameter: alg: require int / tstr type", }, { name: "invalid unprotected header", @@ -99,7 +99,7 @@ func TestSignature_MarshalCBOR(t *testing.T) { }, Signature: []byte("bar"), }, - wantErr: true, + wantErr: "cbor: unsupported type: chan bool", }, { name: "protected has IV and unprotected has PartialIV error", @@ -115,7 +115,7 @@ func TestSignature_MarshalCBOR(t *testing.T) { }, Signature: []byte("bar"), }, - wantErr: true, + wantErr: "IV (protected) and PartialIV (unprotected) parameters must not both be present", }, { name: "protected has PartialIV and unprotected has IV error", @@ -131,13 +131,16 @@ func TestSignature_MarshalCBOR(t *testing.T) { }, Signature: []byte("bar"), }, - wantErr: true, + wantErr: "IV (unprotected) and PartialIV (protected) parameters must not both be present", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.s.MarshalCBOR() - if (err != nil) != tt.wantErr { + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("Signature.MarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && (tt.wantErr != "") { t.Errorf("Signature.MarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) return } @@ -163,7 +166,7 @@ func TestSignature_UnmarshalCBOR(t *testing.T) { name string data []byte want Signature - wantErr bool + wantErr string }{ { name: "valid signature struct", @@ -190,12 +193,12 @@ func TestSignature_UnmarshalCBOR(t *testing.T) { { name: "nil CBOR data", data: nil, - wantErr: true, + wantErr: "cbor: invalid Signature object", }, { name: "empty CBOR data", data: []byte{}, - wantErr: true, + wantErr: "cbor: invalid Signature object", }, { name: "tagged signature", // issue #30 @@ -204,7 +207,7 @@ func TestSignature_UnmarshalCBOR(t *testing.T) { 0x40, 0xa0, // empty headers 0xcb, 0xa1, 0x00, // tagged signature }, - wantErr: true, + wantErr: "cbor: CBOR tag isn't allowed", }, { name: "nil signature", @@ -213,7 +216,7 @@ func TestSignature_UnmarshalCBOR(t *testing.T) { 0x40, 0xa0, // empty headers 0xf6, // nil signature }, - wantErr: true, + wantErr: "empty signature", }, { name: "empty signature", @@ -222,14 +225,14 @@ func TestSignature_UnmarshalCBOR(t *testing.T) { 0x40, 0xa0, // empty headers 0x40, // empty signature }, - wantErr: true, + wantErr: "empty signature", }, { name: "mismatch type", data: []byte{ 0x40, }, - wantErr: true, + wantErr: "cbor: invalid Signature object", }, { name: "smaller array size", @@ -237,7 +240,7 @@ func TestSignature_UnmarshalCBOR(t *testing.T) { 0x82, 0x40, 0xa0, // empty headers }, - wantErr: true, + wantErr: "cbor: invalid Signature object", }, { name: "larger array size", @@ -247,7 +250,7 @@ func TestSignature_UnmarshalCBOR(t *testing.T) { 0x41, 0x00, // signature 0x40, }, - wantErr: true, + wantErr: "cbor: invalid Signature object", }, { name: "signature as a byte array", @@ -256,7 +259,7 @@ func TestSignature_UnmarshalCBOR(t *testing.T) { 0x40, 0xa0, // empty headers 0x81, 0x00, // signature }, - wantErr: true, + wantErr: "cbor: require bstr type", }, { name: "protected has IV and unprotected has PartialIV", @@ -266,7 +269,7 @@ func TestSignature_UnmarshalCBOR(t *testing.T) { 0xa1, 0x6, 0x63, 0x62, 0x61, 0x72, // unprotected 0x43, 0x62, 0x61, 0x72, // signature }, - wantErr: true, + wantErr: "cbor: invalid protected header: protected header: header parameter: IV: require bstr type", }, { name: "protected has PartialIV and unprotected has IV", @@ -276,13 +279,17 @@ func TestSignature_UnmarshalCBOR(t *testing.T) { 0xa1, 0x5, 0x63, 0x62, 0x61, 0x72, // unprotected 0x43, 0x62, 0x61, 0x72, // signature }, - wantErr: true, + wantErr: "cbor: invalid protected header: protected header: header parameter: Partial IV: require bstr type", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got Signature - if err := got.UnmarshalCBOR(tt.data); (err != nil) != tt.wantErr { + err := got.UnmarshalCBOR(tt.data) + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("Signature.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && (tt.wantErr != "") { t.Errorf("Signature.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) return } @@ -317,7 +324,7 @@ func TestSignature_Sign(t *testing.T) { sig *Signature onSign args onVerify args - wantErr bool + wantErr string check func(t *testing.T, s *Signature) }{ { @@ -429,7 +436,7 @@ func TestSignature_Sign(t *testing.T) { protected: []byte{0x40}, payload: nil, }, - wantErr: true, + wantErr: "missing payload", }, { name: "mismatch algorithm", @@ -448,7 +455,7 @@ func TestSignature_Sign(t *testing.T) { protected: []byte{0x40}, payload: []byte("hello world"), }, - wantErr: true, + wantErr: "algorithm mismatch: signer ES256: header ES512", }, { name: "missing algorithm", @@ -486,7 +493,7 @@ func TestSignature_Sign(t *testing.T) { protected: []byte{0x40}, payload: []byte("hello world"), }, - wantErr: true, + wantErr: "algorithm not found", }, { name: "missing algorithm with externally supplied data", @@ -521,7 +528,7 @@ func TestSignature_Sign(t *testing.T) { protected: []byte{0x40}, payload: []byte("hello world"), }, - wantErr: true, + wantErr: "Signature already has signature bytes", }, { name: "nil signature", @@ -534,7 +541,7 @@ func TestSignature_Sign(t *testing.T) { protected: []byte{0x40}, payload: []byte("hello world"), }, - wantErr: true, + wantErr: "signing nil Signature", }, { name: "nil body protected header", @@ -556,7 +563,7 @@ func TestSignature_Sign(t *testing.T) { protected: nil, payload: []byte("hello world"), }, - wantErr: true, + wantErr: "invalid body protected headers", }, { name: "empty body protected header", @@ -578,7 +585,7 @@ func TestSignature_Sign(t *testing.T) { protected: []byte{}, payload: []byte("hello world"), }, - wantErr: true, + wantErr: "invalid body protected headers", }, { name: "invalid protected header", @@ -600,19 +607,22 @@ func TestSignature_Sign(t *testing.T) { protected: []byte{0xa0}, payload: []byte("hello world"), }, - wantErr: true, + wantErr: "invalid body protected headers", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.sig.Sign(rand.Reader, signer, tt.onSign.protected, tt.onSign.payload, tt.onSign.external) - if (err != nil) != tt.wantErr { - t.Errorf("Signature.Sign() error = %v, wantErr %v", err, tt.wantErr) - return - } if err != nil { + if err.Error() != tt.wantErr { + t.Errorf("Signature.Sign() error = %v, wantErr %v", err, tt.wantErr) + } + return + } else if tt.wantErr != "" { + t.Errorf("Signature.Sign() error = %v, wantErr %v", err, tt.wantErr) return } + if tt.check != nil { tt.check(t, tt.sig) } @@ -791,7 +801,7 @@ func TestSignature_Verify(t *testing.T) { onSign args onVerify args tamper func(s *Signature) *Signature - wantErr bool + wantErr string }{ { name: "round trip on valid message", @@ -840,7 +850,7 @@ func TestSignature_Verify(t *testing.T) { protected: nil, payload: []byte("hello world"), }, - wantErr: true, + wantErr: "invalid body protected headers", }, { name: "empty body protected header", @@ -852,7 +862,7 @@ func TestSignature_Verify(t *testing.T) { protected: []byte{}, payload: []byte("hello world"), }, - wantErr: true, + wantErr: "invalid body protected headers", }, { name: "invalid body protected header", @@ -864,7 +874,7 @@ func TestSignature_Verify(t *testing.T) { protected: []byte{0xa0}, payload: []byte("hello world"), }, - wantErr: true, + wantErr: "invalid body protected headers", }, { name: "body protected header mismatch", @@ -876,7 +886,7 @@ func TestSignature_Verify(t *testing.T) { protected: []byte{0x43, 0xa1, 0x00, 0x01}, payload: []byte("hello world"), }, - wantErr: true, + wantErr: "verification error", }, { name: "nil payload", @@ -888,7 +898,7 @@ func TestSignature_Verify(t *testing.T) { protected: []byte{0x40}, payload: nil, }, - wantErr: true, + wantErr: "missing payload", }, { name: "payload mismatch", @@ -900,7 +910,7 @@ func TestSignature_Verify(t *testing.T) { protected: []byte{0x40}, payload: []byte("foobar"), }, - wantErr: true, + wantErr: "verification error", }, { name: "external mismatch", @@ -914,7 +924,7 @@ func TestSignature_Verify(t *testing.T) { payload: []byte("hello world"), external: []byte("bar"), }, - wantErr: true, + wantErr: "verification error", }, { name: "nil signature struct", @@ -929,7 +939,7 @@ func TestSignature_Verify(t *testing.T) { tamper: func(s *Signature) *Signature { return nil }, - wantErr: true, + wantErr: "verifying nil Signature", }, { name: "strip signature", @@ -945,7 +955,7 @@ func TestSignature_Verify(t *testing.T) { s.Signature = nil return s }, - wantErr: true, + wantErr: "empty signature", }, { name: "empty signature", @@ -961,7 +971,7 @@ func TestSignature_Verify(t *testing.T) { s.Signature = []byte{} return s }, - wantErr: true, + wantErr: "empty signature", }, { name: "tamper protected header", @@ -977,7 +987,7 @@ func TestSignature_Verify(t *testing.T) { s.Headers.Protected["foo"] = "bar" return s }, - wantErr: true, + wantErr: "verification error", }, { name: "tamper unprotected header", @@ -993,7 +1003,6 @@ func TestSignature_Verify(t *testing.T) { s.Headers.Unprotected["foo"] = "bar" return s }, - wantErr: false, // allowed }, { name: "tamper signature", @@ -1009,7 +1018,7 @@ func TestSignature_Verify(t *testing.T) { s.Signature[0]++ return s }, - wantErr: true, + wantErr: "verification error", }, } for _, tt := range tests { @@ -1025,7 +1034,8 @@ func TestSignature_Verify(t *testing.T) { }, }, } - if err := sig.Sign(rand.Reader, signer, tt.onSign.protected, tt.onSign.payload, tt.onSign.external); err != nil { + err := sig.Sign(rand.Reader, signer, tt.onSign.protected, tt.onSign.payload, tt.onSign.external) + if err != nil && (err.Error() != tt.wantErr) { t.Errorf("Signature.Sign() error = %v", err) return } @@ -1036,7 +1046,10 @@ func TestSignature_Verify(t *testing.T) { } // verify signature - if err := sig.Verify(verifier, tt.onVerify.protected, tt.onVerify.payload, tt.onVerify.external); (err != nil) != tt.wantErr { + err = sig.Verify(verifier, tt.onVerify.protected, tt.onVerify.payload, tt.onVerify.external) + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("Signature.Verify() error = %v, wantErr %v", err, tt.wantErr) + } else if err == nil && (tt.wantErr != "") { t.Errorf("Signature.Verify() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -1048,7 +1061,7 @@ func TestSignMessage_MarshalCBOR(t *testing.T) { name string m *SignMessage want []byte - wantErr bool + wantErr string }{ { name: "valid message with multiple signatures", @@ -1164,7 +1177,7 @@ func TestSignMessage_MarshalCBOR(t *testing.T) { Payload: []byte("hello world"), Signatures: nil, }, - wantErr: true, + wantErr: "no signatures attached", }, { name: "empty signatures", @@ -1180,12 +1193,12 @@ func TestSignMessage_MarshalCBOR(t *testing.T) { Payload: []byte("hello world"), Signatures: []*Signature{}, }, - wantErr: true, + wantErr: "no signatures attached", }, { name: "nil message", m: nil, - wantErr: true, + wantErr: "cbor: MarshalCBOR on nil SignMessage pointer", }, { name: "nil payload", @@ -1253,7 +1266,7 @@ func TestSignMessage_MarshalCBOR(t *testing.T) { }, }, }, - wantErr: true, + wantErr: "protected header: header parameter: alg: require int / tstr type", }, { name: "invalid unprotected header", @@ -1278,13 +1291,16 @@ func TestSignMessage_MarshalCBOR(t *testing.T) { }, }, }, - wantErr: true, + wantErr: "cbor: unsupported type: chan bool", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := tt.m.MarshalCBOR() - if (err != nil) != tt.wantErr { + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("SignMessage.MarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && (tt.wantErr != "") { t.Errorf("SignMessage.MarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) return } @@ -1313,7 +1329,7 @@ func TestSignMessage_UnmarshalCBOR(t *testing.T) { name string data []byte want SignMessage - wantErr bool + wantErr string }{ { name: "valid message with multiple signatures", @@ -1511,7 +1527,7 @@ func TestSignMessage_UnmarshalCBOR(t *testing.T) { 0xf6, // nil payload 0xf6, // signatures }, - wantErr: true, + wantErr: "no signatures attached", }, { name: "empty signatures", @@ -1527,7 +1543,7 @@ func TestSignMessage_UnmarshalCBOR(t *testing.T) { 0xf6, // nil payload 0x80, // signatures }, - wantErr: true, + wantErr: "no signatures attached", }, { name: "tagged signature", // issue #30 @@ -1546,17 +1562,17 @@ func TestSignMessage_UnmarshalCBOR(t *testing.T) { 0x40, 0xa0, // empty headers 0xcb, 0xa1, 0x00, // tagged signature }, - wantErr: true, + wantErr: "cbor: CBOR tag isn't allowed", }, { name: "nil CBOR data", data: nil, - wantErr: true, + wantErr: "cbor: invalid COSE_Sign_Tagged object", }, { name: "empty CBOR data", data: []byte{}, - wantErr: true, + wantErr: "cbor: invalid COSE_Sign_Tagged object", }, { name: "mismatch tag", @@ -1569,14 +1585,14 @@ func TestSignMessage_UnmarshalCBOR(t *testing.T) { 0x40, 0xa0, // empty headers 0x41, 0x00, // signature }, - wantErr: true, + wantErr: "cbor: invalid COSE_Sign_Tagged object", }, { name: "mismatch type", data: []byte{ 0xd8, 0x62, 0x40, }, - wantErr: true, + wantErr: "cbor: invalid COSE_Sign_Tagged object", }, { name: "smaller array size", @@ -1585,7 +1601,7 @@ func TestSignMessage_UnmarshalCBOR(t *testing.T) { 0x40, 0xa0, // empty headers 0xf6, // nil payload }, - wantErr: true, + wantErr: "cbor: invalid COSE_Sign_Tagged object", }, { name: "larger array size", @@ -1599,7 +1615,7 @@ func TestSignMessage_UnmarshalCBOR(t *testing.T) { 0x41, 0x00, // signature 0x40, }, - wantErr: true, + wantErr: "cbor: invalid COSE_Sign_Tagged object", }, { name: "undefined payload", @@ -1612,7 +1628,7 @@ func TestSignMessage_UnmarshalCBOR(t *testing.T) { 0x40, 0xa0, // empty headers 0x41, 0x00, // signature }, - wantErr: true, + wantErr: "cbor: require bstr type", }, { name: "payload as a byte array", @@ -1625,13 +1641,17 @@ func TestSignMessage_UnmarshalCBOR(t *testing.T) { 0x40, 0xa0, // empty headers 0x41, 0x00, // signature }, - wantErr: true, + wantErr: "cbor: require bstr type", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got SignMessage - if err := got.UnmarshalCBOR(tt.data); (err != nil) != tt.wantErr { + err := got.UnmarshalCBOR(tt.data) + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("SignMessage.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && (tt.wantErr != "") { t.Errorf("SignMessage.UnmarshalCBOR() error = %v, wantErr %v", err, tt.wantErr) return } @@ -1669,7 +1689,7 @@ func TestSignMessage_Sign(t *testing.T) { msg *SignMessage externalOnSign []byte externalOnVerify []byte - wantErr bool + wantErr string check func(t *testing.T, m *SignMessage) }{ { @@ -1833,7 +1853,7 @@ func TestSignMessage_Sign(t *testing.T) { }, }, }, - wantErr: true, + wantErr: "missing payload", }, { name: "mismatch algorithm", @@ -1856,7 +1876,7 @@ func TestSignMessage_Sign(t *testing.T) { }, }, }, - wantErr: true, + wantErr: "algorithm mismatch: signer ES256: header ES512", }, { name: "plain message", @@ -1887,12 +1907,12 @@ func TestSignMessage_Sign(t *testing.T) { }, }, }, - wantErr: true, + wantErr: "Signature already has signature bytes", }, { name: "nil message", msg: nil, - wantErr: true, + wantErr: "signing nil SignMessage", }, { name: "too few signers", @@ -1900,7 +1920,7 @@ func TestSignMessage_Sign(t *testing.T) { Payload: []byte("hello world"), Signatures: []*Signature{{}, {}, {}}, }, - wantErr: true, + wantErr: "2 signers for 3 signatures", }, { name: "too many signers", @@ -1908,7 +1928,7 @@ func TestSignMessage_Sign(t *testing.T) { Payload: []byte("hello world"), Signatures: []*Signature{{}}, }, - wantErr: true, + wantErr: "2 signers for 1 signatures", }, { name: "empty signatures", @@ -1916,7 +1936,7 @@ func TestSignMessage_Sign(t *testing.T) { Payload: []byte("hello world"), Signatures: []*Signature{}, }, - wantErr: true, + wantErr: "no signatures attached", }, { name: "nil signatures", @@ -1924,17 +1944,19 @@ func TestSignMessage_Sign(t *testing.T) { Payload: []byte("hello world"), Signatures: nil, }, - wantErr: true, + wantErr: "no signatures attached", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.msg.Sign(rand.Reader, tt.externalOnSign, signers...) - if (err != nil) != tt.wantErr { - t.Errorf("SignMessage.Sign() error = %v, wantErr %v", err, tt.wantErr) - return - } if err != nil { + if err.Error() != tt.wantErr { + t.Errorf("SignMessage.Sign() error = %v, wantErr %v", err, tt.wantErr) + } + return + } else if tt.wantErr != "" { + t.Errorf("SignMessage.Sign() error = %v, wantErr %v", err, tt.wantErr) return } if tt.check != nil { @@ -1986,7 +2008,7 @@ func TestSignMessage_Verify(t *testing.T) { externalOnVerify []byte verifiers []Verifier tamper func(m *SignMessage) *SignMessage - wantErr bool + wantErr string }{ { name: "round trip on valid message", @@ -1997,7 +2019,7 @@ func TestSignMessage_Verify(t *testing.T) { externalOnSign: []byte("foo"), externalOnVerify: []byte("bar"), verifiers: verifiers, - wantErr: true, + wantErr: "verification error", }, { name: "mixed nil / empty external", @@ -2011,7 +2033,7 @@ func TestSignMessage_Verify(t *testing.T) { tamper: func(m *SignMessage) *SignMessage { return nil }, - wantErr: true, + wantErr: "verifying nil SignMessage", }, { name: "strip signatures", @@ -2020,7 +2042,7 @@ func TestSignMessage_Verify(t *testing.T) { m.Signatures = nil return m }, - wantErr: true, + wantErr: "no signatures attached", }, { name: "empty signatures", @@ -2029,7 +2051,7 @@ func TestSignMessage_Verify(t *testing.T) { m.Signatures = []*Signature{} return m }, - wantErr: true, + wantErr: "no signatures attached", }, { name: "tamper protected header", @@ -2038,7 +2060,7 @@ func TestSignMessage_Verify(t *testing.T) { m.Headers.Protected["foo"] = "bar" return m }, - wantErr: true, + wantErr: "verification error", }, { name: "tamper unprotected header", @@ -2047,7 +2069,6 @@ func TestSignMessage_Verify(t *testing.T) { m.Headers.Unprotected["foo"] = "bar" return m }, - wantErr: false, // allowed }, { name: "tamper payload", @@ -2056,7 +2077,7 @@ func TestSignMessage_Verify(t *testing.T) { m.Payload = []byte("foobar") return m }, - wantErr: true, + wantErr: "verification error", }, { name: "tamper signature", @@ -2065,18 +2086,18 @@ func TestSignMessage_Verify(t *testing.T) { m.Signatures[1].Signature[0]++ return m }, - wantErr: true, + wantErr: "verification error", }, { name: "no verifiers", verifiers: nil, - wantErr: true, + wantErr: "0 verifiers for 2 signatures", }, { name: "too few verifiers", verifiers: verifiers[:1], - wantErr: true, + wantErr: "1 verifiers for 2 signatures", }, { name: "too many verifiers", @@ -2085,12 +2106,12 @@ func TestSignMessage_Verify(t *testing.T) { m.Signatures = m.Signatures[:1] return m }, - wantErr: true, + wantErr: "2 verifiers for 1 signatures", }, { name: "verifier mismatch", verifiers: []Verifier{verifiers[1], verifiers[0]}, - wantErr: true, + wantErr: "algorithm mismatch: verifier ES512: header ES256", }, } for _, tt := range tests { @@ -2137,7 +2158,10 @@ func TestSignMessage_Verify(t *testing.T) { } // verify message - if err := msg.Verify(tt.externalOnVerify, tt.verifiers...); (err != nil) != tt.wantErr { + err := msg.Verify(tt.externalOnVerify, tt.verifiers...) + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("SignMessage.Verify() error = %v, wantErr %v", err, tt.wantErr) + } else if err == nil && (tt.wantErr != "") { t.Errorf("SignMessage.Verify() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/signer_test.go b/signer_test.go index 5bd6c32..e9475b9 100644 --- a/signer_test.go +++ b/signer_test.go @@ -48,7 +48,7 @@ func TestNewSigner(t *testing.T) { alg Algorithm key crypto.Signer want Signer - wantErr bool + wantErr string }{ { name: "ecdsa key signer", @@ -73,7 +73,7 @@ func TestNewSigner(t *testing.T) { name: "ecdsa key mismatch", alg: AlgorithmES256, key: rsaKey, - wantErr: true, + wantErr: "ES256: invalid public key", }, { name: "ed25519 signer", @@ -87,7 +87,7 @@ func TestNewSigner(t *testing.T) { name: "ed25519 key mismatch", alg: AlgorithmEd25519, key: rsaKey, - wantErr: true, + wantErr: "EdDSA: invalid public key", }, { name: "rsa signer", @@ -102,24 +102,27 @@ func TestNewSigner(t *testing.T) { name: "rsa key mismatch", alg: AlgorithmPS256, key: ecdsaKey, - wantErr: true, + wantErr: "PS256: invalid public key", }, { name: "rsa key under minimum entropy", alg: AlgorithmPS256, key: rsaKeyLowEntropy, - wantErr: true, + wantErr: "RSA key must be at least 2048 bits long", }, { name: "unknown algorithm", alg: 0, - wantErr: true, + wantErr: "algorithm not supported", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := NewSigner(tt.alg, tt.key) - if (err != nil) != tt.wantErr { + if err != nil && (err.Error() != tt.wantErr) { + t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr) + return + } else if err == nil && (tt.wantErr != "") { t.Errorf("NewSigner() error = %v, wantErr %v", err, tt.wantErr) return } From 08fc5f0a8051bf9462251ae5a603b3fa9cf4c29a Mon Sep 17 00:00:00 2001 From: Yogesh Deshpande Date: Fri, 10 Mar 2023 14:10:56 -0500 Subject: [PATCH 2/2] Incorporating Review Comments Signed-off-by: Yogesh Deshpande --- ecdsa_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ecdsa_test.go b/ecdsa_test.go index feaa673..308ca6a 100644 --- a/ecdsa_test.go +++ b/ecdsa_test.go @@ -101,6 +101,9 @@ func TestI2OSP(t *testing.T) { if err != nil && (err.Error() != tt.wantErr) { t.Errorf("I2OSP() error = %v, wantErr %v", err, tt.wantErr) return + } else if err == nil && (tt.wantErr != "") { + t.Errorf("I2OSP() error = %v, wantErr %v", err, tt.wantErr) + return } if got := tt.buf; (tt.wantErr == "") && !reflect.DeepEqual(got, tt.want) {