diff --git a/cache.go b/cache.go index 8a4a5c87..3d7f5bc0 100644 --- a/cache.go +++ b/cache.go @@ -6,6 +6,7 @@ package cbor import ( "bytes" "errors" + "fmt" "reflect" "sort" "strconv" @@ -84,9 +85,10 @@ func newTypeInfo(t reflect.Type) *typeInfo { } type decodingStructType struct { - fields fields - err error - toArray bool + fields fields + fieldIndicesByName map[string]int + err error + toArray bool } func getDecodingStructType(t reflect.Type) *decodingStructType { @@ -98,12 +100,12 @@ func getDecodingStructType(t reflect.Type) *decodingStructType { toArray := hasToArrayOption(structOptions) - var err error + var errs []error for i := 0; i < len(flds); i++ { if flds[i].keyAsInt { nameAsInt, numErr := strconv.Atoi(flds[i].name) if numErr != nil { - err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")") + errs = append(errs, errors.New("cbor: failed to parse field name \""+flds[i].name+"\" to int ("+numErr.Error()+")")) break } flds[i].nameAsInt = int64(nameAsInt) @@ -112,7 +114,21 @@ func getDecodingStructType(t reflect.Type) *decodingStructType { flds[i].typInfo = getTypeInfo(flds[i].typ) } - structType := &decodingStructType{fields: flds, err: err, toArray: toArray} + fieldIndicesByName := make(map[string]int, len(flds)) + for i, fld := range flds { + if _, ok := fieldIndicesByName[fld.name]; ok { + errs = append(errs, fmt.Errorf("cbor: two or more fields of %v have the same name %q", t, fld.name)) + continue + } + fieldIndicesByName[fld.name] = i + } + + structType := &decodingStructType{ + fields: flds, + fieldIndicesByName: fieldIndicesByName, + err: errors.Join(errs...), + toArray: toArray, + } decodingStructTypeCache.Store(t, structType) return structType } diff --git a/decode.go b/decode.go index 0b44124d..e1e07a59 100644 --- a/decode.go +++ b/decode.go @@ -207,7 +207,12 @@ func (e *UnknownFieldError) Error() string { return fmt.Sprintf("cbor: found unknown field at map element index %d", e.Index) } -// DupMapKeyMode specifies how to enforce duplicate map key. +// DupMapKeyMode specifies how to enforce duplicate map key. Two map keys are considered duplicates if: +// 1. When decoding into a struct, both keys match the same struct field. The keys are also +// considered duplicates if neither matches any field and decoding to interface{} would produce +// equal (==) values for both keys. +// 2. When decoding into a map, both keys are equal (==) when decoded into values of the +// destination map's key type. type DupMapKeyMode int const ( @@ -1893,20 +1898,32 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n count := int(val) // Keeps track of matched struct fields - foundFldIdx := make([]bool, len(structType.fields)) + var foundFldIdx []bool + { + const maxStackFields = 128 + if nfields := len(structType.fields); nfields <= maxStackFields { + // For structs with typical field counts, expect that this can be + // stack-allocated. + var a [maxStackFields]bool + foundFldIdx = a[:nfields] + } else { + foundFldIdx = make([]bool, len(structType.fields)) + } + } // Keeps track of CBOR map keys to detect duplicate map key keyCount := 0 var mapKeys map[interface{}]struct{} - if d.dm.dupMapKey == DupMapKeyEnforcedAPF { - mapKeys = make(map[interface{}]struct{}, len(structType.fields)) - } errOnUnknownField := (d.dm.extraReturnErrors & ExtraDecErrorUnknownField) > 0 +MapEntryLoop: for j := 0; (hasSize && j < count) || (!hasSize && !d.foundBreak()); j++ { var f *field - var k interface{} // Used by duplicate map key detection + + // If duplicate field detection is enabled and the key at index j did not match any + // field, k will hold the map key. + var k interface{} t := d.nextCBORType() if t == cborTypeTextString || (t == cborTypeByteString && d.dm.fieldNameByteString == FieldNameByteStringAllowed) { @@ -1924,30 +1941,61 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n keyBytes, _ = d.parseByteString() } - keyLen := len(keyBytes) - // Find field with exact match - for i := 0; i < len(structType.fields); i++ { + // Check for exact match on field name. + if i, ok := structType.fieldIndicesByName[string(keyBytes)]; ok { fld := structType.fields[i] - if !foundFldIdx[i] && len(fld.name) == keyLen && fld.name == string(keyBytes) { + + if !foundFldIdx[i] { f = fld foundFldIdx[i] = true - break + } else if d.dm.dupMapKey == DupMapKeyEnforcedAPF { + err = &DupMapKeyError{fld.name, j} + d.skip() // skip value + j++ + // skip the rest of the map + for ; (hasSize && j < count) || (!hasSize && !d.foundBreak()); j++ { + d.skip() + d.skip() + } + return err + } else { + // discard repeated match + d.skip() + continue MapEntryLoop } } + // Find field with case-insensitive match if f == nil && d.dm.fieldNameMatching == FieldNameMatchingPreferCaseSensitive { + keyLen := len(keyBytes) keyString := string(keyBytes) for i := 0; i < len(structType.fields); i++ { fld := structType.fields[i] - if !foundFldIdx[i] && len(fld.name) == keyLen && strings.EqualFold(fld.name, keyString) { - f = fld - foundFldIdx[i] = true + if len(fld.name) == keyLen && strings.EqualFold(fld.name, keyString) { + if !foundFldIdx[i] { + f = fld + foundFldIdx[i] = true + } else if d.dm.dupMapKey == DupMapKeyEnforcedAPF { + err = &DupMapKeyError{keyString, j} + d.skip() // skip value + j++ + // skip the rest of the map + for ; (hasSize && j < count) || (!hasSize && !d.foundBreak()); j++ { + d.skip() + d.skip() + } + return err + } else { + // discard repeated match + d.skip() + continue MapEntryLoop + } break } } } - if d.dm.dupMapKey == DupMapKeyEnforcedAPF { + if d.dm.dupMapKey == DupMapKeyEnforcedAPF && f == nil { k = string(keyBytes) } } else if t <= cborTypeNegativeInt { // uint/int @@ -1975,14 +2023,30 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n // Find field for i := 0; i < len(structType.fields); i++ { fld := structType.fields[i] - if !foundFldIdx[i] && fld.keyAsInt && fld.nameAsInt == nameAsInt { - f = fld - foundFldIdx[i] = true + if fld.keyAsInt && fld.nameAsInt == nameAsInt { + if !foundFldIdx[i] { + f = fld + foundFldIdx[i] = true + } else if d.dm.dupMapKey == DupMapKeyEnforcedAPF { + err = &DupMapKeyError{nameAsInt, j} + d.skip() // skip value + j++ + // skip the rest of the map + for ; (hasSize && j < count) || (!hasSize && !d.foundBreak()); j++ { + d.skip() + d.skip() + } + return err + } else { + // discard repeated match + d.skip() + continue MapEntryLoop + } break } } - if d.dm.dupMapKey == DupMapKeyEnforcedAPF { + if d.dm.dupMapKey == DupMapKeyEnforcedAPF && f == nil { k = nameAsInt } } else { @@ -2010,23 +2074,6 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n } } - if d.dm.dupMapKey == DupMapKeyEnforcedAPF { - mapKeys[k] = struct{}{} - newKeyCount := len(mapKeys) - if newKeyCount == keyCount { - err = &DupMapKeyError{k, j} - d.skip() // skip value - j++ - // skip the rest of the map - for ; (hasSize && j < count) || (!hasSize && !d.foundBreak()); j++ { - d.skip() - d.skip() - } - return err - } - keyCount = newKeyCount - } - if f == nil { if errOnUnknownField { err = &UnknownFieldError{j} @@ -2039,6 +2086,31 @@ func (d *decoder) parseMapToStruct(v reflect.Value, tInfo *typeInfo) error { //n } return err } + + // Two map keys that match the same struct field are immediately considered + // duplicates. This check detects duplicates between two map keys that do + // not match a struct field. If unknown field errors are enabled, then this + // check is never reached. + if d.dm.dupMapKey == DupMapKeyEnforcedAPF { + if mapKeys == nil { + mapKeys = make(map[interface{}]struct{}, 1) + } + mapKeys[k] = struct{}{} + newKeyCount := len(mapKeys) + if newKeyCount == keyCount { + err = &DupMapKeyError{k, j} + d.skip() // skip value + j++ + // skip the rest of the map + for ; (hasSize && j < count) || (!hasSize && !d.foundBreak()); j++ { + d.skip() + d.skip() + } + return err + } + keyCount = newKeyCount + } + d.skip() // Skip value continue }