Skip to content

Commit

Permalink
Merge pull request #492 from benluddy/dupmapkey-same-struct-field
Browse files Browse the repository at this point in the history
Treat map keys matching the same struct field as duplicates.
  • Loading branch information
fxamacker authored Feb 19, 2024
2 parents 7959607 + 1d28086 commit afbafc4
Show file tree
Hide file tree
Showing 4 changed files with 448 additions and 68 deletions.
207 changes: 207 additions & 0 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cbor

import (
"bytes"
"fmt"
"io"
"reflect"
"testing"
Expand Down Expand Up @@ -743,3 +744,209 @@ func BenchmarkMarshalCOSEMACWithTag(b *testing.B) {
}
}
}

func BenchmarkUnmarshalMapToStruct(b *testing.B) {
type S struct {
A, B, C, D, E, F, G, H, I, J, K, L, M bool
}

var (
allKnownFields = hexDecode("ad6141f56142f56143f56144f56145f56146f56147f56148f56149f5614af5614bf5614cf5614df5") // {"A": true, ... "M": true }
allKnownDuplicateFields = hexDecode("ad6141f56141f56141f56141f56141f56141f56141f56141f56141f56141f56141f56141f56141f5") // {"A": true, "A": true, "A": true, ...}
allUnknownFields = hexDecode("ad614ef5614ff56150f56151f56152f56153f56154f56155f56156f56157f56158f56159f5615af5") // {"N": true, ... "Z": true }
allUnknownDuplicateFields = hexDecode("ad614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5") // {"N": true, "N": true, "N": true, ...}
)

type ManyFields struct {
AA, AB, AC, AD, AE, AF, AG, AH, AI, AJ, AK, AL, AM, AN, AO, AP, AQ, AR, AS, AT, AU, AV, AW, AX, AY, AZ bool
BA, BB, BC, BD, BE, BF, BG, BH, BI, BJ, BK, BL, BM, BN, BO, BP, BQ, BR, BS, BT, BU, BV, BW, BX, BY, BZ bool
CA, CB, CC, CD, CE, CF, CG, CH, CI, CJ, CK, CL, CM, CN, CO, CP, CQ, CR, CS, CT, CU, CV, CW, CX, CY, CZ bool
DA, DB, DC, DD, DE, DF, DG, DH, DI, DJ, DK, DL, DM, DN, DO, DP, DQ, DR, DS, DT, DU, DV, DW, DX, DY, DZ bool
}
var manyFieldsOneKeyPerField []byte
{
// An EncOption that accepts a function to sort or shuffle keys might be useful for
// cases like this. Here we are manually encoding the fields in reverse order to
// target worst-case key-to-field matching.
rt := reflect.TypeOf(ManyFields{})
var buf bytes.Buffer
if rt.NumField() > 255 {
b.Fatalf("invalid test assumption: ManyFields expected to have no more than 255 fields, has %d", rt.NumField())
}
buf.WriteByte(0xb8)
buf.WriteByte(byte(rt.NumField()))
for i := rt.NumField() - 1; i >= 0; i-- { // backwards
f := rt.Field(i)
if len(f.Name) > 23 {
b.Fatalf("invalid test assumption: field name %q longer than 23 bytes", f.Name)
}
buf.WriteByte(byte(0x60 + len(f.Name)))
buf.WriteString(f.Name)
buf.WriteByte(0xf5) // true
}
manyFieldsOneKeyPerField = buf.Bytes()
}

type input struct {
name string
data []byte
into interface{}
reject bool
}

for _, tc := range []struct {
name string
opts DecOptions
inputs []input
}{
{
name: "default options",
opts: DecOptions{},
inputs: []input{
{
name: "all known fields",
data: allKnownFields,
into: S{},
reject: false,
},
{
name: "all known duplicate fields",
data: allKnownDuplicateFields,
into: S{},
reject: false,
},
{
name: "all unknown fields",
data: allUnknownFields,
into: S{},
reject: false,
},
{
name: "all unknown duplicate fields",
data: allUnknownDuplicateFields,
into: S{},
reject: false,
},
{
name: "many fields one key per field",
data: manyFieldsOneKeyPerField,
into: ManyFields{},
reject: false,
},
},
},
{
name: "reject unknown",
opts: DecOptions{ExtraReturnErrors: ExtraDecErrorUnknownField},
inputs: []input{
{
name: "all known fields",
data: allKnownFields,
into: S{},
reject: false,
},
{
name: "all known duplicate fields",
data: allKnownDuplicateFields,
into: S{},
reject: false,
},
{
name: "all unknown fields",
data: allUnknownFields,
into: S{},
reject: true,
},
{
name: "all unknown duplicate fields",
data: allUnknownDuplicateFields,
into: S{},
reject: true,
},
},
},
{
name: "reject duplicate",
opts: DecOptions{DupMapKey: DupMapKeyEnforcedAPF},
inputs: []input{
{
name: "all known fields",
data: allKnownFields,
into: S{},
reject: false,
},
{
name: "all known duplicate fields",
data: allKnownDuplicateFields,
into: S{},
reject: true,
},
{
name: "all unknown fields",
data: allUnknownFields,
into: S{},
reject: false,
},
{
name: "all unknown duplicate fields",
data: allUnknownDuplicateFields,
into: S{},
reject: true,
},
},
},
{
name: "reject unknown and duplicate",
opts: DecOptions{
DupMapKey: DupMapKeyEnforcedAPF,
ExtraReturnErrors: ExtraDecErrorUnknownField,
},
inputs: []input{
{
name: "all known fields",
data: allKnownFields,
into: S{},
reject: false,
},
{
name: "all known duplicate fields",
data: allKnownDuplicateFields,
into: S{},
reject: true,
},
{
name: "all unknown fields",
data: allUnknownFields,
into: S{},
reject: true,
},
{
name: "all unknown duplicate fields",
data: allUnknownDuplicateFields,
into: S{},
reject: true,
},
},
},
} {
for _, in := range tc.inputs {
b.Run(fmt.Sprintf("%s/%s", tc.name, in.name), func(b *testing.B) {
dm, err := tc.opts.DecMode()
if err != nil {
b.Fatal(err)
}

dst := reflect.New(reflect.TypeOf(in.into)).Interface()

b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := dm.Unmarshal(in.data, dst); !in.reject && err != nil {
b.Fatalf("unexpected error: %v", err)
} else if in.reject && err == nil {
b.Fatal("expected non-nil error")
}
}
})
}
}
}
58 changes: 52 additions & 6 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package cbor
import (
"bytes"
"errors"
"fmt"
"reflect"
"sort"
"strconv"
Expand Down Expand Up @@ -84,9 +85,25 @@ func newTypeInfo(t reflect.Type) *typeInfo {
}

type decodingStructType struct {
fields fields
err error
toArray bool
fields fields
fieldIndicesByName map[string]int
err error
toArray bool
}

// The stdlib errors.Join was introduced in Go 1.20, and we still support Go 1.17, so instead,
// here's a very basic implementation of an aggregated error.
type multierror []error

func (m multierror) Error() string {
var sb strings.Builder
for i, err := range m {
sb.WriteString(err.Error())
if i < len(m)-1 {
sb.WriteString(", ")
}
}
return sb.String()
}

func getDecodingStructType(t reflect.Type) *decodingStructType {
Expand All @@ -98,12 +115,12 @@ func getDecodingStructType(t reflect.Type) *decodingStructType {

toArray := hasToArrayOption(structOptions)

var err error
var errs []error
for i := 0; i < len(flds); i++ {
if flds[i].keyAsInt {
nameAsInt, numErr := strconv.Atoi(flds[i].name)
if numErr != nil {
err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")")
errs = append(errs, errors.New("cbor: failed to parse field name \""+flds[i].name+"\" to int ("+numErr.Error()+")"))
break
}
flds[i].nameAsInt = int64(nameAsInt)
Expand All @@ -112,7 +129,36 @@ func getDecodingStructType(t reflect.Type) *decodingStructType {
flds[i].typInfo = getTypeInfo(flds[i].typ)
}

structType := &decodingStructType{fields: flds, err: err, toArray: toArray}
fieldIndicesByName := make(map[string]int, len(flds))
for i, fld := range flds {
if _, ok := fieldIndicesByName[fld.name]; ok {
errs = append(errs, fmt.Errorf("cbor: two or more fields of %v have the same name %q", t, fld.name))
continue
}
fieldIndicesByName[fld.name] = i
}

var err error
{
var multi multierror
for _, each := range errs {
if each != nil {
multi = append(multi, each)
}
}
if len(multi) == 1 {
err = multi[0]
} else if len(multi) > 1 {
err = multi
}
}

structType := &decodingStructType{
fields: flds,
fieldIndicesByName: fieldIndicesByName,
err: err,
toArray: toArray,
}
decodingStructTypeCache.Store(t, structType)
return structType
}
Expand Down
Loading

0 comments on commit afbafc4

Please sign in to comment.