From 58cd838d4e1cd21a30ecb4866872e30590bc1576 Mon Sep 17 00:00:00 2001 From: Ben Luddy Date: Tue, 30 Jan 2024 13:52:27 -0500 Subject: [PATCH] Add option to set arbitrary simple value to Go value mappings. Signed-off-by: Ben Luddy --- decode.go | 226 +++++++++++++++++++++++++++++++++++++++++++------ decode_test.go | 225 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 425 insertions(+), 26 deletions(-) diff --git a/decode.go b/decode.go index c2b5e008..f4828993 100644 --- a/decode.go +++ b/decode.go @@ -208,6 +208,18 @@ func (e *UnknownFieldError) Error() string { return fmt.Sprintf("cbor: found unknown field at map element index %d", e.Index) } +// UnacceptableDataItemError is returned when unmarshaling a CBOR input that contains a data item that +// is not acceptable to a specific CBOR-based application protocol (as described in RFC 8949 Section +// 5 Paragraph 3). +type UnacceptableDataItemError struct { + CBORType string + Message string +} + +func (e UnacceptableDataItemError) Error() string { + return fmt.Sprintf("cbor: protocol does not accept data items of cbor type %s: %s", e.CBORType, e.Message) +} + // DupMapKeyMode specifies how to enforce duplicate map key. Two map keys are considered duplicates if: // 1. When decoding into a struct, both keys match the same struct field. The keys are also // considered duplicates if neither matches any field and decoding to interface{} would produce @@ -472,6 +484,94 @@ func (uttam UnrecognizedTagToAnyMode) valid() bool { return uttam >= 0 && uttam < maxUnrecognizedTagToAny } +// SimpleValueRegistry is an immutable mapping from CBOR simple value number (0...23 and 32...255) +// to Go analog value. +type SimpleValueRegistry struct { + analogs [256]*interface{} +} + +// WithSimpleValueAnalog registers a Go analog value for the given simple value. When decoding into +// an empty interface value, the registered analog value is returned. When decoding into a concrete +// type, the type of the registered analog must be directly assignable to the destination's type. +func WithSimpleValueAnalog(sv SimpleValue, analog interface{}) func(*SimpleValueRegistry) error { + return func(r *SimpleValueRegistry) error { + if sv >= 24 && sv <= 31 { + return fmt.Errorf("cbor: cannot set analog for reserved simple value %d", sv) + } + r.analogs[sv] = &analog + return nil + } +} + +// WithNoSimpleValueAnalog marks the given simple value as having no registered Go analog. If +// encountered during unmarshaling, an UnacceptableDataItemError will be returned. +func WithNoSimpleValueAnalog(sv SimpleValue) func(*SimpleValueRegistry) error { + return func(r *SimpleValueRegistry) error { + if sv >= 24 && sv <= 31 { + return fmt.Errorf("cbor: cannot set analog for reserved simple value %d", sv) + } + r.analogs[sv] = nil + return nil + } +} + +// builtinAnalog wraps any of the built-in default simple value analogs. Simple values mapped to one +// of the built-in analogs are decoded more permissively than user-provided analogs, which must be +// directly assignable to a destination value. +type builtinAnalog struct { + v interface{} +} + +// Creates a new SimpleValueRegistry. The registry state is initialized by executing the provided +// functions in order against an empty registry. Any simple value without a registered analog will +// produce an UnacceptableDataItemError if encountered in the input while unmarshaling. +func NewSimpleValueRegistry(fns ...func(*SimpleValueRegistry) error) (*SimpleValueRegistry, error) { + var r SimpleValueRegistry + for _, fn := range fns { + if err := fn(&r); err != nil { + return nil, err + } + } + return &r, nil +} + +// Creates a new SimpleValueRegistry. The registry state is initialized by executing the provided +// functions in order against a registry that is pre-populated with the library defaults. +func NewSimpleValueRegistryFromDefaults(fns ...func(*SimpleValueRegistry) error) (*SimpleValueRegistry, error) { + var r SimpleValueRegistry + for _, fn := range append([]func(*SimpleValueRegistry) error{withDefaultSimpleValueAnalogs}, fns...) { + if err := fn(&r); err != nil { + return nil, err + } + } + return &r, nil +} + +// withDefaultSimpleValueAnalogs registers Go analogs for false (false), true (true), null (nil), +// and undefined (nil). For the simple values numbering 0 through 19, inclusive, and 32 through 255, +// inclusive, registers the analog SimpleValue(N), where N is each respective simple value number. +func withDefaultSimpleValueAnalogs(r *SimpleValueRegistry) error { + var err error + for i := 0; i <= 255 && err == nil; i++ { + sv := SimpleValue(i) + switch { + case sv == 20: + err = WithSimpleValueAnalog(20, builtinAnalog{false})(r) + case sv == 21: + err = WithSimpleValueAnalog(21, builtinAnalog{true})(r) + case sv == 22: // null + err = WithSimpleValueAnalog(22, builtinAnalog{nil})(r) + case sv == 23: // undefined + err = WithSimpleValueAnalog(23, builtinAnalog{nil})(r) + case sv >= 24 && sv <= 31: // reserved + continue + default: + err = WithSimpleValueAnalog(sv, builtinAnalog{sv})(r) + } + } + return err +} + // DecOptions specifies decoding options. type DecOptions struct { // DupMapKey specifies whether to enforce duplicate map key. @@ -564,6 +664,16 @@ type DecOptions struct { // UnrecognizedTagToAny specifies how to decode unrecognized CBOR tag into an empty interface. // Currently, recognized CBOR tag numbers are 0, 1, 2, 3, or registered by TagSet. UnrecognizedTagToAny UnrecognizedTagToAnyMode + + // SimpleValues is an immutable mapping from CBOR simple value to a Go analog value. If nil, + // the simple values false, true, null, and undefined are mapped to Go analog values false, + // true, nil, and nil, respectively, and all other simple values N (except the reserved + // simple values 24 through 31) are mapped to cbor.SimpleValue(N). In other words, all + // well-formed simple values can be decoded. + // + // Users may construct a custom SimpleValueRegistry via NewSimpleValueRegistry or + // NewSimpleValueRegistryFromDefaults. + SimpleValues *SimpleValueRegistry } // DecMode returns DecMode with immutable options and no tags (safe for concurrency). @@ -633,6 +743,14 @@ const ( maxMaxNestedLevels = 65535 ) +var defaultSimpleValues = func() *SimpleValueRegistry { + registry, err := NewSimpleValueRegistry(withDefaultSimpleValueAnalogs) + if err != nil { + panic(err) + } + return registry +}() + func (opts DecOptions) decMode() (*decMode, error) { if !opts.DupMapKey.valid() { return nil, errors.New("cbor: invalid DupMapKey " + strconv.Itoa(int(opts.DupMapKey))) @@ -716,6 +834,10 @@ func (opts DecOptions) decMode() (*decMode, error) { if !opts.UnrecognizedTagToAny.valid() { return nil, errors.New("cbor: invalid UnrecognizedTagToAnyMode " + strconv.Itoa(int(opts.UnrecognizedTagToAny))) } + simpleValues := opts.SimpleValues + if simpleValues == nil { + simpleValues = defaultSimpleValues + } dm := decMode{ dupMapKey: opts.DupMapKey, @@ -736,6 +858,7 @@ func (opts DecOptions) decMode() (*decMode, error) { byteStringToString: opts.ByteStringToString, fieldNameByteString: opts.FieldNameByteString, unrecognizedTagToAny: opts.UnrecognizedTagToAny, + simpleValues: simpleValues, } return &dm, nil @@ -807,12 +930,20 @@ type decMode struct { byteStringToString ByteStringToStringMode fieldNameByteString FieldNameByteStringMode unrecognizedTagToAny UnrecognizedTagToAnyMode + simpleValues *SimpleValueRegistry } var defaultDecMode, _ = DecOptions{}.decMode() // DecOptions returns user specified options used to create this DecMode. func (dm *decMode) DecOptions() DecOptions { + simpleValues := dm.simpleValues + if simpleValues == defaultSimpleValues { + // Users can't explicitly set this to defaultSimpleValues. It must have been nil in + // the original DecOptions. + simpleValues = nil + } + return DecOptions{ DupMapKey: dm.dupMapKey, TimeTag: dm.timeTag, @@ -832,6 +963,7 @@ func (dm *decMode) DecOptions() DecOptions { ByteStringToString: dm.byteStringToString, FieldNameByteString: dm.fieldNameByteString, UnrecognizedTagToAny: dm.unrecognizedTagToAny, + SimpleValues: simpleValues, } } @@ -1154,14 +1286,43 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin f := math.Float64frombits(val) return fillFloat(t, f, v) default: // ai <= 24 - switch ai { - case 20, 21: - return fillBool(t, ai == 21, v) - case 22, 23: - return fillNil(t, v) - default: - return fillPositiveInt(t, val, v) + analog := d.dm.simpleValues.analogs[SimpleValue(val)] + if analog == nil { + return &UnacceptableDataItemError{ + CBORType: t.String(), + Message: "simple value " + strconv.FormatInt(int64(val), 10) + " is not recognized", + } + } + + // Compatibility mode for simple value decoding using the default analogs. + if ba, ok := (*analog).(builtinAnalog); ok { + return fillBuiltinAnalog(t, ba, v) + } + + if *analog == nil { + switch v.Kind() { + case reflect.Ptr, reflect.Func, reflect.Slice, reflect.Map, reflect.Chan, reflect.Interface: + // (reflect.Value) SetZero() was added in Go 1.20. + v.Set(reflect.Zero(v.Type())) + return nil + } + return &UnmarshalTypeError{ + CBORType: t.String(), + GoType: v.Type().String(), + errorMsg: fmt.Sprintf("analog %v (%T) for simple value %d is not assignable to a value of this type", *analog, *analog, val), + } + } + + av := reflect.ValueOf(*analog) + if !av.Type().AssignableTo(v.Type()) { + return &UnmarshalTypeError{ + CBORType: t.String(), + GoType: v.Type().String(), + errorMsg: fmt.Sprintf("analog %v (%T) for simple value %d is not assignable to a value of this type", *analog, *analog, val), + } } + v.Set(av) + return nil } case cborTypeTag: @@ -1554,14 +1715,20 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli return Tag{tagNum, content}, nil case cborTypePrimitives: _, ai, val := d.getHead() - if ai < 20 || ai == 24 { - return SimpleValue(val), nil + if ai <= 24 { + analog := d.dm.simpleValues.analogs[SimpleValue(val)] + if analog == nil { + return nil, &UnacceptableDataItemError{ + CBORType: t.String(), + Message: "simple value " + strconv.FormatInt(int64(val), 10) + " is not recognized", + } + } + if ba, ok := (*analog).(builtinAnalog); ok { + return ba.v, nil + } + return *analog, nil } switch ai { - case 20, 21: - return (ai == 21), nil - case 22, 23: - return nil, nil case 25: f := float64(float16.Frombits(uint16(val)).Float32()) return f, nil @@ -2370,13 +2537,28 @@ var ( typeByteSlice = reflect.TypeOf([]byte(nil)) ) -func fillNil(_ cborType, v reflect.Value) error { - switch v.Kind() { - case reflect.Slice, reflect.Map, reflect.Interface, reflect.Ptr: - v.Set(reflect.Zero(v.Type())) +func fillBuiltinAnalog(t cborType, analog builtinAnalog, v reflect.Value) error { + switch analog.v { + case nil: + switch v.Kind() { + case reflect.Slice, reflect.Map, reflect.Interface, reflect.Ptr: + v.Set(reflect.Zero(v.Type())) + default: + // no-op + } return nil + case true, false: + if v.Kind() == reflect.Bool { + v.SetBool(analog.v.(bool)) + return nil + } + default: + if sv, ok := (analog.v).(SimpleValue); ok { + return fillPositiveInt(t, uint64(sv), v) + } } - return nil + + return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()} } func fillPositiveInt(t cborType, val uint64, v reflect.Value) error { @@ -2446,14 +2628,6 @@ func fillNegativeInt(t cborType, val int64, v reflect.Value) error { return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()} } -func fillBool(t cborType, val bool, v reflect.Value) error { - if v.Kind() == reflect.Bool { - v.SetBool(val) - return nil - } - return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()} -} - func fillFloat(t cborType, val float64, v reflect.Value) error { switch v.Kind() { case reflect.Float32, reflect.Float64: diff --git a/decode_test.go b/decode_test.go index 54da82f3..58f297b2 100644 --- a/decode_test.go +++ b/decode_test.go @@ -4893,6 +4893,11 @@ func TestUnmarshalToNotNilInterface(t *testing.T) { } func TestDecOptions(t *testing.T) { + simpleValues, err := NewSimpleValueRegistry(WithSimpleValueAnalog(255, false)) + if err != nil { + t.Fatal(err) + } + opts1 := DecOptions{ DupMapKey: DupMapKeyEnforcedAPF, TimeTag: DecTagRequired, @@ -4912,6 +4917,7 @@ func TestDecOptions(t *testing.T) { ByteStringToString: ByteStringToStringAllowed, FieldNameByteString: FieldNameByteStringAllowed, UnrecognizedTagToAny: UnrecognizedTagContentToAny, + SimpleValues: simpleValues, } ov := reflect.ValueOf(opts1) for i := 0; i < ov.NumField(); i++ { @@ -8655,6 +8661,225 @@ func TestUnmarshalWithUnrecognizedTagToAnyModeForSharedTag(t *testing.T) { } } +func TestNewSimpleValueRegistry(t *testing.T) { + for _, tc := range []struct { + name string + opts []func(*SimpleValueRegistry) error + wantErrorMsg string + }{ + { + name: "min reserved", + opts: []func(*SimpleValueRegistry) error{WithSimpleValueAnalog(24, true)}, + wantErrorMsg: "cbor: cannot set analog for reserved simple value 24", + }, + { + name: "max reserved", + opts: []func(*SimpleValueRegistry) error{WithSimpleValueAnalog(31, true)}, + wantErrorMsg: "cbor: cannot set analog for reserved simple value 31", + }, + } { + t.Run(tc.name, func(t *testing.T) { + _, err := NewSimpleValueRegistry(tc.opts...) + if err == nil { + t.Fatalf("got nil error, want: %s", tc.wantErrorMsg) + } + if got := err.Error(); got != tc.wantErrorMsg { + t.Errorf("want: %s, got: %s", tc.wantErrorMsg, got) + } + }) + } +} + +func TestUnmarshalSimpleValues(t *testing.T) { + type namedBool bool + + assertNilError := func(t *testing.T, e error) { + if e != nil { + t.Errorf("expected nil error, got: %v", e) + } + } + + assertExactError := func(want error) func(*testing.T, error) { + return func(t *testing.T, got error) { + if reflect.DeepEqual(want, got) { + return + } + t.Errorf("want %#v, got %#v", want, got) + } + } + + for _, tc := range []struct { + name string + builder func(...func(*SimpleValueRegistry) error) (*SimpleValueRegistry, error) + fns []func(*SimpleValueRegistry) error + in []byte + into reflect.Type + want interface{} + assertOnError func(t *testing.T, e error) + }{ + { + name: "default false into interface{}", + builder: NewSimpleValueRegistryFromDefaults, + fns: nil, + in: []byte{0xf4}, + into: typeIntf, + want: false, + assertOnError: assertNilError, + }, + { + name: "default false into bool", + builder: NewSimpleValueRegistryFromDefaults, + fns: nil, + in: []byte{0xf4}, + into: typeBool, + want: false, + assertOnError: assertNilError, + }, + { + name: "default true into interface{}", + builder: NewSimpleValueRegistryFromDefaults, + fns: nil, + in: []byte{0xf5}, + into: typeIntf, + want: true, + assertOnError: assertNilError, + }, + { + name: "default true into bool", + builder: NewSimpleValueRegistryFromDefaults, + fns: nil, + in: []byte{0xf5}, + into: typeBool, + want: true, + assertOnError: assertNilError, + }, + { + name: "provided true into namedBool", + builder: NewSimpleValueRegistry, + fns: []func(*SimpleValueRegistry) error{WithSimpleValueAnalog(21, true)}, + in: []byte{0xf5}, + into: reflect.TypeOf(namedBool(false)), + want: namedBool(false), + assertOnError: assertExactError(&UnmarshalTypeError{ + CBORType: "primitives", + GoType: reflect.TypeOf(namedBool(false)).String(), + errorMsg: "analog true (bool) for simple value 21 is not assignable to a value of this type", + }), + }, + { + name: "provided namedBool(true) into namedBool", + builder: NewSimpleValueRegistry, + fns: []func(*SimpleValueRegistry) error{WithSimpleValueAnalog(21, namedBool(true))}, + in: []byte{0xf5}, + into: reflect.TypeOf(namedBool(false)), + want: namedBool(true), + assertOnError: assertNilError, + }, + { + name: "provided namedBool(true) into interface{}", + builder: NewSimpleValueRegistry, + fns: []func(*SimpleValueRegistry) error{WithSimpleValueAnalog(21, namedBool(true))}, + in: []byte{0xf5}, + into: typeIntf, + want: namedBool(true), + assertOnError: assertNilError, + }, + { + name: "default null into interface{}", + builder: NewSimpleValueRegistryFromDefaults, + fns: nil, + in: []byte{0xf6}, + into: typeIntf, + want: nil, + assertOnError: assertNilError, + }, + { + name: "provided untyped nil into []int32", + builder: NewSimpleValueRegistry, + fns: []func(*SimpleValueRegistry) error{WithSimpleValueAnalog(200, nil)}, + in: []byte{0xf8, 0xc8}, + into: reflect.TypeOf(([]int32)(nil)), + want: ([]int32)(nil), + assertOnError: assertNilError, + }, + { + name: "provided untyped nil into string", + builder: NewSimpleValueRegistry, + fns: []func(*SimpleValueRegistry) error{WithSimpleValueAnalog(200, nil)}, + in: []byte{0xf8, 0xc8}, + into: typeString, + want: "", + assertOnError: assertExactError(&UnmarshalTypeError{ + CBORType: "primitives", + GoType: "string", + errorMsg: "analog () for simple value 200 is not assignable to a value of this type", + }), + }, + { + name: "default undefined into interface{}", + builder: NewSimpleValueRegistryFromDefaults, + fns: nil, + in: []byte{0xf7}, + into: typeIntf, + want: nil, + assertOnError: assertNilError, + }, + { + name: "cleared default undefined into interface{}", + builder: NewSimpleValueRegistryFromDefaults, + fns: []func(*SimpleValueRegistry) error{WithNoSimpleValueAnalog(23)}, + in: []byte{0xf7}, + into: typeIntf, + want: nil, + assertOnError: assertExactError(&UnacceptableDataItemError{ + CBORType: "primitives", + Message: "simple value 23 is not recognized", + }), + }, + { + name: "default unrecognized into uint64", + builder: NewSimpleValueRegistryFromDefaults, + fns: nil, + in: []byte{0xf8, 0xc8}, + into: typeUint64, + want: uint64(200), + assertOnError: assertNilError, + }, + { + name: "empty unrecognized into uint64", + builder: NewSimpleValueRegistry, + fns: nil, + in: []byte{0xf8, 0xc8}, + into: typeUint64, + want: uint64(0), + assertOnError: assertExactError(&UnacceptableDataItemError{ + CBORType: "primitives", + Message: "simple value 200 is not recognized", + }), + }, + } { + t.Run(tc.name, func(t *testing.T) { + r, err := tc.builder(tc.fns...) + if err != nil { + t.Fatal(err) + } + + decMode, err := DecOptions{SimpleValues: r}.DecMode() + if err != nil { + t.Fatal(err) + } + + dst := reflect.New(tc.into) + err = decMode.Unmarshal(tc.in, dst.Interface()) + tc.assertOnError(t, err) + + if got := dst.Elem().Interface(); !reflect.DeepEqual(tc.want, got) { + t.Errorf("got: %#v\nwant: %#v\n", got, tc.want) + } + }) + } +} + func isCBORNil(data []byte) bool { return len(data) > 0 && (data[0] == 0xf6 || data[0] == 0xf7) }