Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure IV and Partial IV are not both present #66

Merged
merged 4 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ const (
HeaderLabelCritical int64 = 2
HeaderLabelContentType int64 = 3
HeaderLabelKeyID int64 = 4
HeaderLabelIV int64 = 5
HeaderLabelPartialIV int64 = 6
HeaderLabelCounterSignature int64 = 7
HeaderLabelCounterSignature0 int64 = 9
HeaderLabelX5Bag int64 = 32
Expand Down Expand Up @@ -43,6 +45,9 @@ func (h ProtectedHeader) MarshalCBOR() ([]byte, error) {
if err = h.ensureCritical(); err != nil {
return nil, err
}
if err = ensureHeaderIV(h); err != nil {
return nil, fmt.Errorf("protected header: %w", err)
}
encoded, err = encMode.Marshal(map[interface{}]interface{}(h))
if err != nil {
return nil, err
Expand Down Expand Up @@ -83,6 +88,9 @@ func (h *ProtectedHeader) UnmarshalCBOR(data []byte) error {
if err := candidate.ensureCritical(); err != nil {
return err
}
if err := ensureHeaderIV(candidate); err != nil {
return fmt.Errorf("protected header: %w", err)
}

// cast to type Algorithm if `alg` presents
if alg, err := candidate.Algorithm(); err == nil {
Expand Down Expand Up @@ -170,6 +178,9 @@ func (h UnprotectedHeader) MarshalCBOR() ([]byte, error) {
if err := validateHeaderLabel(h); err != nil {
return nil, err
}
if err := ensureHeaderIV(h); err != nil {
return nil, fmt.Errorf("unprotected header: %w", err)
}
return encMode.Marshal(map[interface{}]interface{}(h))
}

Expand All @@ -196,6 +207,9 @@ func (h *UnprotectedHeader) UnmarshalCBOR(data []byte) error {
if err := decMode.Unmarshal(data, &header); err != nil {
return err
}
if err := ensureHeaderIV(header); err != nil {
return fmt.Errorf("unprotected header: %w", err)
}
*h = header
return nil
}
Expand Down Expand Up @@ -253,6 +267,23 @@ type Headers struct {
Unprotected UnprotectedHeader
}

// marshal encoded both headers.
// It returns RawProtected and RawUnprotected if those are set.
func (h *Headers) marshal() (cbor.RawMessage, cbor.RawMessage, error) {
if err := h.ensureIV(); err != nil {
return nil, nil, err
}
protected, err := h.MarshalProtected()
if err != nil {
return nil, nil, err
}
unprotected, err := h.MarshalUnprotected()
if err != nil {
return nil, nil, err
}
return protected, unprotected, nil
}

// MarshalProtected encodes the protected header.
// RawProtected is returned if it is not set to nil.
func (h *Headers) MarshalProtected() ([]byte, error) {
Expand Down Expand Up @@ -280,6 +311,9 @@ func (h *Headers) UnmarshalFromRaw() error {
if err := decMode.Unmarshal(h.RawUnprotected, &h.Unprotected); err != nil {
return fmt.Errorf("cbor: invalid unprotected header: %w", err)
}
if err := h.ensureIV(); err != nil {
return err
}
return nil
}

Expand Down Expand Up @@ -331,6 +365,38 @@ func (h *Headers) ensureVerificationAlgorithm(alg Algorithm, external []byte) er
return err
}

// ensureIV ensures IV and Partial IV are not both present
// in the protected and unprotected headers.
// It does not check if they are both present within one header,
// as it will be checked later on.
//
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
func (h *Headers) ensureIV() error {
if hasLabel(h.Protected, HeaderLabelIV) && hasLabel(h.Unprotected, HeaderLabelPartialIV) {
return errors.New("IV (protected) and PartialIV (unprotected) parameters must not both be present")
}
if hasLabel(h.Protected, HeaderLabelPartialIV) && hasLabel(h.Unprotected, HeaderLabelIV) {
return errors.New("IV (unprotected) and PartialIV (protected) parameters must not both be present")
}
return nil
}

// hasLabel returns true if h contains label.
func hasLabel(h map[interface{}]interface{}, label interface{}) bool {
_, ok := h[label]
return ok
}

// ensureHeaderIV ensures IV and Partial IV are not both present in the header.
//
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
func ensureHeaderIV(h map[interface{}]interface{}) error {
if hasLabel(h, HeaderLabelIV) && hasLabel(h, HeaderLabelPartialIV) {
return errors.New("IV and PartialIV parameters must not both be present")
}
return nil
}

// validateHeaderLabel validates if all header labels are integers or strings.
//
// label = int / tstr
Expand Down
30 changes: 30 additions & 0 deletions headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ func TestProtectedHeader_MarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "iv and partial iv present",
h: ProtectedHeader{
HeaderLabelIV: "foo",
HeaderLabelPartialIV: "bar",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -265,6 +273,13 @@ func TestProtectedHeader_UnmarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "iv and partial iv present",
data: []byte{
0x4b, 0xa2, 0x5, 0x63, 0x66, 0x6f, 0x6f, 0x6, 0x63, 0x62, 0x61, 0x72,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -517,6 +532,14 @@ func TestUnprotectedHeader_MarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "iv and partial iv present",
h: UnprotectedHeader{
HeaderLabelIV: "foo",
HeaderLabelPartialIV: "bar",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -617,6 +640,13 @@ func TestUnprotectedHeader_UnmarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "iv and partial iv present",
data: []byte{
0xa2, 0x5, 0x63, 0x66, 0x6f, 0x6f, 0x6, 0x63, 0x62, 0x61, 0x72,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
12 changes: 2 additions & 10 deletions sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ func (s *Signature) MarshalCBOR() ([]byte, error) {
if len(s.Signature) == 0 {
return nil, ErrEmptySignature
}
protected, err := s.Headers.MarshalProtected()
if err != nil {
return nil, err
}
unprotected, err := s.Headers.MarshalUnprotected()
protected, unprotected, err := s.Headers.marshal()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -329,11 +325,7 @@ func (m *SignMessage) MarshalCBOR() ([]byte, error) {
if len(m.Signatures) == 0 {
return nil, ErrNoSignatures
}
protected, err := m.Headers.MarshalProtected()
if err != nil {
return nil, err
}
unprotected, err := m.Headers.MarshalUnprotected()
protected, unprotected, err := m.Headers.marshal()
if err != nil {
return nil, err
}
Expand Down
6 changes: 1 addition & 5 deletions sign1.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,7 @@ func (m *Sign1Message) MarshalCBOR() ([]byte, error) {
if len(m.Signature) == 0 {
return nil, ErrEmptySignature
}
protected, err := m.Headers.MarshalProtected()
if err != nil {
return nil, err
}
unprotected, err := m.Headers.MarshalUnprotected()
protected, unprotected, err := m.Headers.marshal()
if err != nil {
return nil, err
}
Expand Down
58 changes: 58 additions & 0 deletions sign1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,40 @@ func TestSign1Message_MarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "protected has IV and unprotected has PartialIV error",
m: &Sign1Message{
Headers: Headers{
Protected: ProtectedHeader{
HeaderLabelAlgorithm: AlgorithmES256,
HeaderLabelIV: "",
},
Unprotected: UnprotectedHeader{
HeaderLabelPartialIV: "",
},
},
Payload: []byte("foo"),
Signature: []byte("bar"),
},
wantErr: true,
},
{
name: "protected has PartialIV and unprotected has IV error",
m: &Sign1Message{
Headers: Headers{
Protected: ProtectedHeader{
HeaderLabelAlgorithm: AlgorithmES256,
HeaderLabelPartialIV: "",
},
Unprotected: UnprotectedHeader{
HeaderLabelIV: "",
},
},
Payload: []byte("foo"),
Signature: []byte("bar"),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -324,6 +358,30 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "protected has IV and unprotected has PartialIV",
data: []byte{
0xd2, // tag
0x84,
0x46, 0xa1, 0x5, 0x63, 0x66, 0x6f, 0x6f, // protected
0xa1, 0x6, 0x63, 0x62, 0x61, 0x72, // unprotected
0xf6, // payload
0x43, 0x62, 0x61, 0x72, // signature
},
wantErr: true,
},
{
name: "protected has PartialIV and unprotected has IV",
data: []byte{
0xd2, // tag
0x84,
0x46, 0xa1, 0x6, 0x63, 0x66, 0x6f, 0x6f, // protected
0xa1, 0x5, 0x63, 0x62, 0x61, 0x72, // unprotected
0xf6, // payload
0x43, 0x62, 0x61, 0x72, // signature
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
52 changes: 52 additions & 0 deletions sign_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,38 @@ func TestSignature_MarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "protected has IV and unprotected has PartialIV error",
s: &Signature{
Headers: Headers{
Protected: ProtectedHeader{
HeaderLabelAlgorithm: AlgorithmES256,
HeaderLabelIV: "",
},
Unprotected: UnprotectedHeader{
HeaderLabelPartialIV: "",
},
},
Signature: []byte("bar"),
},
wantErr: true,
},
{
name: "protected has PartialIV and unprotected has IV error",
s: &Signature{
Headers: Headers{
Protected: ProtectedHeader{
HeaderLabelAlgorithm: AlgorithmES256,
HeaderLabelPartialIV: "",
},
Unprotected: UnprotectedHeader{
HeaderLabelIV: "",
},
},
Signature: []byte("bar"),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -227,6 +259,26 @@ func TestSignature_UnmarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "protected has IV and unprotected has PartialIV",
data: []byte{
0x83,
0x46, 0xa1, 0x5, 0x63, 0x66, 0x6f, 0x6f, // protected
0xa1, 0x6, 0x63, 0x62, 0x61, 0x72, // unprotected
0x43, 0x62, 0x61, 0x72, // signature
},
wantErr: true,
},
{
name: "protected has PartialIV and unprotected has IV",
data: []byte{
0x83,
0x46, 0xa1, 0x6, 0x63, 0x66, 0x6f, 0x6f, // protected
0xa1, 0x5, 0x63, 0x62, 0x61, 0x72, // unprotected
0x43, 0x62, 0x61, 0x72, // signature
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down