Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Treat map keys matching the same struct field as duplicates. #492

Merged
merged 2 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cbor

import (
"bytes"
"fmt"
"io"
"reflect"
"testing"
Expand Down Expand Up @@ -743,3 +744,209 @@ func BenchmarkMarshalCOSEMACWithTag(b *testing.B) {
}
}
}

func BenchmarkUnmarshalMapToStruct(b *testing.B) {
type S struct {
A, B, C, D, E, F, G, H, I, J, K, L, M bool
}

var (
allKnownFields = hexDecode("ad6141f56142f56143f56144f56145f56146f56147f56148f56149f5614af5614bf5614cf5614df5") // {"A": true, ... "M": true }
allKnownDuplicateFields = hexDecode("ad6141f56141f56141f56141f56141f56141f56141f56141f56141f56141f56141f56141f56141f5") // {"A": true, "A": true, "A": true, ...}
allUnknownFields = hexDecode("ad614ef5614ff56150f56151f56152f56153f56154f56155f56156f56157f56158f56159f5615af5") // {"N": true, ... "Z": true }
allUnknownDuplicateFields = hexDecode("ad614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5614ef5") // {"N": true, "N": true, "N": true, ...}
)

type ManyFields struct {
AA, AB, AC, AD, AE, AF, AG, AH, AI, AJ, AK, AL, AM, AN, AO, AP, AQ, AR, AS, AT, AU, AV, AW, AX, AY, AZ bool
BA, BB, BC, BD, BE, BF, BG, BH, BI, BJ, BK, BL, BM, BN, BO, BP, BQ, BR, BS, BT, BU, BV, BW, BX, BY, BZ bool
CA, CB, CC, CD, CE, CF, CG, CH, CI, CJ, CK, CL, CM, CN, CO, CP, CQ, CR, CS, CT, CU, CV, CW, CX, CY, CZ bool
DA, DB, DC, DD, DE, DF, DG, DH, DI, DJ, DK, DL, DM, DN, DO, DP, DQ, DR, DS, DT, DU, DV, DW, DX, DY, DZ bool
}
var manyFieldsOneKeyPerField []byte
{
// An EncOption that accepts a function to sort or shuffle keys might be useful for
// cases like this. Here we are manually encoding the fields in reverse order to
// target worst-case key-to-field matching.
rt := reflect.TypeOf(ManyFields{})
var buf bytes.Buffer
if rt.NumField() > 255 {
b.Fatalf("invalid test assumption: ManyFields expected to have no more than 255 fields, has %d", rt.NumField())
}
buf.WriteByte(0xb8)
buf.WriteByte(byte(rt.NumField()))
for i := rt.NumField() - 1; i >= 0; i-- { // backwards
f := rt.Field(i)
if len(f.Name) > 23 {
b.Fatalf("invalid test assumption: field name %q longer than 23 bytes", f.Name)
}
buf.WriteByte(byte(0x60 + len(f.Name)))
buf.WriteString(f.Name)
buf.WriteByte(0xf5) // true
}
manyFieldsOneKeyPerField = buf.Bytes()
}

type input struct {
name string
data []byte
into interface{}
reject bool
}

for _, tc := range []struct {
name string
opts DecOptions
inputs []input
}{
{
name: "default options",
opts: DecOptions{},
inputs: []input{
{
name: "all known fields",
data: allKnownFields,
into: S{},
reject: false,
},
{
name: "all known duplicate fields",
data: allKnownDuplicateFields,
into: S{},
reject: false,
},
{
name: "all unknown fields",
data: allUnknownFields,
into: S{},
reject: false,
},
{
name: "all unknown duplicate fields",
data: allUnknownDuplicateFields,
into: S{},
reject: false,
},
{
name: "many fields one key per field",
data: manyFieldsOneKeyPerField,
into: ManyFields{},
reject: false,
},
},
},
{
name: "reject unknown",
opts: DecOptions{ExtraReturnErrors: ExtraDecErrorUnknownField},
inputs: []input{
{
name: "all known fields",
data: allKnownFields,
into: S{},
reject: false,
},
{
name: "all known duplicate fields",
data: allKnownDuplicateFields,
into: S{},
reject: false,
},
{
name: "all unknown fields",
data: allUnknownFields,
into: S{},
reject: true,
},
{
name: "all unknown duplicate fields",
data: allUnknownDuplicateFields,
into: S{},
reject: true,
},
},
},
{
name: "reject duplicate",
opts: DecOptions{DupMapKey: DupMapKeyEnforcedAPF},
inputs: []input{
{
name: "all known fields",
data: allKnownFields,
into: S{},
reject: false,
},
{
name: "all known duplicate fields",
data: allKnownDuplicateFields,
into: S{},
reject: true,
},
{
name: "all unknown fields",
data: allUnknownFields,
into: S{},
reject: false,
},
{
name: "all unknown duplicate fields",
data: allUnknownDuplicateFields,
into: S{},
reject: true,
},
},
},
{
name: "reject unknown and duplicate",
opts: DecOptions{
DupMapKey: DupMapKeyEnforcedAPF,
ExtraReturnErrors: ExtraDecErrorUnknownField,
},
inputs: []input{
{
name: "all known fields",
data: allKnownFields,
into: S{},
reject: false,
},
{
name: "all known duplicate fields",
data: allKnownDuplicateFields,
into: S{},
reject: true,
},
{
name: "all unknown fields",
data: allUnknownFields,
into: S{},
reject: true,
},
{
name: "all unknown duplicate fields",
data: allUnknownDuplicateFields,
into: S{},
reject: true,
},
},
},
} {
for _, in := range tc.inputs {
b.Run(fmt.Sprintf("%s/%s", tc.name, in.name), func(b *testing.B) {
dm, err := tc.opts.DecMode()
if err != nil {
b.Fatal(err)
}

dst := reflect.New(reflect.TypeOf(in.into)).Interface()

b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := dm.Unmarshal(in.data, dst); !in.reject && err != nil {
b.Fatalf("unexpected error: %v", err)
} else if in.reject && err == nil {
b.Fatal("expected non-nil error")
}
}
})
}
}
}
58 changes: 52 additions & 6 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package cbor
import (
"bytes"
"errors"
"fmt"
"reflect"
"sort"
"strconv"
Expand Down Expand Up @@ -84,9 +85,25 @@ func newTypeInfo(t reflect.Type) *typeInfo {
}

type decodingStructType struct {
fields fields
err error
toArray bool
fields fields
fieldIndicesByName map[string]int
err error
toArray bool
}

// The stdlib errors.Join was introduced in Go 1.20, and we still support Go 1.17, so instead,
// here's a very basic implementation of an aggregated error.
type multierror []error

func (m multierror) Error() string {
var sb strings.Builder
for i, err := range m {
sb.WriteString(err.Error())
if i < len(m)-1 {
sb.WriteString(", ")
}
}
return sb.String()
}

func getDecodingStructType(t reflect.Type) *decodingStructType {
Expand All @@ -98,12 +115,12 @@ func getDecodingStructType(t reflect.Type) *decodingStructType {

toArray := hasToArrayOption(structOptions)

var err error
var errs []error
for i := 0; i < len(flds); i++ {
if flds[i].keyAsInt {
nameAsInt, numErr := strconv.Atoi(flds[i].name)
if numErr != nil {
err = errors.New("cbor: failed to parse field name \"" + flds[i].name + "\" to int (" + numErr.Error() + ")")
errs = append(errs, errors.New("cbor: failed to parse field name \""+flds[i].name+"\" to int ("+numErr.Error()+")"))
break
}
flds[i].nameAsInt = int64(nameAsInt)
Expand All @@ -112,7 +129,36 @@ func getDecodingStructType(t reflect.Type) *decodingStructType {
flds[i].typInfo = getTypeInfo(flds[i].typ)
}

structType := &decodingStructType{fields: flds, err: err, toArray: toArray}
fieldIndicesByName := make(map[string]int, len(flds))
for i, fld := range flds {
if _, ok := fieldIndicesByName[fld.name]; ok {
errs = append(errs, fmt.Errorf("cbor: two or more fields of %v have the same name %q", t, fld.name))
continue
}
fieldIndicesByName[fld.name] = i
}

var err error
{
var multi multierror
for _, each := range errs {
if each != nil {
multi = append(multi, each)
}
}
if len(multi) == 1 {
err = multi[0]
} else if len(multi) > 1 {
err = multi
}
}

structType := &decodingStructType{
fields: flds,
fieldIndicesByName: fieldIndicesByName,
err: err,
toArray: toArray,
}
decodingStructTypeCache.Store(t, structType)
return structType
}
Expand Down
Loading
Loading