Skip to content

Commit

Permalink
Merge pull request #481 from benluddy/simplevalue-govalues
Browse files Browse the repository at this point in the history
Add decode option to allow rejecting inputs that contain certain simple values.
  • Loading branch information
fxamacker committed Apr 7, 2024
2 parents 4330f59 + 28078a7 commit 9f099e8
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 0 deletions.
89 changes: 89 additions & 0 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,18 @@ func (e *UnknownFieldError) Error() string {
return fmt.Sprintf("cbor: found unknown field at map element index %d", e.Index)
}

// UnacceptableDataItemError is returned when unmarshaling a CBOR input that contains a data item
// that is not acceptable to a specific CBOR-based application protocol ("invalid or unexpected" as
// described in RFC 8949 Section 5 Paragraph 3).
type UnacceptableDataItemError struct {
CBORType string
Message string
}

func (e UnacceptableDataItemError) Error() string {
return fmt.Sprintf("cbor: data item of cbor type %s is not accepted by protocol: %s", e.CBORType, e.Message)
}

// DupMapKeyMode specifies how to enforce duplicate map key. Two map keys are considered duplicates if:
// 1. When decoding into a struct, both keys match the same struct field. The keys are also
// considered duplicates if neither matches any field and decoding to interface{} would produce
Expand Down Expand Up @@ -496,6 +508,37 @@ func (tttam TimeTagToAnyMode) valid() bool {
return tttam >= 0 && tttam < maxTimeTagToAnyMode
}

// SimpleValueRegistry is a registry of unmarshaling behaviors for each possible CBOR simple value
// number (0...23 and 32...255).
type SimpleValueRegistry struct {
rejected [256]bool
}

// WithRejectedSimpleValue registers the given simple value as rejected. If the simple value is
// encountered in a CBOR input during unmarshaling, an UnacceptableDataItemError is returned.
func WithRejectedSimpleValue(sv SimpleValue) func(*SimpleValueRegistry) error {
return func(r *SimpleValueRegistry) error {
if sv >= 24 && sv <= 31 {
return fmt.Errorf("cbor: cannot set analog for reserved simple value %d", sv)
}
r.rejected[sv] = true
return nil
}
}

