Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FieldNameMatching decode option. #433

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,26 @@ func (um UTF8Mode) valid() bool {
return um < maxUTF8Mode
}

// FieldNameMatchingMode specifies how string keys in CBOR maps are matched to Go struct field names.
type FieldNameMatchingMode int

const (
// FieldNameMatchingPreferCaseSensitive prefers to decode map items into struct fields whose names (or tag
// names) exactly match the item's key. If there is no such field, a map item will be decoded into a field whose
// name is a case-insensitive match for the item's key.
FieldNameMatchingPreferCaseSensitive = iota

// FieldNameMatchingCaseSensitive decodes map items only into a struct field whose name (or tag name) is an
// exact match for the item's key.
FieldNameMatchingCaseSensitive

maxFieldNameMatchingMode
)

func (fnmm FieldNameMatchingMode) valid() bool {
return fnmm >= 0 && fnmm < maxFieldNameMatchingMode
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -402,6 +422,9 @@ type DecOptions struct {
// UTF8 specifies if decoder should decode CBOR Text containing invalid UTF-8.
// By default, unmarshal rejects CBOR text containing invalid UTF-8.
UTF8 UTF8Mode

// FieldNameMatching specifies how string keys in CBOR maps are matched to Go struct field names.
FieldNameMatching FieldNameMatchingMode
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -510,6 +533,9 @@ func (opts DecOptions) decMode() (*decMode, error) {
if !opts.UTF8.valid() {
return nil, errors.New("cbor: invalid UTF8 " + strconv.Itoa(int(opts.UTF8)))
}
if !opts.FieldNameMatching.valid() {
return nil, errors.New("cbor: invalid FieldNameMatching " + strconv.Itoa(int(opts.FieldNameMatching)))
}
dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
Expand All @@ -523,6 +549,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
extraReturnErrors: opts.ExtraReturnErrors,
defaultMapType: opts.DefaultMapType,
utf8: opts.UTF8,
fieldNameMatching: opts.FieldNameMatching,
}
return &dm, nil
}
Expand Down Expand Up @@ -587,6 +614,7 @@ type decMode struct {
extraReturnErrors ExtraDecErrorCond
defaultMapType reflect.Type
utf8 UTF8Mode
fieldNameMatching FieldNameMatchingMode
}

var defaultDecMode, _ = DecOptions{}.decMode()
Expand All @@ -605,6 +633,7 @@ func (dm *decMode) DecOptions() DecOptions {
MapKeyByteString: dm.mapKeyByteString,
ExtraReturnErrors: dm.extraReturnErrors,
UTF8: dm.utf8,
FieldNameMatching: dm.fieldNameMatching,
}
}

Expand Down Expand Up @@ -1681,7 +1710,7 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n
}
}
// Find field with case-insensitive match
if f == nil {
if f == nil && d.dm.fieldNameMatching == FieldNameMatchingPreferCaseSensitive {
keyString := string(keyBytes)
for i := 0; i < len(structType.fields); i++ {
fld := structType.fields[i]
Expand Down
99 changes: 99 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6038,3 +6038,102 @@ func TestUnmarshalFirstInvalidItem(t *testing.T) {
t.Errorf("UnmarshalFirst(0x%x) = (%x, %v), want (nil, err)", invalidCBOR, rest, err)
}
}

func TestDecModeInvalidFieldNameMatchingMode(t *testing.T) {
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{FieldNameMatching: -1},
wantErrorMsg: "cbor: invalid FieldNameMatching -1",
},
{
name: "above range of valid modes",
opts: DecOptions{FieldNameMatching: 101},
wantErrorMsg: "cbor: invalid FieldNameMatching 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

func TestDecodeFieldNameMatching(t *testing.T) {
type s struct {
LowerA int `cbor:"a"`
UpperB int `cbor:"B"`
LowerB int `cbor:"b"`
}

testCases := []struct {
name string
opts DecOptions
cborData []byte
wantValue s
}{
{
name: "case-insensitive match",
cborData: hexDecode("a1614101"), // {"A": 1}
wantValue: s{LowerA: 1},
},
{
name: "ignore case-insensitive match",
opts: DecOptions{FieldNameMatching: FieldNameMatchingCaseSensitive},
cborData: hexDecode("a1614101"), // {"A": 1}
wantValue: s{},
},
{
name: "exact match before case-insensitive match",
cborData: hexDecode("a2616101614102"), // {"a": 1, "A": 2}
wantValue: s{LowerA: 1},
},
{
name: "case-insensitive match before exact match",
cborData: hexDecode("a2614101616102"), // {"A": 1, "a": 2}
wantValue: s{LowerA: 1},
},
{
name: "ignore case-insensitive match before exact match",
opts: DecOptions{FieldNameMatching: FieldNameMatchingCaseSensitive},
cborData: hexDecode("a2614101616102"), // {"A": 1, "a": 2}
wantValue: s{LowerA: 2},
},
{
name: "earliest exact match wins",
opts: DecOptions{FieldNameMatching: FieldNameMatchingCaseSensitive},
cborData: hexDecode("a2616101616102"), // {"a": 1, "a": 2} (invalid)
wantValue: s{LowerA: 1},
},
{
// the field tags themselves are case-insensitive matches for each other
name: "duplicate keys decode to different fields",
cborData: hexDecode("a2614201614202"), // {"B": 1, "B": 2} (invalid)
wantValue: s{UpperB: 1, LowerB: 2},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
decMode, _ := tc.opts.DecMode()

var dst s
err := decMode.Unmarshal(tc.cborData, &dst)
if err != nil {
t.Fatalf("Unmarshal(0x%x) returned unexpected error %v", tc.cborData, err)
}

if !reflect.DeepEqual(dst, tc.wantValue) {
t.Errorf("Unmarshal(0x%x) = %#v, want %#v", tc.cborData, dst, tc.wantValue)
}
})
}
}
Loading