Skip to content

Commit

Permalink
Add option to set arbitrary simple value to Go value mappings.
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Luddy <bluddy@redhat.com>
  • Loading branch information
benluddy committed Mar 26, 2024
1 parent 571b811 commit 58cd838
Show file tree
Hide file tree
Showing 2 changed files with 425 additions and 26 deletions.
226 changes: 200 additions & 26 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,18 @@ func (e *UnknownFieldError) Error() string {
return fmt.Sprintf("cbor: found unknown field at map element index %d", e.Index)
}

// UnacceptableDataItemError is returned when unmarshaling a CBOR input that contains a data item that
// is not acceptable to a specific CBOR-based application protocol (as described in RFC 8949 Section
// 5 Paragraph 3).
type UnacceptableDataItemError struct {
CBORType string
Message string
}

func (e UnacceptableDataItemError) Error() string {
return fmt.Sprintf("cbor: protocol does not accept data items of cbor type %s: %s", e.CBORType, e.Message)
}

// DupMapKeyMode specifies how to enforce duplicate map key. Two map keys are considered duplicates if:
// 1. When decoding into a struct, both keys match the same struct field. The keys are also
// considered duplicates if neither matches any field and decoding to interface{} would produce
Expand Down Expand Up @@ -472,6 +484,94 @@ func (uttam UnrecognizedTagToAnyMode) valid() bool {
return uttam >= 0 && uttam < maxUnrecognizedTagToAny
}

// SimpleValueRegistry is an immutable mapping from CBOR simple value number (0...23 and 32...255)
// to Go analog value.
type SimpleValueRegistry struct {
analogs [256]*interface{}
}

// WithSimpleValueAnalog registers a Go analog value for the given simple value. When decoding into
// an empty interface value, the registered analog value is returned. When decoding into a concrete
// type, the type of the registered analog must be directly assignable to the destination's type.
func WithSimpleValueAnalog(sv SimpleValue, analog interface{}) func(*SimpleValueRegistry) error {
return func(r *SimpleValueRegistry) error {
if sv >= 24 && sv <= 31 {
return fmt.Errorf("cbor: cannot set analog for reserved simple value %d", sv)
}
r.analogs[sv] = &analog
return nil
}
}

// WithNoSimpleValueAnalog marks the given simple value as having no registered Go analog. If
// encountered during unmarshaling, an UnacceptableDataItemError will be returned.
func WithNoSimpleValueAnalog(sv SimpleValue) func(*SimpleValueRegistry) error {
return func(r *SimpleValueRegistry) error {
if sv >= 24 && sv <= 31 {
return fmt.Errorf("cbor: cannot set analog for reserved simple value %d", sv)
}
r.analogs[sv] = nil
return nil
}
}

// builtinAnalog wraps any of the built-in default simple value analogs. Simple values mapped to one
// of the built-in analogs are decoded more permissively than user-provided analogs, which must be
// directly assignable to a destination value.
type builtinAnalog struct {
v interface{}
}

// Creates a new SimpleValueRegistry. The registry state is initialized by executing the provided
// functions in order against an empty registry. Any simple value without a registered analog will
// produce an UnacceptableDataItemError if encountered in the input while unmarshaling.
func NewSimpleValueRegistry(fns ...func(*SimpleValueRegistry) error) (*SimpleValueRegistry, error) {
var r SimpleValueRegistry
for _, fn := range fns {
if err := fn(&r); err != nil {
return nil, err
}
}
return &r, nil
}

// Creates a new SimpleValueRegistry. The registry state is initialized by executing the provided
// functions in order against a registry that is pre-populated with the library defaults.
func NewSimpleValueRegistryFromDefaults(fns ...func(*SimpleValueRegistry) error) (*SimpleValueRegistry, error) {
var r SimpleValueRegistry
for _, fn := range append([]func(*SimpleValueRegistry) error{withDefaultSimpleValueAnalogs}, fns...) {
if err := fn(&r); err != nil {
return nil, err
}
}
return &r, nil
}

