diff --git a/decode.go b/decode.go index 0c8b398c..e37931f4 100644 --- a/decode.go +++ b/decode.go @@ -208,6 +208,18 @@ func (e *UnknownFieldError) Error() string { return fmt.Sprintf("cbor: found unknown field at map element index %d", e.Index) } +// UnacceptableDataItemError is returned when unmarshaling a CBOR input that contains a data item +// that is not acceptable to a specific CBOR-based application protocol ("invalid or unexpected" as +// described in RFC 8949 Section 5 Paragraph 3). +type UnacceptableDataItemError struct { + CBORType string + Message string +} + +func (e UnacceptableDataItemError) Error() string { + return fmt.Sprintf("cbor: data item of cbor type %s is not accepted by protocol: %s", e.CBORType, e.Message) +} + // DupMapKeyMode specifies how to enforce duplicate map key. Two map keys are considered duplicates if: // 1. When decoding into a struct, both keys match the same struct field. The keys are also // considered duplicates if neither matches any field and decoding to interface{} would produce @@ -496,6 +508,37 @@ func (tttam TimeTagToAnyMode) valid() bool { return tttam >= 0 && tttam < maxTimeTagToAnyMode } +// SimpleValueRegistry is a registry of unmarshaling behaviors for each possible CBOR simple value +// number (0...23 and 32...255). +type SimpleValueRegistry struct { + rejected [256]bool +} + +// WithRejectedSimpleValue registers the given simple value as rejected. If the simple value is +// encountered in a CBOR input during unmarshaling, an UnacceptableDataItemError is returned. +func WithRejectedSimpleValue(sv SimpleValue) func(*SimpleValueRegistry) error { + return func(r *SimpleValueRegistry) error { + if sv >= 24 && sv <= 31 { + return fmt.Errorf("cbor: cannot set analog for reserved simple value %d", sv) + } + r.rejected[sv] = true + return nil + } +} + +// Creates a new SimpleValueRegistry. The registry state is initialized by executing the provided +// functions in order against a registry that is pre-populated with the defaults for all well-formed +// simple value numbers. +func NewSimpleValueRegistryFromDefaults(fns ...func(*SimpleValueRegistry) error) (*SimpleValueRegistry, error) { + var r SimpleValueRegistry + for _, fn := range fns { + if err := fn(&r); err != nil { + return nil, err + } + } + return &r, nil +} + // DecOptions specifies decoding options. type DecOptions struct { // DupMapKey specifies whether to enforce duplicate map key. @@ -592,6 +635,16 @@ type DecOptions struct { // TimeTagToAnyMode specifies how to decode CBOR tag 0 and 1 into an empty interface (any). // Based on the specified mode, Unmarshal can return a time.Time value or a time string in a specific format. TimeTagToAny TimeTagToAnyMode + + // SimpleValues is an immutable mapping from each CBOR simple value to a corresponding + // unmarshal behavior. If nil, the simple values false, true, null, and undefined are mapped + // to the Go analog values false, true, nil, and nil, respectively, and all other simple + // values N (except the reserved simple values 24 through 31) are mapped to + // cbor.SimpleValue(N). In other words, all well-formed simple values can be decoded. + // + // Users may provide a custom SimpleValueRegistry constructed via + // NewSimpleValueRegistryFromDefaults. + SimpleValues *SimpleValueRegistry } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -661,6 +714,15 @@ const ( maxMaxNestedLevels = 65535 ) +var defaultSimpleValues = func() *SimpleValueRegistry { + registry, err := NewSimpleValueRegistryFromDefaults() + if err != nil { + panic(err) + } + return registry +}() + +//nolint:gocyclo // Each option comes with some manageable boilerplate func (opts DecOptions) decMode() (*decMode, error) { if !opts.DupMapKey.valid() { return nil, errors.New("cbor: invalid DupMapKey " + strconv.Itoa(int(opts.DupMapKey))) @@ -744,6 +806,10 @@ func (opts DecOptions) decMode() (*decMode, error) { if !opts.UnrecognizedTagToAny.valid() { return nil, errors.New("cbor: invalid UnrecognizedTagToAnyMode " + strconv.Itoa(int(opts.UnrecognizedTagToAny))) } + simpleValues := opts.SimpleValues + if simpleValues == nil { + simpleValues = defaultSimpleValues + } if !opts.TimeTagToAny.valid() { return nil, errors.New("cbor: invalid TimeTagToAny " + strconv.Itoa(int(opts.TimeTagToAny))) @@ -769,6 +835,7 @@ func (opts DecOptions) decMode() (*decMode, error) { fieldNameByteString: opts.FieldNameByteString, unrecognizedTagToAny: opts.UnrecognizedTagToAny, timeTagToAny: opts.TimeTagToAny, + simpleValues: simpleValues, } return &dm, nil @@ -841,12 +908,20 @@ type decMode struct { fieldNameByteString FieldNameByteStringMode unrecognizedTagToAny UnrecognizedTagToAnyMode timeTagToAny TimeTagToAnyMode + simpleValues *SimpleValueRegistry } var defaultDecMode, _ = DecOptions{}.decMode() // DecOptions returns user specified options used to create this DecMode. func (dm *decMode) DecOptions() DecOptions { + simpleValues := dm.simpleValues + if simpleValues == defaultSimpleValues { + // Users can't explicitly set this to defaultSimpleValues. It must have been nil in + // the original DecOptions. + simpleValues = nil + } + return DecOptions{ DupMapKey: dm.dupMapKey, TimeTag: dm.timeTag, @@ -867,6 +942,7 @@ func (dm *decMode) DecOptions() DecOptions { FieldNameByteString: dm.fieldNameByteString, UnrecognizedTagToAny: dm.unrecognizedTagToAny, TimeTagToAny: dm.timeTagToAny, + SimpleValues: simpleValues, } } @@ -1189,6 +1265,13 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin f := math.Float64frombits(val) return fillFloat(t, f, v) default: // ai <= 24 + if d.dm.simpleValues.rejected[SimpleValue(val)] { + return &UnacceptableDataItemError{ + CBORType: t.String(), + Message: "simple value " + strconv.FormatInt(int64(val), 10) + " is not recognized", + } + } + switch ai { case 20, 21: return fillBool(t, ai == 21, v) @@ -1607,6 +1690,12 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli return Tag{tagNum, content}, nil case cborTypePrimitives: _, ai, val := d.getHead() + if ai <= 24 && d.dm.simpleValues.rejected[SimpleValue(val)] { + return nil, &UnacceptableDataItemError{ + CBORType: t.String(), + Message: "simple value " + strconv.FormatInt(int64(val), 10) + " is not recognized", + } + } if ai < 20 || ai == 24 { return SimpleValue(val), nil } diff --git a/decode_test.go b/decode_test.go index 9bc4c881..afe287a0 100644 --- a/decode_test.go +++ b/decode_test.go @@ -4893,6 +4893,11 @@ func TestUnmarshalToNotNilInterface(t *testing.T) { } func TestDecOptions(t *testing.T) { + simpleValues, err := NewSimpleValueRegistryFromDefaults(WithRejectedSimpleValue(255)) + if err != nil { + t.Fatal(err) + } + opts1 := DecOptions{ DupMapKey: DupMapKeyEnforcedAPF, TimeTag: DecTagRequired, @@ -4913,6 +4918,7 @@ func TestDecOptions(t *testing.T) { FieldNameByteString: FieldNameByteStringAllowed, UnrecognizedTagToAny: UnrecognizedTagContentToAny, TimeTagToAny: TimeTagToRFC3339, + SimpleValues: simpleValues, } ov := reflect.ValueOf(opts1) for i := 0; i < ov.NumField(); i++ { @@ -8656,6 +8662,160 @@ func TestUnmarshalWithUnrecognizedTagToAnyModeForSharedTag(t *testing.T) { } } +func TestNewSimpleValueRegistry(t *testing.T) { + for _, tc := range []struct { + name string + opts []func(*SimpleValueRegistry) error + wantErrorMsg string + }{ + { + name: "min reserved", + opts: []func(*SimpleValueRegistry) error{WithRejectedSimpleValue(24)}, + wantErrorMsg: "cbor: cannot set analog for reserved simple value 24", + }, + { + name: "max reserved", + opts: []func(*SimpleValueRegistry) error{WithRejectedSimpleValue(31)}, + wantErrorMsg: "cbor: cannot set analog for reserved simple value 31", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := NewSimpleValueRegistryFromDefaults(tc.opts...) + if err == nil { + t.Fatalf("got nil error, want: %s", tc.wantErrorMsg) + } + if got := err.Error(); got != tc.wantErrorMsg { + t.Errorf("want: %s, got: %s", tc.wantErrorMsg, got) + } + }) + } +} + +func TestUnmarshalSimpleValues(t *testing.T) { + assertNilError := func(t *testing.T, e error) { + if e != nil { + t.Errorf("expected nil error, got: %v", e) + } + } + + assertExactError := func(want error) func(*testing.T, error) { + return func(t *testing.T, got error) { + if reflect.DeepEqual(want, got) { + return + } + t.Errorf("want %#v, got %#v", want, got) + } + } + + for _, tc := range []struct { + name string + fns []func(*SimpleValueRegistry) error + in []byte + into reflect.Type + want interface{} + assertOnError func(t *testing.T, e error) + }{ + { + name: "default false into interface{}", + fns: nil, + in: []byte{0xf4}, + into: typeIntf, + want: false, + assertOnError: assertNilError, + }, + { + name: "default false into bool", + fns: nil, + in: []byte{0xf4}, + into: typeBool, + want: false, + assertOnError: assertNilError, + }, + { + name: "default true into interface{}", + fns: nil, + in: []byte{0xf5}, + into: typeIntf, + want: true, + assertOnError: assertNilError, + }, + { + name: "default true into bool", + fns: nil, + in: []byte{0xf5}, + into: typeBool, + want: true, + assertOnError: assertNilError, + }, + { + name: "default null into interface{}", + fns: nil, + in: []byte{0xf6}, + into: typeIntf, + want: nil, + assertOnError: assertNilError, + }, + { + name: "default undefined into interface{}", + fns: nil, + in: []byte{0xf7}, + into: typeIntf, + want: nil, + assertOnError: assertNilError, + }, + { + name: "reject undefined into interface{}", + fns: []func(*SimpleValueRegistry) error{WithRejectedSimpleValue(23)}, + in: []byte{0xf7}, + into: typeIntf, + want: nil, + assertOnError: assertExactError(&UnacceptableDataItemError{ + CBORType: "primitives", + Message: "simple value 23 is not recognized", + }), + }, + { + name: "reject true into bool", + fns: []func(*SimpleValueRegistry) error{WithRejectedSimpleValue(21)}, + in: []byte{0xf5}, + into: typeBool, + want: false, + assertOnError: assertExactError(&UnacceptableDataItemError{ + CBORType: "primitives", + Message: "simple value 21 is not recognized", + }), + }, + { + name: "default unrecognized into uint64", + fns: nil, + in: []byte{0xf8, 0xc8}, + into: typeUint64, + want: uint64(200), + assertOnError: assertNilError, + }, + } { + t.Run(tc.name, func(t *testing.T) { + r, err := NewSimpleValueRegistryFromDefaults(tc.fns...) + if err != nil { + t.Fatal(err) + } + + decMode, err := DecOptions{SimpleValues: r}.DecMode() + if err != nil { + t.Fatal(err) + } + + dst := reflect.New(tc.into) + err = decMode.Unmarshal(tc.in, dst.Interface()) + tc.assertOnError(t, err) + + if got := dst.Elem().Interface(); !reflect.DeepEqual(tc.want, got) { + t.Errorf("got: %#v\nwant: %#v\n", got, tc.want) + } + }) + } +} + func isCBORNil(data []byte) bool { return len(data) > 0 && (data[0] == 0xf6 || data[0] == 0xf7) }