diff --git a/bench_test.go b/bench_test.go index 3dd312f0..6f0878fa 100644 --- a/bench_test.go +++ b/bench_test.go @@ -5,6 +5,7 @@ package cbor import ( "bytes" + "fmt" "io" "reflect" "testing" @@ -743,3 +744,209 @@ func BenchmarkMarshalCOSEMACWithTag(b *testing.B) { } } } + +func BenchmarkUnmarshalMapToStruct(b *testing.B) { + type S struct { + A, B, C, D, E, F, G, H, I, J, K, L, M bool + } + + var ( + allKnownFields = hexDecode("ad6141f56142f56143f56144f56145f56146f56147f56148f56149f5614af5614bf5614cf5614df5") // {"A": true, ... "M": true } + allKnownDuplicateFields = hexDecode("ad6141f56141f56141f56141f56141f56141f56141f56141f56141f56141f56141f56141f56141f5") // {"A": true, "A": true, "A": true, ...} + allUnknownFields = hexDecode("ad614ef5614ff56150f56151f56152f56153f56154f56155f56156f56157f56158f56159f5615af5") // {"N": true, ... "Z": true } + allUnknownDuplicateFields = hexDecode("ad614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5") // {"N": true, "N": true, "N": true, ...} + ) + + type ManyFields struct { + AA, AB, AC, AD, AE, AF, AG, AH, AI, AJ, AK, AL, AM, AN, AO, AP, AQ, AR, AS, AT, AU, AV, AW, AX, AY, AZ bool + BA, BB, BC, BD, BE, BF, BG, BH, BI, BJ, BK, BL, BM, BN, BO, BP, BQ, BR, BS, BT, BU, BV, BW, BX, BY, BZ bool + CA, CB, CC, CD, CE, CF, CG, CH, CI, CJ, CK, CL, CM, CN, CO, CP, CQ, CR, CS, CT, CU, CV, CW, CX, CY, CZ bool + DA, DB, DC, DD, DE, DF, DG, DH, DI, DJ, DK, DL, DM, DN, DO, DP, DQ, DR, DS, DT, DU, DV, DW, DX, DY, DZ bool + } + var manyFieldsOneKeyPerField []byte + { + // An EncOption that accepts a function to sort or shuffle keys might be useful for + // cases like this. Here we are manually encoding the fields in reverse order to + // target worst-case key-to-field matching. + rt := reflect.TypeOf(ManyFields{}) + var buf bytes.Buffer + if rt.NumField() > 255 { + b.Fatalf("invalid test assumption: ManyFields expected to have no more than 255 fields, has %d", rt.NumField()) + } + buf.WriteByte(0xb8) + buf.WriteByte(byte(rt.NumField())) + for i := rt.NumField() - 1; i >= 0; i-- { // backwards + f := rt.Field(i) + if len(f.Name) > 23 { + b.Fatalf("invalid test assumption: field name %q longer than 23 bytes", f.Name) + } + buf.WriteByte(byte(0x60 + len(f.Name))) + buf.WriteString(f.Name) + buf.WriteByte(0xf5) // true + } + manyFieldsOneKeyPerField = buf.Bytes() + } + + type input struct { + name string + data []byte + into interface{} + reject bool + } + + for _, tc := range []struct { + name string + opts DecOptions + inputs []input + }{ + { + name: "default options", + opts: DecOptions{}, + inputs: []input{ + { + name: "all known fields", + data: allKnownFields, + into: S{}, + reject: false, + }, + { + name: "all known duplicate fields", + data: allKnownDuplicateFields, + into: S{}, + reject: false, + }, + { + name: "all unknown fields", + data: allUnknownFields, + into: S{}, + reject: false, + }, + { + name: "all unknown duplicate fields", + data: allUnknownDuplicateFields, + into: S{}, + reject: false, + }, + { + name: "many fields one key per field", + data: manyFieldsOneKeyPerField, + into: ManyFields{}, + reject: false, + }, + }, + }, + { + name: "reject unknown", + opts: DecOptions{ExtraReturnErrors: ExtraDecErrorUnknownField}, + inputs: []input{ + { + name: "all known fields", + data: allKnownFields, + into: S{}, + reject: false, + }, + { + name: "all known duplicate fields", + data: allKnownDuplicateFields, + into: S{}, + reject: false, + }, + { + name: "all unknown fields", + data: allUnknownFields, + into: S{}, + reject: true, + }, + { + name: "all unknown duplicate fields", + data: allUnknownDuplicateFields, + into: S{}, + reject: true, + }, + }, + }, + { + name: "reject duplicate", + opts: DecOptions{DupMapKey: DupMapKeyEnforcedAPF}, + inputs: []input{ + { + name: "all known fields", + data: allKnownFields, + into: S{}, + reject: false, + }, + { + name: "all known duplicate fields", + data: allKnownDuplicateFields, + into: S{}, + reject: true, + }, + { + name: "all unknown fields", + data: allUnknownFields, + into: S{}, + reject: false, + }, + { + name: "all unknown duplicate fields", + data: allUnknownDuplicateFields, + into: S{}, + reject: true, + }, + }, + }, + { + name: "reject unknown and duplicate", + opts: DecOptions{ + DupMapKey: DupMapKeyEnforcedAPF, + ExtraReturnErrors: ExtraDecErrorUnknownField, + }, + inputs: []input{ + { + name: "all known fields", + data: allKnownFields, + into: S{}, + reject: false, + }, + { + name: "all known duplicate fields", + data: allKnownDuplicateFields, + into: S{}, + reject: true, + }, + { + name: "all unknown fields", + data: allUnknownFields, + into: S{}, + reject: true, + }, + { + name: "all unknown duplicate fields", + data: allUnknownDuplicateFields, + into: S{}, + reject: true, + }, + }, + }, + } { + for _, in := range tc.inputs { + b.Run(fmt.Sprintf("%s/%s", tc.name, in.name), func(b *testing.B) { + dm, err := tc.opts.DecMode() + if err != nil { + b.Fatal(err) + } + + dst := reflect.New(reflect.TypeOf(in.into)).Interface() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := dm.Unmarshal(in.data, dst); !in.reject && err != nil { + b.Fatalf("unexpected error: %v", err) + } else if in.reject && err == nil { + b.Fatal("expected non-nil error") + } + } + }) + } + } +} diff --git a/cache.go b/cache.go index 8a4a5c87..cd83d8a9 100644 --- a/cache.go +++ b/cache.go @@ -6,6 +6,7 @@ package cbor import ( "bytes" "errors" + "fmt" "reflect" "sort" "strconv" @@ -84,9 +85,25 @@ func newTypeInfo(t reflect.Type) *typeInfo { } type decodingStructType struct { - fields fields - err error - toArray bool + fields fields + fieldIndicesByName map[string]int + err error + toArray bool +} + +// The stdlib errors.Join was introduced in Go 1.20, and we still support Go 1.17, so instead, +// here's a very basic implementation of an aggregated error. +type multierror []error + +func (m multierror) Error() string { + var sb strings.Builder + for i, err := range m { + sb.WriteString(err.Error()) + if i < len(m)-1 { + sb.WriteString(", ") + } + } + return sb.String() } func getDecodingStructType(t reflect.Type) *decodingStructType { @@ -98,12 +115,12 @@ func getDecodingStructType(t reflect.Type) *decodingStructType { toArray := hasToArrayOption(structOptions) - var err error + var errs []error for i := 0; i < len(flds); i++ { if flds[i].keyAsInt { nameAsInt, numErr := strconv.Atoi(flds[i].name) if numErr != nil { - err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")") + errs = append(errs, errors.New("cbor: failed to parse field name \""+flds[i].name+"\" to int ("+numErr.Error()+")")) break } flds[i].nameAsInt = int64(nameAsInt) @@ -112,7 +129,36 @@ func getDecodingStructType(t reflect.Type) *decodingStructType { flds[i].typInfo = getTypeInfo(flds[i].typ) } - structType := &decodingStructType{fields: flds, err: err, toArray: toArray} + fieldIndicesByName := make(map[string]int, len(flds)) + for i, fld := range flds { + if _, ok := fieldIndicesByName[fld.name]; ok { + errs = append(errs, fmt.Errorf("cbor: two or more fields of %v have the same name %q", t, fld.name)) + continue + } + fieldIndicesByName[fld.name] = i + } + + var err error + { + var multi multierror + for _, each := range errs { + if each != nil { + multi = append(multi, each) + } + } + if len(multi) == 1 { + err = multi[0] + } else if len(multi) > 1 { + err = multi + } + } + + structType := &decodingStructType{ + fields: flds, + fieldIndicesByName: fieldIndicesByName, + err: err, + toArray: toArray, + } decodingStructTypeCache.Store(t, structType) return structType } diff --git a/decode.go b/decode.go index 0b44124d..e1e07a59 100644 --- a/decode.go +++ b/decode.go @@ -207,7 +207,12 @@ func (e *UnknownFieldError) Error() string { return fmt.Sprintf("cbor: found unknown field at map element index %d", e.Index) } -// DupMapKeyMode specifies how to enforce duplicate map key. +// DupMapKeyMode specifies how to enforce duplicate map key. Two map keys are considered duplicates if: +// 1. When decoding into a struct, both keys match the same struct field. The keys are also +// considered duplicates if neither matches any field and decoding to interface{} would produce +// equal (==) values for both keys. +// 2. When decoding into a map, both keys are equal (==) when decoded into values of the +// destination map's key type. type DupMapKeyMode int const ( @@ -1893,20 +1898,32 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n count := int(val) // Keeps track of matched struct fields - foundFldIdx := make([]bool, len(structType.fields)) + var foundFldIdx []bool + { + const maxStackFields = 128 + if nfields := len(structType.fields); nfields <= maxStackFields { + // For structs with typical field counts, expect that this can be + // stack-allocated. + var a [maxStackFields]bool + foundFldIdx = a[:nfields] + } else { + foundFldIdx = make([]bool, len(structType.fields)) + } + } // Keeps track of CBOR map keys to detect duplicate map key keyCount := 0 var mapKeys map[interface{}]struct{} - if d.dm.dupMapKey == DupMapKeyEnforcedAPF { - mapKeys = make(map[interface{}]struct{}, len(structType.fields)) - } errOnUnknownField := (d.dm.extraReturnErrors & ExtraDecErrorUnknownField) > 0 +MapEntryLoop: for j := 0; (hasSize && j < count) || (!hasSize && !d.foundBreak()); j++ { var f *field - var k interface{} // Used by duplicate map key detection + + // If duplicate field detection is enabled and the key at index j did not match any + // field, k will hold the map key. + var k interface{} t := d.nextCBORType() if t == cborTypeTextString || (t == cborTypeByteString && d.dm.fieldNameByteString == FieldNameByteStringAllowed) { @@ -1924,30 +1941,61 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n keyBytes, _ = d.parseByteString() } - keyLen := len(keyBytes) - // Find field with exact match - for i := 0; i < len(structType.fields); i++ { + // Check for exact match on field name. + if i, ok := structType.fieldIndicesByName[string(keyBytes)]; ok { fld := structType.fields[i] - if !foundFldIdx[i] && len(fld.name) == keyLen && fld.name == string(keyBytes) { + + if !foundFldIdx[i] { f = fld foundFldIdx[i] = true - break + } else if d.dm.dupMapKey == DupMapKeyEnforcedAPF { + err = &DupMapKeyError{fld.name, j} + d.skip() // skip value + j++ + // skip the rest of the map + for ; (hasSize && j < count) || (!hasSize && !d.foundBreak()); j++ { + d.skip() + d.skip() + } + return err + } else { + // discard repeated match + d.skip() + continue MapEntryLoop } } + // Find field with case-insensitive match if f == nil && d.dm.fieldNameMatching == FieldNameMatchingPreferCaseSensitive { + keyLen := len(keyBytes) keyString := string(keyBytes) for i := 0; i < len(structType.fields); i++ { fld := structType.fields[i] - if !foundFldIdx[i] && len(fld.name) == keyLen && strings.EqualFold(fld.name, keyString) { - f = fld - foundFldIdx[i] = true + if len(fld.name) == keyLen && strings.EqualFold(fld.name, keyString) { + if !foundFldIdx[i] { + f = fld + foundFldIdx[i] = true + } else if d.dm.dupMapKey == DupMapKeyEnforcedAPF { + err = &DupMapKeyError{keyString, j} + d.skip() // skip value + j++ + // skip the rest of the map + for ; (hasSize && j < count) || (!hasSize && !d.foundBreak()); j++ { + d.skip() + d.skip() + } + return err + } else { + // discard repeated match + d.skip() + continue MapEntryLoop + } break } } } - if d.dm.dupMapKey == DupMapKeyEnforcedAPF { + if d.dm.dupMapKey == DupMapKeyEnforcedAPF && f == nil { k = string(keyBytes) } } else if t <= cborTypeNegativeInt { // uint/int @@ -1975,14 +2023,30 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n // Find field for i := 0; i < len(structType.fields); i++ { fld := structType.fields[i] - if !foundFldIdx[i] && fld.keyAsInt && fld.nameAsInt == nameAsInt { - f = fld - foundFldIdx[i] = true + if fld.keyAsInt && fld.nameAsInt == nameAsInt { + if !foundFldIdx[i] { + f = fld + foundFldIdx[i] = true + } else if d.dm.dupMapKey == DupMapKeyEnforcedAPF { + err = &DupMapKeyError{nameAsInt, j} + d.skip() // skip value + j++ + // skip the rest of the map + for ; (hasSize && j < count) || (!hasSize && !d.foundBreak()); j++ { + d.skip() + d.skip() + } + return err + } else { + // discard repeated match + d.skip() + continue MapEntryLoop + } break } } - if d.dm.dupMapKey == DupMapKeyEnforcedAPF { + if d.dm.dupMapKey == DupMapKeyEnforcedAPF && f == nil { k = nameAsInt } } else { @@ -2010,23 +2074,6 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n } } - if d.dm.dupMapKey == DupMapKeyEnforcedAPF { - mapKeys[k] = struct{}{} - newKeyCount := len(mapKeys) - if newKeyCount == keyCount { - err = &DupMapKeyError{k, j} - d.skip() // skip value - j++ - // skip the rest of the map - for ; (hasSize && j < count) || (!hasSize && !d.foundBreak()); j++ { - d.skip() - d.skip() - } - return err - } - keyCount = newKeyCount - } - if f == nil { if errOnUnknownField { err = &UnknownFieldError{j} @@ -2039,6 +2086,31 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n } return err } + + // Two map keys that match the same struct field are immediately considered + // duplicates. This check detects duplicates between two map keys that do + // not match a struct field. If unknown field errors are enabled, then this + // check is never reached. + if d.dm.dupMapKey == DupMapKeyEnforcedAPF { + if mapKeys == nil { + mapKeys = make(map[interface{}]struct{}, 1) + } + mapKeys[k] = struct{}{} + newKeyCount := len(mapKeys) + if newKeyCount == keyCount { + err = &DupMapKeyError{k, j} + d.skip() // skip value + j++ + // skip the rest of the map + for ; (hasSize && j < count) || (!hasSize && !d.foundBreak()); j++ { + d.skip() + d.skip() + } + return err + } + keyCount = newKeyCount + } + d.skip() // Skip value continue } diff --git a/decode_test.go b/decode_test.go index 88629349..5fcea786 100644 --- a/decode_test.go +++ b/decode_test.go @@ -5476,33 +5476,82 @@ func TestUnmarshalDupMapKeyToStruct(t *testing.T) { C string `cbor:"c"` D string `cbor:"d"` E string `cbor:"e"` - } - data := hexDecode("a6616161416162614261636143616161466164614461656145") // {"a": "A", "b": "B", "c": "C", "a": "F", "d": "D", "e": "E"} - // Duplicate key doesn't overwrite previous value (default). - wantS := s{A: "A", B: "B", C: "C", D: "D", E: "E"} - var s1 s - if err := Unmarshal(data, &s1); err != nil { - t.Errorf("Unmarshal(0x%x) returned error %v", data, err) - } - if !reflect.DeepEqual(s1, wantS) { - t.Errorf("Unmarshal(0x%x) = %+v (%T), want %+v (%T)", data, s1, s1, wantS, wantS) + I string `cbor:"1,keyasint"` } - // Duplicate key triggers error. - wantS = s{A: "A", B: "B", C: "C"} - wantErrorMsg := "cbor: found duplicate map key \"a\" at map element index 3" - dm, _ := DecOptions{DupMapKey: DupMapKeyEnforcedAPF}.DecMode() - var s2 s - if err := dm.Unmarshal(data, &s2); err == nil { - t.Errorf("Unmarshal(0x%x, %s) didn't return an error", data, reflect.TypeOf(s2)) - } else if _, ok := err.(*DupMapKeyError); !ok { - t.Errorf("Unmarshal(0x%x) returned wrong error type %T, want (*DupMapKeyError)", data, err) - } else if !strings.Contains(err.Error(), wantErrorMsg) { - t.Errorf("Unmarshal(0x%x) returned error %q, want error containing %q", data, err.Error(), wantErrorMsg) - } - if !reflect.DeepEqual(s2, wantS) { - t.Errorf("Unmarshal(0x%x) = %+v (%T), want %+v (%T)", data, s2, s2, wantS, wantS) + for _, tc := range []struct { + name string + opts DecOptions + data []byte + want s + wantErr *DupMapKeyError + }{ + { + name: "duplicate key does not overwrite previous value", + data: hexDecode("a6616161416162614261636143616161466164614461656145"), // {"a": "A", "b": "B", "c": "C", "a": "F", "d": "D", "e": "E"} + want: s{A: "A", B: "B", C: "C", D: "D", E: "E"}, + }, + { + name: "duplicate key triggers error", + opts: DecOptions{DupMapKey: DupMapKeyEnforcedAPF}, + data: hexDecode("a6616161416162614261636143616161466164614461656145"), // {"a": "A", "b": "B", "c": "C", "a": "F", "d": "D", "e": "E"} + want: s{A: "A", B: "B", C: "C"}, + wantErr: &DupMapKeyError{Key: "a", Index: 3}, + }, + { + name: "duplicate keys of comparable but disallowed cbor types skips remaining entries and returns error", + opts: DecOptions{DupMapKey: DupMapKeyEnforcedAPF}, + data: hexDecode("a7616161416162614261636143d903e70100d903e701016164614461656145"), // {"a": "A", "b": "B", "c": "C", 999(1): 0, 999(1): 1, "d": "D", "e": "E"} + want: s{A: "A", B: "B", C: "C"}, + wantErr: &DupMapKeyError{Key: Tag{Number: 999, Content: uint64(1)}, Index: 4}, + }, + { + name: "mixed-case duplicate key does not overwrite previous value", + data: hexDecode("a6616161416162614261636143614161466164614461656145"), // {"a": "A", "b": "B", "c": "C", "A": "F", "d": "D", "e": "E"} + want: s{A: "A", B: "B", C: "C", D: "D", E: "E"}, + }, + { + name: "mixed-case duplicate key triggers error", + opts: DecOptions{DupMapKey: DupMapKeyEnforcedAPF}, + data: hexDecode("a6616161416162614261636143614161466164614461656145"), // {"a": "A", "b": "B", "c": "C", "A": "F", "d": "D", "e": "E"} + want: s{A: "A", B: "B", C: "C"}, + wantErr: &DupMapKeyError{Key: "A", Index: 3}, + }, + { + name: "keyasint duplicate key does not overwrite previous value", + data: hexDecode("a36131616901614961616141"), // {"1": "i", 1: "I", "a": "A"} + want: s{I: "i", A: "A"}, + }, + { + name: "keyasint duplicate key triggers error", + opts: DecOptions{DupMapKey: DupMapKeyEnforcedAPF}, + data: hexDecode("a36131616901614961616141"), // {"1": "i", 1: "I", "a": "A"} + want: s{I: "i"}, + wantErr: &DupMapKeyError{Key: int64(1), Index: 1}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + dm, err := tc.opts.DecMode() + if err != nil { + t.Fatal(err) + } + + var s1 s + if err := dm.Unmarshal(tc.data, &s1); err != nil { + if !reflect.DeepEqual(err, tc.wantErr) { + t.Errorf("got error: %v, wanted: %v", err, tc.wantErr) + } + } else { + if tc.wantErr != nil { + t.Errorf("got nil error, wanted: %v", tc.wantErr) + } + } + + if !reflect.DeepEqual(s1, tc.want) { + t.Errorf("Unmarshal(0x%x) = %+v (%T), want %+v (%T)", tc.data, s1, s1, tc.want, tc.want) + } + }) } } @@ -6714,6 +6763,12 @@ func TestExtraErrorCondUnknownField(t *testing.T) { dm: dmUnknownFieldError, wantObj: s{A: "a", B: "b", C: ""}, }, + { + name: "duplicate map keys matching known field with ExtraDecErrorUnknownField", + data: hexDecode("a26141616161416141"), // map[string]string{"A": "a", "A": "A"} + dm: dmUnknownFieldError, + wantObj: s{A: "a"}, + }, { name: "CBOR map unknown field", data: hexDecode("a461416161614261626143616361446164"), // map[string]string{"A": "a", "B": "b", "C": "c", "D": "d"} @@ -8038,9 +8093,9 @@ func TestDecodeFieldNameMatching(t *testing.T) { }, { // the field tags themselves are case-insensitive matches for each other - name: "duplicate keys decode to different fields", + name: "duplicate key does not fall back to case-insensitive match", data: hexDecode("a2614201614202"), // {"B": 1, "B": 2} (invalid) - wantValue: s{UpperB: 1, LowerB: 2}, + wantValue: s{UpperB: 1}, }, }