Skip to content

Commit

Permalink
Merge pull request #434 from fxamacker/fxamacker/refactor-option-vali…
Browse files Browse the repository at this point in the history
…d-check

Refactor valid() to reject negative values for integer modes
  • Loading branch information
fxamacker authored Nov 5, 2023
2 parents 4687659 + ad9dc1c commit ad0cc0c
Show file tree
Hide file tree
Showing 6 changed files with 599 additions and 297 deletions.
14 changes: 7 additions & 7 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ const (
)

func (dmkm DupMapKeyMode) valid() bool {
return dmkm < maxDupMapKeyMode
return dmkm >= 0 && dmkm < maxDupMapKeyMode
}

// IndefLengthMode specifies whether to allow indefinite length items.
Expand All @@ -243,7 +243,7 @@ const (
)

func (m IndefLengthMode) valid() bool {
return m < maxIndefLengthMode
return m >= 0 && m < maxIndefLengthMode
}

// TagsMode specifies whether to allow CBOR tags.
Expand All @@ -260,7 +260,7 @@ const (
)

func (tm TagsMode) valid() bool {
return tm < maxTagsMode
return tm >= 0 && tm < maxTagsMode
}

// IntDecMode specifies which Go int type (int64 or uint64) should
Expand All @@ -282,7 +282,7 @@ const (
)

func (idm IntDecMode) valid() bool {
return idm < maxIntDec
return idm >= 0 && idm < maxIntDec
}

// MapKeyByteStringMode specifies how to decode CBOR byte string (major type 2)
Expand Down Expand Up @@ -312,7 +312,7 @@ const (
)

func (mkbsm MapKeyByteStringMode) valid() bool {
return mkbsm < maxMapKeyByteStringMode
return mkbsm >= 0 && mkbsm < maxMapKeyByteStringMode
}

// ExtraDecErrorCond specifies extra conditions that should be treated as errors.
Expand Down Expand Up @@ -350,7 +350,7 @@ const (
)

func (um UTF8Mode) valid() bool {
return um < maxUTF8Mode
return um >= 0 && um < maxUTF8Mode
}

// FieldNameMatchingMode specifies how string keys in CBOR maps are matched to Go struct field names.
Expand Down Expand Up @@ -1985,7 +1985,7 @@ var (
typeBinaryUnmarshaler = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem()
)

func fillNil(t cborType, v reflect.Value) error {
func fillNil(_ cborType, v reflect.Value) error {
switch v.Kind() {
case reflect.Slice, reflect.Map, reflect.Interface, reflect.Ptr:
v.Set(reflect.Zero(v.Type()))
Expand Down
210 changes: 168 additions & 42 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3294,22 +3294,58 @@ func testRoundTrip(t *testing.T, testCases []roundTripTest, em EncMode, dm DecMo
}

func TestDecModeInvalidTimeTag(t *testing.T) {
wantErrorMsg := "cbor: invalid TimeTag 101"
_, err := DecOptions{TimeTag: 101}.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), wantErrorMsg)
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{TimeTag: -1},
wantErrorMsg: "cbor: invalid TimeTag -1",
},
{
name: "above range of valid modes",
opts: DecOptions{TimeTag: 101},
wantErrorMsg: "cbor: invalid TimeTag 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

func TestDecModeInvalidDuplicateMapKey(t *testing.T) {
wantErrorMsg := "cbor: invalid DupMapKey 101"
_, err := DecOptions{DupMapKey: 101}.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), wantErrorMsg)
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{DupMapKey: -1},
wantErrorMsg: "cbor: invalid DupMapKey -1",
},
{
name: "above range of valid modes",
opts: DecOptions{DupMapKey: 101},
wantErrorMsg: "cbor: invalid DupMapKey 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

Expand Down Expand Up @@ -3437,22 +3473,58 @@ func TestDecModeInvalidMaxArrayElements(t *testing.T) {
}

func TestDecModeInvalidIndefiniteLengthMode(t *testing.T) {
wantErrorMsg := "cbor: invalid IndefLength 101"
_, err := DecOptions{IndefLength: 101}.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), wantErrorMsg)
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{IndefLength: -1},
wantErrorMsg: "cbor: invalid IndefLength -1",
},
{
name: "above range of valid modes",
opts: DecOptions{IndefLength: 101},
wantErrorMsg: "cbor: invalid IndefLength 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

func TestDecModeInvalidTagsMode(t *testing.T) {
wantErrorMsg := "cbor: invalid TagsMd 101"
_, err := DecOptions{TagsMd: 101}.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), wantErrorMsg)
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{TagsMd: -1},
wantErrorMsg: "cbor: invalid TagsMd -1",
},
{
name: "above range of valid modes",
opts: DecOptions{TagsMd: 101},
wantErrorMsg: "cbor: invalid TagsMd 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

Expand Down Expand Up @@ -4605,12 +4677,30 @@ func TestDecTagsMdOption(t *testing.T) {
}

func TestDecModeInvalidIntDec(t *testing.T) {
wantErrorMsg := "cbor: invalid IntDec 101"
_, err := DecOptions{IntDec: 101}.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), wantErrorMsg)
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{IntDec: -1},
wantErrorMsg: "cbor: invalid IntDec -1",
},
{
name: "above range of valid modes",
opts: DecOptions{IntDec: 101},
wantErrorMsg: "cbor: invalid IntDec 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

Expand Down Expand Up @@ -4664,12 +4754,30 @@ func TestIntDec(t *testing.T) {
}

func TestDecModeInvalidMapKeyByteString(t *testing.T) {
wantErrorMsg := "cbor: invalid MapKeyByteString 101"
_, err := DecOptions{MapKeyByteString: 101}.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), wantErrorMsg)
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{MapKeyByteString: -1},
wantErrorMsg: "cbor: invalid MapKeyByteString -1",
},
{
name: "above range of valid modes",
opts: DecOptions{MapKeyByteString: 101},
wantErrorMsg: "cbor: invalid MapKeyByteString 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

Expand Down Expand Up @@ -4844,12 +4952,30 @@ func TestExtraErrorCondUnknownField(t *testing.T) {
}

func TestInvalidUTF8Mode(t *testing.T) {
wantErrorMsg := "cbor: invalid UTF8 2"
_, err := DecOptions{UTF8: 2}.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), wantErrorMsg)
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{UTF8: -1},
wantErrorMsg: "cbor: invalid UTF8 -1",
},
{
name: "above range of valid modes",
opts: DecOptions{UTF8: 101},
wantErrorMsg: "cbor: invalid UTF8 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

Expand Down
Loading

0 comments on commit ad0cc0c

Please sign in to comment.