// Creates a new SimpleValueRegistry. The registry state is initialized by executing the provided
// functions in order against a registry that is pre-populated with the defaults for all well-formed
// simple value numbers.
func NewSimpleValueRegistryFromDefaults(fns ...func(*SimpleValueRegistry) error) (*SimpleValueRegistry, error) {
var r SimpleValueRegistry
for _, fn := range fns {
if err := fn(&r); err != nil {
return nil, err
}
}
return &r, nil
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -592,6 +635,16 @@ type DecOptions struct {
// TimeTagToAnyMode specifies how to decode CBOR tag 0 and 1 into an empty interface (any).
// Based on the specified mode, Unmarshal can return a time.Time value or a time string in a specific format.
TimeTagToAny TimeTagToAnyMode

// SimpleValues is an immutable mapping from each CBOR simple value to a corresponding
// unmarshal behavior. If nil, the simple values false, true, null, and undefined are mapped
// to the Go analog values false, true, nil, and nil, respectively, and all other simple
// values N (except the reserved simple values 24 through 31) are mapped to
// cbor.SimpleValue(N). In other words, all well-formed simple values can be decoded.
//
// Users may provide a custom SimpleValueRegistry constructed via
// NewSimpleValueRegistryFromDefaults.
SimpleValues *SimpleValueRegistry
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -661,6 +714,15 @@ const (
maxMaxNestedLevels = 65535
)

var defaultSimpleValues = func() *SimpleValueRegistry {
registry, err := NewSimpleValueRegistryFromDefaults()
if err != nil {
panic(err)
}
return registry
}()

//nolint:gocyclo // Each option comes with some manageable boilerplate
func (opts DecOptions) decMode() (*decMode, error) {
if !opts.DupMapKey.valid() {
return nil, errors.New("cbor: invalid DupMapKey " + strconv.Itoa(int(opts.DupMapKey)))
Expand Down Expand Up @@ -744,6 +806,10 @@ func (opts DecOptions) decMode() (*decMode, error) {
if !opts.UnrecognizedTagToAny.valid() {
return nil, errors.New("cbor: invalid UnrecognizedTagToAnyMode " + strconv.Itoa(int(opts.UnrecognizedTagToAny)))
}
simpleValues := opts.SimpleValues
if simpleValues == nil {
simpleValues = defaultSimpleValues
}

if !opts.TimeTagToAny.valid() {
return nil, errors.New("cbor: invalid TimeTagToAny " + strconv.Itoa(int(opts.TimeTagToAny)))
Expand All @@ -769,6 +835,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
fieldNameByteString: opts.FieldNameByteString,
unrecognizedTagToAny: opts.UnrecognizedTagToAny,
timeTagToAny: opts.TimeTagToAny,
simpleValues: simpleValues,
}

return &dm, nil
Expand Down Expand Up @@ -841,12 +908,20 @@ type decMode struct {
fieldNameByteString FieldNameByteStringMode
unrecognizedTagToAny UnrecognizedTagToAnyMode
timeTagToAny TimeTagToAnyMode
simpleValues *SimpleValueRegistry
}

var defaultDecMode, _ = DecOptions{}.decMode()

// DecOptions returns user specified options used to create this DecMode.
func (dm *decMode) DecOptions() DecOptions {
simpleValues := dm.simpleValues
if simpleValues == defaultSimpleValues {
// Users can't explicitly set this to defaultSimpleValues. It must have been nil in
// the original DecOptions.
simpleValues = nil
}

return DecOptions{
DupMapKey: dm.dupMapKey,
TimeTag: dm.timeTag,
Expand All @@ -867,6 +942,7 @@ func (dm *decMode) DecOptions() DecOptions {
FieldNameByteString: dm.fieldNameByteString,
UnrecognizedTagToAny: dm.unrecognizedTagToAny,
TimeTagToAny: dm.timeTagToAny,
SimpleValues: simpleValues,
}
}

Expand Down Expand Up @@ -1189,6 +1265,13 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
f := math.Float64frombits(val)
return fillFloat(t, f, v)
default: // ai <= 24
if d.dm.simpleValues.rejected[SimpleValue(val)] {
return &UnacceptableDataItemError{
CBORType: t.String(),
Message: "simple value " + strconv.FormatInt(int64(val), 10) + " is not recognized",
}
}

switch ai {
case 20, 21:
return fillBool(t, ai == 21, v)
Expand Down Expand Up @@ -1607,6 +1690,12 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
return Tag{tagNum, content}, nil
case cborTypePrimitives:
_, ai, val := d.getHead()
if ai <= 24 && d.dm.simpleValues.rejected[SimpleValue(val)] {
return nil, &UnacceptableDataItemError{
CBORType: t.String(),
Message: "simple value " + strconv.FormatInt(int64(val), 10) + " is not recognized",
}
}
if ai < 20 || ai == 24 {
return SimpleValue(val), nil
}
Expand Down
160 changes: 160 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4893,6 +4893,11 @@ func TestUnmarshalToNotNilInterface(t *testing.T) {
}

func TestDecOptions(t *testing.T) {
simpleValues, err := NewSimpleValueRegistryFromDefaults(WithRejectedSimpleValue(255))
if err != nil {
t.Fatal(err)
}

opts1 := DecOptions{
DupMapKey: DupMapKeyEnforcedAPF,
TimeTag: DecTagRequired,
Expand All @@ -4913,6 +4918,7 @@ func TestDecOptions(t *testing.T) {
FieldNameByteString: FieldNameByteStringAllowed,
UnrecognizedTagToAny: UnrecognizedTagContentToAny,
TimeTagToAny: TimeTagToRFC3339,
SimpleValues: simpleValues,
}
ov := reflect.ValueOf(opts1)
for i := 0; i < ov.NumField(); i++ {
Expand Down Expand Up @@ -8656,6 +8662,160 @@ func TestUnmarshalWithUnrecognizedTagToAnyModeForSharedTag(t *testing.T) {
}
}

func TestNewSimpleValueRegistry(t *testing.T) {
for _, tc := range []struct {
name string
opts []func(*SimpleValueRegistry) error
wantErrorMsg string
}{
{
name: "min reserved",
opts: []func(*SimpleValueRegistry) error{WithRejectedSimpleValue(24)},
wantErrorMsg: "cbor: cannot set analog for reserved simple value 24",
},
{
name: "max reserved",
opts: []func(*SimpleValueRegistry) error{WithRejectedSimpleValue(31)},
wantErrorMsg: "cbor: cannot set analog for reserved simple value 31",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := NewSimpleValueRegistryFromDefaults(tc.opts...)
if err == nil {
t.Fatalf("got nil error, want: %s", tc.wantErrorMsg)
}
if got := err.Error(); got != tc.wantErrorMsg {
t.Errorf("want: %s, got: %s", tc.wantErrorMsg, got)
}
})
}
}

func TestUnmarshalSimpleValues(t *testing.T) {
assertNilError := func(t *testing.T, e error) {
if e != nil {
t.Errorf("expected nil error, got: %v", e)
}
}

assertExactError := func(want error) func(*testing.T, error) {
return func(t *testing.T, got error) {
if reflect.DeepEqual(want, got) {
return
}
t.Errorf("want %#v, got %#v", want, got)
}
}

for _, tc := range []struct {
name string
fns []func(*SimpleValueRegistry) error
in []byte
into reflect.Type
want interface{}
assertOnError func(t *testing.T, e error)
}{
{
name: "default false into interface{}",
fns: nil,
in: []byte{0xf4},
into: typeIntf,
want: false,
assertOnError: assertNilError,
},
{
name: "default false into bool",
fns: nil,
in: []byte{0xf4},
into: typeBool,
want: false,
assertOnError: assertNilError,
},
{
name: "default true into interface{}",
fns: nil,
in: []byte{0xf5},
into: typeIntf,
want: true,
assertOnError: assertNilError,
},
{
name: "default true into bool",
fns: nil,
in: []byte{0xf5},
into: typeBool,
want: true,
assertOnError: assertNilError,
},
{
name: "default null into interface{}",
fns: nil,
in: []byte{0xf6},
into: typeIntf,
want: nil,
assertOnError: assertNilError,
},
{
name: "default undefined into interface{}",
fns: nil,
in: []byte{0xf7},
into: typeIntf,
want: nil,
assertOnError: assertNilError,
},
{
name: "reject undefined into interface{}",
fns: []func(*SimpleValueRegistry) error{WithRejectedSimpleValue(23)},
in: []byte{0xf7},
into: typeIntf,
want: nil,
assertOnError: assertExactError(&UnacceptableDataItemError{
CBORType: "primitives",
Message: "simple value 23 is not recognized",
}),
},
{
name: "reject true into bool",
fns: []func(*SimpleValueRegistry) error{WithRejectedSimpleValue(21)},
in: []byte{0xf5},
into: typeBool,
want: false,
assertOnError: assertExactError(&UnacceptableDataItemError{
CBORType: "primitives",
Message: "simple value 21 is not recognized",
}),
},
{
name: "default unrecognized into uint64",
fns: nil,
in: []byte{0xf8, 0xc8},
into: typeUint64,
want: uint64(200),
assertOnError: assertNilError,
},
} {
t.Run(tc.name, func(t *testing.T) {
r, err := NewSimpleValueRegistryFromDefaults(tc.fns...)
if err != nil {
t.Fatal(err)
}

decMode, err := DecOptions{SimpleValues: r}.DecMode()
if err != nil {
t.Fatal(err)
}

dst := reflect.New(tc.into)
err = decMode.Unmarshal(tc.in, dst.Interface())
tc.assertOnError(t, err)

if got := dst.Elem().Interface(); !reflect.DeepEqual(tc.want, got) {
t.Errorf("got: %#v\nwant: %#v\n", got, tc.want)
}
})
}
}

func isCBORNil(data []byte) bool {
return len(data) > 0 && (data[0] == 0xf6 || data[0] == 0xf7)
}
Expand Down

0 comments on commit 9f099e8

Please sign in to comment.