Skip to content

Commit

Permalink
Merge pull request #433 from benluddy/field-name-matching-decode-option
Browse files Browse the repository at this point in the history
Add FieldNameMatching decode option.
  • Loading branch information
fxamacker authored Nov 4, 2023
2 parents cd38439 + 9e247e0 commit 4687659
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 1 deletion.
31 changes: 30 additions & 1 deletion decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,26 @@ func (um UTF8Mode) valid() bool {
return um < maxUTF8Mode
}

// FieldNameMatchingMode specifies how string keys in CBOR maps are matched to Go struct field names.
type FieldNameMatchingMode int

const (
// FieldNameMatchingPreferCaseSensitive prefers to decode map items into struct fields whose names (or tag
// names) exactly match the item's key. If there is no such field, a map item will be decoded into a field whose
// name is a case-insensitive match for the item's key.
FieldNameMatchingPreferCaseSensitive = iota

// FieldNameMatchingCaseSensitive decodes map items only into a struct field whose name (or tag name) is an
// exact match for the item's key.
FieldNameMatchingCaseSensitive

maxFieldNameMatchingMode
)

func (fnmm FieldNameMatchingMode) valid() bool {
return fnmm >= 0 && fnmm < maxFieldNameMatchingMode
}

// DecOptions specifies decoding options.
type DecOptions struct {
// DupMapKey specifies whether to enforce duplicate map key.
Expand Down Expand Up @@ -402,6 +422,9 @@ type DecOptions struct {
// UTF8 specifies if decoder should decode CBOR Text containing invalid UTF-8.
// By default, unmarshal rejects CBOR text containing invalid UTF-8.
UTF8 UTF8Mode

// FieldNameMatching specifies how string keys in CBOR maps are matched to Go struct field names.
FieldNameMatching FieldNameMatchingMode
}

// DecMode returns DecMode with immutable options and no tags (safe for concurrency).
Expand Down Expand Up @@ -510,6 +533,9 @@ func (opts DecOptions) decMode() (*decMode, error) {
if !opts.UTF8.valid() {
return nil, errors.New("cbor: invalid UTF8 " + strconv.Itoa(int(opts.UTF8)))
}
if !opts.FieldNameMatching.valid() {
return nil, errors.New("cbor: invalid FieldNameMatching " + strconv.Itoa(int(opts.FieldNameMatching)))
}
dm := decMode{
dupMapKey: opts.DupMapKey,
timeTag: opts.TimeTag,
Expand All @@ -523,6 +549,7 @@ func (opts DecOptions) decMode() (*decMode, error) {
extraReturnErrors: opts.ExtraReturnErrors,
defaultMapType: opts.DefaultMapType,
utf8: opts.UTF8,
fieldNameMatching: opts.FieldNameMatching,
}
return &dm, nil
}
Expand Down Expand Up @@ -587,6 +614,7 @@ type decMode struct {
extraReturnErrors ExtraDecErrorCond
defaultMapType reflect.Type
utf8 UTF8Mode
fieldNameMatching FieldNameMatchingMode
}

var defaultDecMode, _ = DecOptions{}.decMode()
Expand All @@ -605,6 +633,7 @@ func (dm *decMode) DecOptions() DecOptions {
MapKeyByteString: dm.mapKeyByteString,
ExtraReturnErrors: dm.extraReturnErrors,
UTF8: dm.utf8,
FieldNameMatching: dm.fieldNameMatching,
}
}

Expand Down Expand Up @@ -1681,7 +1710,7 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n
}
}
// Find field with case-insensitive match
if f == nil {
if f == nil && d.dm.fieldNameMatching == FieldNameMatchingPreferCaseSensitive {
keyString := string(keyBytes)
for i := 0; i < len(structType.fields); i++ {
fld := structType.fields[i]
Expand Down
99 changes: 99 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6038,3 +6038,102 @@ func TestUnmarshalFirstInvalidItem(t *testing.T) {
t.Errorf("UnmarshalFirst(0x%x) = (%x, %v), want (nil, err)", invalidCBOR, rest, err)
}
}

func TestDecModeInvalidFieldNameMatchingMode(t *testing.T) {
for _, tc := range []struct {
name string
opts DecOptions
wantErrorMsg string
}{
{
name: "below range of valid modes",
opts: DecOptions{FieldNameMatching: -1},
wantErrorMsg: "cbor: invalid FieldNameMatching -1",
},
{
name: "above range of valid modes",
opts: DecOptions{FieldNameMatching: 101},
wantErrorMsg: "cbor: invalid FieldNameMatching 101",
},
} {
t.Run(tc.name, func(t *testing.T) {
_, err := tc.opts.DecMode()
if err == nil {
t.Errorf("DecMode() didn't return an error")
} else if err.Error() != tc.wantErrorMsg {
t.Errorf("DecMode() returned error %q, want %q", err.Error(), tc.wantErrorMsg)
}
})
}
}

func TestDecodeFieldNameMatching(t *testing.T) {
type s struct {
LowerA int `cbor:"a"`
UpperB int `cbor:"B"`
LowerB int `cbor:"b"`
}

testCases := []struct {
name string
opts DecOptions
cborData []byte
wantValue s
}{
{
name: "case-insensitive match",
cborData: hexDecode("a1614101"), // {"A": 1}
wantValue: s{LowerA: 1},
},
{
name: "ignore case-insensitive match",
opts: DecOptions{FieldNameMatching: FieldNameMatchingCaseSensitive},
cborData: hexDecode("a1614101"), // {"A": 1}
wantValue: s{},
},
{
name: "exact match before case-insensitive match",
cborData: hexDecode("a2616101614102"), // {"a": 1, "A": 2}
wantValue: s{LowerA: 1},
},
{
name: "case-insensitive match before exact match",
cborData: hexDecode("a2614101616102"), // {"A": 1, "a": 2}
wantValue: s{LowerA: 1},
},
{
name: "ignore case-insensitive match before exact match",
opts: DecOptions{FieldNameMatching: FieldNameMatchingCaseSensitive},
cborData: hexDecode("a2614101616102"), // {"A": 1, "a": 2}
wantValue: s{LowerA: 2},
},
{
name: "earliest exact match wins",
opts: DecOptions{FieldNameMatching: FieldNameMatchingCaseSensitive},
cborData: hexDecode("a2616101616102"), // {"a": 1, "a": 2} (invalid)
wantValue: s{LowerA: 1},
},
{
// the field tags themselves are case-insensitive matches for each other
name: "duplicate keys decode to different fields",
cborData: hexDecode("a2614201614202"), // {"B": 1, "B": 2} (invalid)
wantValue: s{UpperB: 1, LowerB: 2},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
decMode, _ := tc.opts.DecMode()

var dst s
err := decMode.Unmarshal(tc.cborData, &dst)
if err != nil {
t.Fatalf("Unmarshal(0x%x) returned unexpected error %v", tc.cborData, err)
}

if !reflect.DeepEqual(dst, tc.wantValue) {
t.Errorf("Unmarshal(0x%x) = %#v, want %#v", tc.cborData, dst, tc.wantValue)
}
})
}
}

0 comments on commit 4687659

Please sign in to comment.