Skip to content

Commit

Permalink
Add option to set arbitrary simple value to Go value mappings.
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Luddy <bluddy@redhat.com>
  • Loading branch information
benluddy committed Mar 11, 2024
1 parent 4a6e6d1 commit 2f65e3b
Show file tree
Hide file tree
Showing 2 changed files with 384 additions and 26 deletions.
198 changes: 172 additions & 26 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,80 @@ func (uttam UnrecognizedTagToAnyMode) valid() bool {
return uttam >= 0 && uttam < maxUnrecognizedTagToAny
}

// SimpleValueRegistry is an immutable mapping from CBOR simple value number (0...23 and 32...255)
// to Go analog value.
type SimpleValueRegistry struct {
analogs [256]*interface{}
}

// WithSimpleValueAnalog registers a Go analog value for the given simple value. When decoding into
// an empty interface value, the registered analog value is returned. When decoding into a concrete
// type, the type of the registered analog must be directly assignable to the destination's type.
func WithSimpleValueAnalog(sv SimpleValue, analog interface{}) func(*SimpleValueRegistry) error {
return func(r *SimpleValueRegistry) error {
if sv >= 24 && sv <= 31 {
return fmt.Errorf("cbor: cannot set analog for reserved simple value %d", sv)
}
r.analogs[sv] = &analog
return nil
}
}

// WithNoSimpleValueAnalog marks the given simple value as having no registered Go analog.
func WithNoSimpleValueAnalog(sv SimpleValue) func(*SimpleValueRegistry) error {
return func(r *SimpleValueRegistry) error {
if sv >= 24 && sv <= 31 {
return fmt.Errorf("cbor: cannot set analog for reserved simple value %d", sv)
}
r.analogs[sv] = nil
return nil
}
}

// builtinAnalog wraps any of the built-in default simple value analogs. Simple values mapped to one
// of the built-in analogs are decoded more permissively than user-provided analogs, which must be
// directly assignable to a destination value.
type builtinAnalog struct {
v interface{}
}

// WithDefaultSimpleValueAnalogs registers Go analogs for false (false), true (true), null (nil),
// and undefined (nil). For the simple values numbering 0 through 19, inclusive, and 32 through 255,
// inclusive, registers the analog SimpleValue(N), where N is each respective simple value number.
func WithDefaultSimpleValueAnalogs(r *SimpleValueRegistry) error {
var err error
for i := 0; i <= 255 && err == nil; i++ {
sv := SimpleValue(i)
switch {
case sv == 20:
err = WithSimpleValueAnalog(20, builtinAnalog{false})(r)
case sv == 21:
err = WithSimpleValueAnalog(21, builtinAnalog{true})(r)
case sv == 22: // null
err = WithSimpleValueAnalog(22, builtinAnalog{nil})(r)
case sv == 23: // undefined
err = WithSimpleValueAnalog(23, builtinAnalog{nil})(r)
case sv >= 24 && sv <= 31: // reserved
continue
default:
err = WithSimpleValueAnalog(sv, builtinAnalog{sv})(r)
}
}
return err
}