// withDefaultSimpleValueAnalogs registers Go analogs for false (false), true (true), null (nil),
// and undefined (nil). For the simple values numbering 0 through 19, inclusive, and 32 through 255,
// inclusive, registers the analog SimpleValue(N), where N is each respective simple value number.
func withDefaultSimpleValueAnalogs(r *SimpleValueRegistry) error {
var err error
for i := 0; i <= 255 && err == nil; i++ {
sv := SimpleValue(i)
switch {
case sv == 20:
err = WithSimpleValueAnalog(20, builtinAnalog{false})(r)
case sv == 21:
err = WithSimpleValueAnalog(21, builtinAnalog{true})(r)
case sv == 22: // null
err = WithSimpleValueAnalog(22, builtinAnalog{nil})(r)
case sv == 23: // undefined
err = WithSimpleValueAnalog(23, builtinAnalog{nil})(r)
case sv >= 24 && sv <= 31: // reserved
continue
default:
err = WithSimpleValueAnalog(sv, builtinAnalog{sv})(r)
}
}
return err
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -564,6 +664,16 @@ type DecOptions struct {
// UnrecognizedTagToAny specifies how to decode unrecognized CBOR tag into an empty interface.
// Currently, recognized CBOR tag numbers are 0, 1, 2, 3, or registered by TagSet.
UnrecognizedTagToAny UnrecognizedTagToAnyMode

// SimpleValues is an immutable mapping from CBOR simple value to a Go analog value. If nil,
// the simple values false, true, null, and undefined are mapped to Go analog values false,
// true, nil, and nil, respectively, and all other simple values N (except the reserved
// simple values 24 through 31) are mapped to cbor.SimpleValue(N). In other words, all
// well-formed simple values can be decoded.
//
// Users may construct a custom SimpleValueRegistry via NewSimpleValueRegistry or
// NewSimpleValueRegistryFromDefaults.
SimpleValues *SimpleValueRegistry
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -633,6 +743,14 @@ const (
maxMaxNestedLevels = 65535
)

var defaultSimpleValues = func() *SimpleValueRegistry {
registry, err := NewSimpleValueRegistry(withDefaultSimpleValueAnalogs)
if err != nil {
panic(err)
}
return registry
}()

func (opts DecOptions) decMode() (*decMode, error) {
if !opts.DupMapKey.valid() {
return nil, errors.New("cbor: invalid DupMapKey " + strconv.Itoa(int(opts.DupMapKey)))
Expand Down Expand Up @@ -716,6 +834,10 @@ func (opts DecOptions) decMode() (*decMode, error) {
if !opts.UnrecognizedTagToAny.valid() {
return nil, errors.New("cbor: invalid UnrecognizedTagToAnyMode " + strconv.Itoa(int(opts.UnrecognizedTagToAny)))
}
simpleValues := opts.SimpleValues
if simpleValues == nil {
simpleValues = defaultSimpleValues
}

dm := decMode{
dupMapKey: opts.DupMapKey,
Expand All @@ -736,6 +858,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
byteStringToString: opts.ByteStringToString,
fieldNameByteString: opts.FieldNameByteString,
unrecognizedTagToAny: opts.UnrecognizedTagToAny,
simpleValues: simpleValues,
}

return &dm, nil
Expand Down Expand Up @@ -807,12 +930,20 @@ type decMode struct {
byteStringToString ByteStringToStringMode
fieldNameByteString FieldNameByteStringMode
unrecognizedTagToAny UnrecognizedTagToAnyMode
simpleValues *SimpleValueRegistry
}

var defaultDecMode, _ = DecOptions{}.decMode()

// DecOptions returns user specified options used to create this DecMode.
func (dm *decMode) DecOptions() DecOptions {
simpleValues := dm.simpleValues
if simpleValues == defaultSimpleValues {
// Users can't explicitly set this to defaultSimpleValues. It must have been nil in
// the original DecOptions.
simpleValues = nil
}

return DecOptions{
DupMapKey: dm.dupMapKey,
TimeTag: dm.timeTag,
Expand All @@ -832,6 +963,7 @@ func (dm *decMode) DecOptions() DecOptions {
ByteStringToString: dm.byteStringToString,
FieldNameByteString: dm.fieldNameByteString,
UnrecognizedTagToAny: dm.unrecognizedTagToAny,
SimpleValues: simpleValues,
}
}

Expand Down Expand Up @@ -1154,14 +1286,43 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
f := math.Float64frombits(val)
return fillFloat(t, f, v)
default: // ai <= 24
switch ai {
case 20, 21:
return fillBool(t, ai == 21, v)
case 22, 23:
return fillNil(t, v)
default:
return fillPositiveInt(t, val, v)
analog := d.dm.simpleValues.analogs[SimpleValue(val)]
if analog == nil {
return &UnacceptableDataItemError{
CBORType: t.String(),
Message: "simple value " + strconv.FormatInt(int64(val), 10) + " is not recognized",
}
}

// Compatibility mode for simple value decoding using the default analogs.
if ba, ok := (*analog).(builtinAnalog); ok {
return fillBuiltinAnalog(t, ba, v)
}

if *analog == nil {
switch v.Kind() {
case reflect.Ptr, reflect.Func, reflect.Slice, reflect.Map, reflect.Chan, reflect.Interface:
// (reflect.Value) SetZero() was added in Go 1.20.
v.Set(reflect.Zero(v.Type()))
return nil
}
return &UnmarshalTypeError{
CBORType: t.String(),
GoType: v.Type().String(),
errorMsg: fmt.Sprintf("analog %v (%T) for simple value %d is not assignable to a value of this type", *analog, *analog, val),
}
}

av := reflect.ValueOf(*analog)
if !av.Type().AssignableTo(v.Type()) {
return &UnmarshalTypeError{
CBORType: t.String(),
GoType: v.Type().String(),
errorMsg: fmt.Sprintf("analog %v (%T) for simple value %d is not assignable to a value of this type", *analog, *analog, val),
}
}
v.Set(av)
return nil
}

case cborTypeTag:
Expand Down Expand Up @@ -1554,14 +1715,20 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
return Tag{tagNum, content}, nil
case cborTypePrimitives:
_, ai, val := d.getHead()
if ai < 20 || ai == 24 {
return SimpleValue(val), nil
if ai <= 24 {
analog := d.dm.simpleValues.analogs[SimpleValue(val)]
if analog == nil {
return nil, &UnacceptableDataItemError{
CBORType: t.String(),
Message: "simple value " + strconv.FormatInt(int64(val), 10) + " is not recognized",
}
}
if ba, ok := (*analog).(builtinAnalog); ok {
return ba.v, nil
}
return *analog, nil
}
switch ai {
case 20, 21:
return (ai == 21), nil
case 22, 23:
return nil, nil
case 25:
f := float64(float16.Frombits(uint16(val)).Float32())
return f, nil
Expand Down Expand Up @@ -2370,13 +2537,28 @@ var (
typeByteSlice = reflect.TypeOf([]byte(nil))
)

func fillNil(_ cborType, v reflect.Value) error {
switch v.Kind() {
case reflect.Slice, reflect.Map, reflect.Interface, reflect.Ptr:
v.Set(reflect.Zero(v.Type()))
func fillBuiltinAnalog(t cborType, analog builtinAnalog, v reflect.Value) error {
switch analog.v {
case nil:
switch v.Kind() {
case reflect.Slice, reflect.Map, reflect.Interface, reflect.Ptr:
v.Set(reflect.Zero(v.Type()))
default:
// no-op
}
return nil
case true, false:
if v.Kind() == reflect.Bool {
v.SetBool(analog.v.(bool))
return nil
}
default:
if sv, ok := (analog.v).(SimpleValue); ok {
return fillPositiveInt(t, uint64(sv), v)
}
}
return nil

return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

func fillPositiveInt(t cborType, val uint64, v reflect.Value) error {
Expand Down Expand Up @@ -2446,14 +2628,6 @@ func fillNegativeInt(t cborType, val int64, v reflect.Value) error {
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

func fillBool(t cborType, val bool, v reflect.Value) error {
if v.Kind() == reflect.Bool {
v.SetBool(val)
return nil
}
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

func fillFloat(t cborType, val float64, v reflect.Value) error {
switch v.Kind() {
case reflect.Float32, reflect.Float64:
Expand Down
Loading

0 comments on commit 58cd838

Please sign in to comment.