// Creates a new SimpleValueRegistry. The registry state is initialized by executing the provided
// functions in order against an empty registry.
func NewSimpleValueRegistry(fns ...func(*SimpleValueRegistry) error) (*SimpleValueRegistry, error) {
var r SimpleValueRegistry
for _, fn := range fns {
if err := fn(&r); err != nil {
return nil, err
}
}
return &r, nil
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -543,6 +617,12 @@ type DecOptions struct {
// UnrecognizedTagToAny specifies how to decode unrecognized CBOR tag into an empty interface.
// Currently, recognized CBOR tag numbers are 0, 1, 2, 3, or registered by TagSet.
UnrecognizedTagToAny UnrecognizedTagToAnyMode

// SimpleValues is an immutable mapping from CBOR simple value to a Go analog value. If nil,
// the simple values false, true, null, and undefined are mapped to Go analog values false,
// true, nil, and nil, respectively, and all other simple values N (except the reserved
// simple values 24 through 31) are mapped to cbor.SimpleValue(N).
SimpleValues *SimpleValueRegistry
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -612,6 +692,14 @@ const (
maxMaxNestedLevels = 65535
)

var defaultSimpleValues = func() *SimpleValueRegistry {
registry, err := NewSimpleValueRegistry(WithDefaultSimpleValueAnalogs)
if err != nil {
panic(err)
}
return registry
}()

func (opts DecOptions) decMode() (*decMode, error) {
if !opts.DupMapKey.valid() {
return nil, errors.New("cbor: invalid DupMapKey " + strconv.Itoa(int(opts.DupMapKey)))
Expand Down Expand Up @@ -695,6 +783,10 @@ func (opts DecOptions) decMode() (*decMode, error) {
if !opts.UnrecognizedTagToAny.valid() {
return nil, errors.New("cbor: invalid UnrecognizedTagToAnyMode " + strconv.Itoa(int(opts.UnrecognizedTagToAny)))
}
simpleValues := opts.SimpleValues
if simpleValues == nil {
simpleValues = defaultSimpleValues
}

dm := decMode{
dupMapKey: opts.DupMapKey,
Expand All @@ -715,6 +807,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
byteStringToString: opts.ByteStringToString,
fieldNameByteString: opts.FieldNameByteString,
unrecognizedTagToAny: opts.UnrecognizedTagToAny,
simpleValues: simpleValues,
}

return &dm, nil
Expand Down Expand Up @@ -786,12 +879,20 @@ type decMode struct {
byteStringToString ByteStringToStringMode
fieldNameByteString FieldNameByteStringMode
unrecognizedTagToAny UnrecognizedTagToAnyMode
simpleValues *SimpleValueRegistry
}

var defaultDecMode, _ = DecOptions{}.decMode()

// DecOptions returns user specified options used to create this DecMode.
func (dm *decMode) DecOptions() DecOptions {
simpleValues := dm.simpleValues
if simpleValues == defaultSimpleValues {
// Users can't explicitly set this to defaultSimpleValues. It must have been nil in
// the original DecOptions.
simpleValues = nil
}

return DecOptions{
DupMapKey: dm.dupMapKey,
TimeTag: dm.timeTag,
Expand All @@ -811,6 +912,7 @@ func (dm *decMode) DecOptions() DecOptions {
ByteStringToString: dm.byteStringToString,
FieldNameByteString: dm.fieldNameByteString,
UnrecognizedTagToAny: dm.unrecognizedTagToAny,
SimpleValues: simpleValues,
}
}

Expand Down Expand Up @@ -1131,14 +1233,44 @@ func (d *decoder) parseToValue(v reflect.Value, tInfo *typeInfo) error { //nolin
f := math.Float64frombits(val)
return fillFloat(t, f, v)
default: // ai <= 24
switch ai {
case 20, 21:
return fillBool(t, ai == 21, v)
case 22, 23:
return fillNil(t, v)
default:
return fillPositiveInt(t, val, v)
analog := d.dm.simpleValues.analogs[SimpleValue(val)]
if analog == nil {
return &UnmarshalTypeError{
CBORType: t.String(),
GoType: v.Type().String(),
errorMsg: "simple value " + strconv.FormatInt(int64(val), 10) + " is not recognized",
}
}

// Compatibility mode for simple value decoding using the default analogs.
if ba, ok := (*analog).(builtinAnalog); ok {
return fillBuiltinAnalog(t, ba, v)
}

if *analog == nil {
switch v.Kind() {
case reflect.Ptr, reflect.Func, reflect.Slice, reflect.Map, reflect.Chan, reflect.Interface:
// (reflect.Value) SetZero() was added in Go 1.20.
v.Set(reflect.Zero(v.Type()))
return nil
}
return &UnmarshalTypeError{
CBORType: t.String(),
GoType: v.Type().String(),
errorMsg: fmt.Sprintf("analog %v (%T) for simple value %d is not assignable to a value of this type", *analog, *analog, val),
}
}

av := reflect.ValueOf(*analog)
if !av.Type().AssignableTo(v.Type()) {
return &UnmarshalTypeError{
CBORType: t.String(),
GoType: v.Type().String(),
errorMsg: fmt.Sprintf("analog %v (%T) for simple value %d is not assignable to a value of this type", *analog, *analog, val),
}
}
v.Set(av)
return nil
}

case cborTypeTag:
Expand Down Expand Up @@ -1496,14 +1628,21 @@ func (d *decoder) parse(skipSelfDescribedTag bool) (interface{}, error) { //noli
return Tag{tagNum, content}, nil
case cborTypePrimitives:
_, ai, val := d.getHead()
if ai < 20 || ai == 24 {
return SimpleValue(val), nil
if ai <= 24 {
analog := d.dm.simpleValues.analogs[SimpleValue(val)]
if analog == nil {
return nil, &UnmarshalTypeError{
CBORType: t.String(),
GoType: "interface{}",
errorMsg: "simple value " + strconv.FormatInt(int64(val), 10) + " is not recognized",
}
}
if ba, ok := (*analog).(builtinAnalog); ok {
return ba.v, nil
}
return *analog, nil
}
switch ai {
case 20, 21:
return (ai == 21), nil
case 22, 23:
return nil, nil
case 25:
f := float64(float16.Frombits(uint16(val)).Float32())
return f, nil
Expand Down Expand Up @@ -2312,13 +2451,28 @@ var (
typeByteSlice = reflect.TypeOf([]byte(nil))
)

func fillNil(_ cborType, v reflect.Value) error {
switch v.Kind() {
case reflect.Slice, reflect.Map, reflect.Interface, reflect.Ptr:
v.Set(reflect.Zero(v.Type()))
func fillBuiltinAnalog(t cborType, analog builtinAnalog, v reflect.Value) error {
switch analog.v {
case nil:
switch v.Kind() {
case reflect.Slice, reflect.Map, reflect.Interface, reflect.Ptr:
v.Set(reflect.Zero(v.Type()))
default:
// no-op
}
return nil
case true, false:
if v.Kind() == reflect.Bool {
v.SetBool(analog.v.(bool))
return nil
}
default:
if sv, ok := (analog.v).(SimpleValue); ok {
return fillPositiveInt(t, uint64(sv), v)
}
}
return nil

return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

func fillPositiveInt(t cborType, val uint64, v reflect.Value) error {
Expand Down Expand Up @@ -2388,14 +2542,6 @@ func fillNegativeInt(t cborType, val int64, v reflect.Value) error {
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

func fillBool(t cborType, val bool, v reflect.Value) error {
if v.Kind() == reflect.Bool {
v.SetBool(val)
return nil
}
return &UnmarshalTypeError{CBORType: t.String(), GoType: v.Type().String()}
}

func fillFloat(t cborType, val float64, v reflect.Value) error {
switch v.Kind() {
case reflect.Float32, reflect.Float64:
Expand Down
Loading

0 comments on commit 2f65e3b

Please sign in to comment.