From bf5e39bc5ed0b316270f4f8aa492e48ca06c11b7 Mon Sep 17 00:00:00 2001 From: Dmitriy Matrenichev Date: Tue, 9 Aug 2022 00:39:15 +0300 Subject: [PATCH] chore: support (u)int(8|16) fields ans slices, fix map issues, This commit does several things: - Adds support for (u)int(8|16) fields and slices. Any such type (except for []uint8 and [n]uint8) will be encoded as protobuf fixed32 - Refactor unmarshal.go - Disallow encoding of maps with complex keys - Disallow encoding of maps with repeated values - Check pointer to slices in slices elements - Support empty structs - Arrays of complex types cannot be decoded for now - Other small fixes Signed-off-by: Dmitriy Matrenichev --- .dockerignore | 5 +- Dockerfile | 5 +- array_test.go | 316 ---------------------- field_test.go | 4 +- helpers_test.go | 17 ++ marshal.go | 110 ++++++-- marshal_test.go | 56 ++-- messages/fuzz_test.go | 12 - messages/messages_test.go | 32 +++ person_test.go | 4 +- predefined_types.go | 2 +- protobuf_test.go | 117 +++++--- scanner.go | 334 +++++++++++++++++++++++ slice_test.go | 454 +++++++++++++++++++++++++++++++ type_cache.go | 17 +- unmarshal.go | 555 +++++++++++++++++--------------------- unmarshal_fastpath.go | 2 +- 17 files changed, 1307 insertions(+), 735 deletions(-) delete mode 100644 array_test.go create mode 100644 scanner.go create mode 100644 slice_test.go diff --git a/.dockerignore b/.dockerignore index 9e5d70a..2580753 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,10 +1,9 @@ # THIS FILE WAS AUTOMATICALLY GENERATED, PLEASE DO NOT EDIT. # -# Generated on 2022-08-04T13:22:07Z by kres latest. +# Generated on 2022-08-09T01:20:13Z by kres latest. ** !messages -!array_test.go !benchmarks_test.go !example_test.go !field_test.go @@ -17,6 +16,8 @@ !person_test.go !predefined_types.go !protobuf_test.go +!scanner.go +!slice_test.go !type_cache.go !unmarshal.go !unmarshal_fastpath.go diff --git a/Dockerfile b/Dockerfile index 8906318..52fd414 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,7 @@ # THIS FILE WAS AUTOMATICALLY GENERATED, PLEASE DO NOT EDIT. # -# Generated on 2022-08-04T13:22:07Z by kres latest. +# Generated on 2022-08-09T01:20:13Z by kres latest. ARG TOOLCHAIN @@ -49,7 +49,6 @@ COPY ./go.sum . RUN --mount=type=cache,target=/go/pkg go mod download RUN --mount=type=cache,target=/go/pkg go mod verify COPY ./messages ./messages -COPY ./array_test.go ./array_test.go COPY ./benchmarks_test.go ./benchmarks_test.go COPY ./example_test.go ./example_test.go COPY ./field_test.go ./field_test.go @@ -62,6 +61,8 @@ COPY ./marshal_test.go ./marshal_test.go COPY ./person_test.go ./person_test.go COPY ./predefined_types.go ./predefined_types.go COPY ./protobuf_test.go ./protobuf_test.go +COPY ./scanner.go ./scanner.go +COPY ./slice_test.go ./slice_test.go COPY ./type_cache.go ./type_cache.go COPY ./unmarshal.go ./unmarshal.go COPY ./unmarshal_fastpath.go ./unmarshal_fastpath.go diff --git a/array_test.go b/array_test.go deleted file mode 100644 index 4cdbadb..0000000 --- a/array_test.go +++ /dev/null @@ -1,316 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package protoenc_test - -import ( - "encoding/hex" - "strings" - "testing" - "time" - - "github.com/brianvoe/gofakeit/v6" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/siderolabs/protoenc" -) - -func TestArrayEncodingDecoding(t *testing.T) { - t.Parallel() - - //nolint:govet - type localType struct { - A float32 `protobuf:"1"` - B struct { - C []int `protobuf:"1"` - } `protobuf:"2"` - C string `protobuf:"3"` - D map[protoenc.FixedS32]protoenc.FixedS64 `protobuf:"4"` - E map[protoenc.FixedU64]protoenc.FixedU32 `protobuf:"5"` - F map[float64]*struct { - G int64 `protobuf:"1"` - } `protobuf:"6"` - } - - type localMap struct { //nolint:unused - A map[int]int `protobuf:"1"` - } - - tests := map[string]struct { - fn func(t *testing.T) - }{ - "bool": {testArray[bool]}, - "string": {testArray[string]}, - "int": {testArray[int]}, - "int32": {testArray[int32]}, - "int64": {testArray[int64]}, - "uint": {testArray[uint]}, - "uint32": {testArray[uint32]}, - "uint64": {testArray[uint64]}, - "float32": {testArray[float32]}, - "float64": {testArray[float64]}, - "array[int64]": {testArray[array[int64]]}, - "FixedS32": {testArray[protoenc.FixedS32]}, - "FixedS64": {testArray[protoenc.FixedS64]}, - "FixedU32": {testArray[protoenc.FixedU32]}, - "FixedU64": {testArray[protoenc.FixedU64]}, - "time.Time": {testArray[time.Time]}, - "Struct": {testArray[localType]}, - "array[localMap]": {testArray[array[localMap]]}, - "time.Duration": {testArray[time.Duration]}, - } - - for name, test := range tests { - t.Run(name, test.fn) - } -} - -func testArray[T any](t *testing.T) { - t.Parallel() - - // This is our best-effort attempt to generate a random array of values. - for i := 0; i < 100; i++ { - original := array[T]{} - faker := gofakeit.New(Seed + int64(i)) - - require.NoError(t, faker.Struct(&original)) - - // This is needed because faker cannot fill time.Time in slices. - if timeSlice, ok := any(&original.Arr).(*[]time.Time); ok { - for i := range *timeSlice { - (*timeSlice)[i] = faker.Date() - } - } - - buf := must(protoenc.Marshal(&original))(t) - target := array[T]{} - - err := protoenc.Unmarshal(buf, &target) - if err != nil { - t.Log(original) - - require.FailNow(t, "", "%d iteration: %v", i, err) - } - - if !assert.Equal(t, original, target) { - t.Log(hex.Dump(buf)) - t.FailNow() - } - } -} - -func TestArrayEncodingForm(t *testing.T) { - t.Parallel() - - encodedInts := hexToBytes(t, "0a 03 01 02 03") - encodedFixedU32s := hexToBytes(t, "0a 0c 01 00 00 00 02 00 00 00 03 00 00 00") - encodedFloat32s := hexToBytes(t, "0a 0c 00 00 80 3f 00 00 00 40 00 00 40 40 ") - - // here 0a (which is field=1, type=2 encoded) begin to repeat because of the protbuf specification - encodedStrings := hexToBytes(t, "[0a 02 [61 62]] [0a 02 [62 63]] [0a 03 [63 64 65]]") - encodedWrappedInts := hexToBytes(t, "[0a 01 [01]] [0a 01 [02]] [0a 01 [03]]") - - // here 08 is also begin to repeat because we use inner structure - encodedLocalType := hexToBytes(t, "[0a 02 [08 01]] [0a 02 [08 02]] [0a 02 [08 03]]") - - encodedTime := hexToBytes(t, "[0a 02 [10 01]] [0a 02 [10 02]] [0a 02 [10 03]]") - - encodedDuration := hexToBytes(t, "[0a 02 [08 10]] [0a 02 [08 11]] [0a 02 [08 12]]") - - type localType struct { - A int `protobuf:"1"` - } - - tests := []struct { //nolint:govet - name string - fn func(t *testing.T) - }{ - { - "ints should be encoded in 'packed' form", - testArrayEncodingForm([]int{1, 2, 3}, encodedInts), - }, - { - "uints should be encoded in 'packed' form", - testArrayEncodingForm([]uint{1, 2, 3}, encodedInts), - }, - { - "float32s should be encoded in 'packed' form", - testArrayEncodingForm([]float32{1, 2, 3}, encodedFloat32s), - }, - { - "FixedU32s should be encoded in 'packed' form", - testArrayEncodingForm([]protoenc.FixedU32{1, 2, 3}, encodedFixedU32s), - }, - { - "strings should be encoded in normal form", - testArrayEncodingForm([]string{"ab", "bc", "cde"}, encodedStrings), - }, - { - "wrapped values should be encoded in normal form", - testArrayEncodingForm([]intWrapper{{1}, {2}, {3}}, encodedWrappedInts), - }, - { - "wrapped values with no marshallers should be encoded in normal form", - testArrayEncodingForm([]localType{{1}, {2}, {3}}, encodedLocalType), - }, - { - "time values should be encoded in normal form", - testArrayEncodingForm([]time.Time{time.Unix(0, 1), time.Unix(0, 2), time.Unix(0, 3)}, encodedTime), - }, - { - "duration values should be encoded in normal form", - testArrayEncodingForm([]time.Duration{16 * time.Second, 17 * time.Second, 18 * time.Second}, encodedDuration), - }, - { - "nil pointers will be skipped, and it should be encoded in normal form", - testArrayEncodingForm([]*string{nil, ptr("ab"), ptr("bc"), nil, ptr("cde")}, encodedStrings), - }, - } - - for _, test := range tests { - test := test - t.Run(test.name, test.fn) - } -} - -func testArrayEncodingForm[T any](slc []T, expected []byte) func(t *testing.T) { - return func(t *testing.T) { - t.Parallel() - - original := array[T]{Arr: slc} - buf := must(protoenc.Marshal(&original))(t) - - assert.Equal(t, expected, buf) - } -} - -// hexToBytes converts a hex string to a byte slice, removing any whitespace. -func hexToBytes(t *testing.T, s string) []byte { - t.Helper() - - s = strings.ReplaceAll(s, "|", "") - s = strings.ReplaceAll(s, "[", "") - s = strings.ReplaceAll(s, "]", "") - s = strings.ReplaceAll(s, " ", "") - - b, err := hex.DecodeString(s) - require.NoError(t, err) - - return b -} - -// newArray returns a new array with the given elements. -func newArray[T any](elements ...T) *array[T] { - return &array[T]{elements} -} - -type array[T any] struct { - Arr []T `protobuf:"1"` -} - -func TestDisallowedTypes(t *testing.T) { - t.Parallel() - - type localMapOfSlices struct { //nolint:unused - A map[int][]int `protobuf:"1"` - } - - type myBytes byte //nolint:unused - - tests := map[string]struct { - fn func(t *testing.T) - }{ - "array[map[string]string]": { - fn: testDisallowedTypes[array[map[string]string]], - }, - "array[localMapOfSlices]": { - fn: testDisallowedTypes[array[localMapOfSlices]], - }, - "array[myBytes]": { - fn: testDisallowedTypes[array[myBytes]], - }, - "map[string]string": { - fn: testDisallowedTypes[map[string]string], - }, - "array[Value[int]]": { - fn: func(t *testing.T) { - t.Parallel() - - arr := newArray[Value[int]](&ValueWrapper[int]{1}, &ValueWrapper[int]{1}) - buf, err := protoenc.Marshal(arr) - require.NoError(t, err) - - var target array[Value[int]] - err = protoenc.Unmarshal(buf, &target) - require.Error(t, err) - assert.Contains(t, err.Error(), "nil interface fields are not supported") - }, - }, - } - - for name, test := range tests { - t.Run(name, test.fn) - } -} - -func testDisallowedTypes[T any](t *testing.T) { - t.Parallel() - - faker := gofakeit.New(Seed) - - var original T - - require.NoError(t, faker.Struct(&original)) - - _, err := protoenc.Marshal(&original) - require.Error(t, err) - assert.Regexp(t, "(unsupported type)|(map only support)|(takes a struct)", err.Error()) -} - -func TestDuration(t *testing.T) { - t.Parallel() - - expected := newArray(time.Second*11, time.Second*12, time.Second*13) - buf := must(protoenc.Marshal(expected))(t) - - t.Log(hex.Dump(buf)) - - var actual array[time.Duration] - - require.NoError(t, protoenc.Unmarshal(buf, &actual)) - assert.Equal(t, expected.Arr, actual.Arr) -} - -func TestTime(t *testing.T) { - t.Parallel() - - expected := newArray(time.Unix(11, 0).UTC(), time.Unix(12, 0).UTC(), time.Unix(13, 0).UTC()) - buf := must(protoenc.Marshal(expected))(t) - - t.Log(hex.Dump(buf)) - - var actual array[time.Time] - - require.NoError(t, protoenc.Unmarshal(buf, &actual)) - assert.Equal(t, expected.Arr, actual.Arr) -} - -func TestSliceToArray(t *testing.T) { - t.Parallel() - - expected := newArray(1, 2, 3, 4, 5, 6, 7, 8, 9, 100500) - buf := must(protoenc.Marshal(expected))(t) - - t.Log(hex.Dump(buf)) - - type structWithArray struct { - Arr [10]int `protobuf:"1"` - } - - var actual structWithArray - - require.NoError(t, protoenc.Unmarshal(buf, &actual)) - assert.Equal(t, expected.Arr, actual.Arr[:]) -} diff --git a/field_test.go b/field_test.go index ce43b86..f1f0d2f 100644 --- a/field_test.go +++ b/field_test.go @@ -14,7 +14,7 @@ import ( "github.com/siderolabs/protoenc" ) -func TestEncodeNested(t *testing.T) { +func TestStructFields_EmbedStruct(t *testing.T) { t.Parallel() s := &StructWithEmbed{ @@ -65,7 +65,7 @@ type EmbedStruct struct { C int32 `protobuf:"11"` } -func TestDuplicateIDNotAllowed(t *testing.T) { +func TestStructFields_DuplicateIDNotAllowed(t *testing.T) { t.Parallel() v := reflect.TypeOf(&StructWithDuplicates{}) diff --git a/helpers_test.go b/helpers_test.go index 354fee7..e8c298e 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -5,6 +5,8 @@ package protoenc_test import ( + "encoding/hex" + "strings" "testing" "github.com/stretchr/testify/require" @@ -30,3 +32,18 @@ func panicOnErr[T any](t T, err error) T { func ptr[T any](t T) *T { return &t } + +// hexToBytes converts a hex string to a byte slice, removing any whitespace. +func hexToBytes(t *testing.T, s string) []byte { + t.Helper() + + s = strings.ReplaceAll(s, "|", "") + s = strings.ReplaceAll(s, "[", "") + s = strings.ReplaceAll(s, "]", "") + s = strings.ReplaceAll(s, " ", "") + + b, err := hex.DecodeString(s) + require.NoError(t, err) + + return b +} diff --git a/marshal.go b/marshal.go index b70e87f..924fe4a 100644 --- a/marshal.go +++ b/marshal.go @@ -48,7 +48,7 @@ func Marshal(ptr interface{}) (result []byte, err error) { } val := reflect.ValueOf(ptr) - if val.Kind() != reflect.Ptr { + if val.Kind() != reflect.Pointer { return nil, errors.New("encode takes a pointer to struct") } @@ -62,6 +62,10 @@ type marshaller struct { } func (m *marshaller) Bytes() []byte { + if len(m.buf) == 0 { + return nil + } + return m.buf } @@ -132,7 +136,7 @@ func fieldByIndex(structVal reflect.Value, data FieldData) reflect.Value { index := data.FieldIndex[:i+1] result = structVal.FieldByIndex(index) - if len(data.FieldIndex) > 1 && result.Kind() == reflect.Ptr && result.IsNil() { + if len(data.FieldIndex) > 1 && result.Kind() == reflect.Pointer && result.IsNil() { // Embedded field is nil, return empty reflect.Value. Avo return reflect.Value{} } @@ -141,18 +145,25 @@ func fieldByIndex(structVal reflect.Value, data FieldData) reflect.Value { return result } -//nolint:cyclop +//nolint:cyclop,gocyclo func (m *marshaller) encodeValue(num protowire.Number, val reflect.Value) { if m.tryEncodePredefined(num, val) { return } - // Note that protobufs don't support 8- or 16-bit ints. switch val.Kind() { //nolint:exhaustive case reflect.Bool: putTag(m, num, protowire.VarintType) putBool(m, val.Bool()) + case reflect.Int8, reflect.Int16: + putTag(m, num, protowire.Fixed32Type) + putInt32(m, int32(val.Int())) + + case reflect.Uint8, reflect.Uint16: + putTag(m, num, protowire.Fixed32Type) + putInt32(m, int32(val.Uint())) + case reflect.Int, reflect.Int32, reflect.Int64: putTag(m, num, protowire.VarintType) putUVarint(m, val.Int()) @@ -174,8 +185,6 @@ func (m *marshaller) encodeValue(num protowire.Number, val reflect.Value) { putString(m, val.String()) case reflect.Struct: - putTag(m, num, protowire.BytesType) - var b []byte bmarshaler, ok := asBinaryMarshaler(val) @@ -192,6 +201,7 @@ func (m *marshaller) encodeValue(num protowire.Number, val reflect.Value) { b = inner.Bytes() } + putTag(m, num, protowire.BytesType) putBytes(m, b) case reflect.Slice, reflect.Array: if val.Len() == 0 { @@ -202,7 +212,7 @@ func (m *marshaller) encodeValue(num protowire.Number, val reflect.Value) { return - case reflect.Ptr: + case reflect.Pointer: if val.IsNil() { return } @@ -324,14 +334,16 @@ func (m *marshaller) encodeSlice(key protowire.Number, val reflect.Value) { sliceLen := val.Len() result := marshaller{} - switch val.Type() { - case typeBytes: - // Special case for []byte. + typ := val.Type() + if typ.Elem() == typeByte { + // Special case for byte arrays and slices. putTag(m, key, protowire.BytesType) putBytes(m, val.Bytes()) return + } + switch typ { case typeDurations: // Special case for []time.Duration. slice := val.Interface().([]time.Duration) //nolint:errcheck,forcetypeassert @@ -380,6 +392,7 @@ func (m *marshaller) encodeSlice(key protowire.Number, val reflect.Value) { putBytes(m, result.Bytes()) } +//nolint:gocyclo,cyclop func (m *marshaller) sliceReflect(key protowire.Number, val reflect.Value) { if !isSliceOrArray(val) { panic("passed value is not slice or array") @@ -390,6 +403,16 @@ func (m *marshaller) sliceReflect(key protowire.Number, val reflect.Value) { result := marshaller{} switch elem.Kind() { //nolint:exhaustive + case reflect.Int8, reflect.Int16: + for i := 0; i < sliceLen; i++ { + putInt32(&result, int32(val.Index(i).Int())) + } + + case reflect.Uint8, reflect.Uint16: + for i := 0; i < sliceLen; i++ { + putInt32(&result, uint32(val.Index(i).Uint())) + } + case reflect.Bool: for i := 0; i < sliceLen; i++ { putBool(&result, val.Index(i).Bool()) @@ -415,8 +438,16 @@ func (m *marshaller) sliceReflect(key protowire.Number, val reflect.Value) { putInt64(&result, math.Float64bits(val.Index(i).Float())) } - case reflect.Uint8: - panic(fmt.Errorf("unsupported type %s", val.Type().String())) + case reflect.Pointer: + if !isSlicePtrElemSupported(elem) { + panic(fmt.Errorf("unsupported type: '%s'", val.String())) + } + + for i := 0; i < sliceLen; i++ { + m.encodeValue(key, val.Index(i)) + } + + return case reflect.Map: panic(fmt.Errorf("unsupported type %s", val.Type().String())) @@ -425,7 +456,7 @@ func (m *marshaller) sliceReflect(key protowire.Number, val reflect.Value) { if elem.Kind() == reflect.Slice || elem.Kind() == reflect.Array { subSlice := elem.Elem() if subSlice.Kind() != reflect.Uint8 { - panic("error no support for 2-dimensional array except for [][]byte") + panic("unsupported type: error no support for 2-dimensional array except for [][]byte") } } @@ -440,19 +471,58 @@ func (m *marshaller) sliceReflect(key protowire.Number, val reflect.Value) { putBytes(m, result.buf) } +func isSlicePtrElemSupported(elem reflect.Type) bool { + elem = deref(elem) + + switch elem.Kind() { //nolint:exhaustive + case reflect.Int8, reflect.Int16, reflect.Uint8, reflect.Uint16, reflect.Bool, + reflect.Int, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return false + + case reflect.Slice, reflect.Array: + if elem.Elem().Kind() == reflect.Uint8 { + return true + } + + return false + + default: + return true + } +} + func (m *marshaller) encodeMap(key protowire.Number, mpval reflect.Value) { + first := true + for _, mkey := range mpval.MapKeys() { mval := mpval.MapIndex(mkey) - switch kind := mval.Kind(); kind { //nolint:exhaustive - case reflect.Ptr: - if mval.IsNil() { - panic("error: map has nil element") + if first { + // map key can only be a primitive type or a string + switch mkey.Kind() { //nolint:exhaustive + case reflect.Struct, reflect.Array, reflect.Interface, reflect.Pointer: + panic(errors.New("unsupported type: map key cannot be struct, array, interface or pointer")) } - case reflect.Slice, reflect.Array: - if mval.Type().Elem().Kind() != reflect.Uint8 { - panic("error: map only support []byte or string as repeated value") + + unwrapVal := deref(mval.Type()) + + switch unwrapVal.Kind() { //nolint:exhaustive + case reflect.Slice, reflect.Array: + if mval.Type().Elem() == typeByte { + break + } + + fallthrough + case reflect.Interface: + panic(errors.New("unsupported type: map value cannot be non byte slice, array or interface")) } + + first = false + } + + if kind := mval.Kind(); kind == reflect.Pointer && mval.IsNil() { + panic("error: map has nil element") } inner := marshaller{} diff --git a/marshal_test.go b/marshal_test.go index e3b223d..9161480 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -93,29 +93,29 @@ func (i *intWrapper) UnmarshalBinary(data []byte) error { func TestNoBinaryMarshaler(t *testing.T) { t.Parallel() - encoded := WrapperNoMarshal[string]{&ValueWrapper[string]{V: "test-string"}} + encoded := WrapperNoMarshal[string]{&Value[string]{V: "test-string"}} buf := must(protoenc.Marshal(&encoded))(t) - decoded := WrapperNoMarshal[string]{&ValueWrapper[string]{V: ""}} + decoded := WrapperNoMarshal[string]{&Value[string]{V: ""}} require.NoError(t, protoenc.Unmarshal(buf, &decoded)) require.Equal(t, encoded.Field.Val(), decoded.Field.Val()) } type WrapperNoMarshal[T any] struct { - Field Value[T] `protobuf:"1"` + Field Valuer[T] `protobuf:"1"` } -type Value[T any] interface { +type Valuer[T any] interface { Val() T } -type ValueWrapper[T any] struct { +type Value[T any] struct { V T `protobuf:"1"` } -func (vw *ValueWrapper[T]) Val() T { - return vw.V +func (v *Value[T]) Val() T { + return v.V } func Test2dSlice(t *testing.T) { @@ -350,22 +350,32 @@ func TestMarshal(t *testing.T) { assert.Equal(t, b, testB) } -type EmbeddedStruct struct { - Value int `protobuf:"1"` - Value2 uint32 `protobuf:"2"` -} +func TestMarshalEmpty(t *testing.T) { + type Empty struct{} + + buf := must(protoenc.Marshal(&Empty{}))(t) + require.Len(t, buf, 0) -type AnotherEmbeddedStruct struct { - Value1 int `protobuf:"3"` - Value2 uint32 `protobuf:"4"` + buf = must(protoenc.Marshal(&OneFieldStruct[Empty]{}))(t) + require.Equal(t, []byte{0x0a, 0x00}, buf) } func TestEmbedding(t *testing.T) { + type EmbeddedStruct struct { + Value int `protobuf:"1"` + Value2 uint32 `protobuf:"2"` + } + + type AnotherEmbeddedStruct struct { + Value1 int `protobuf:"3"` + Value2 uint32 `protobuf:"4"` + } + structs := map[string]struct { fn func(t *testing.T) }{ "should embed struct": { - fn: makeEmbedTest(struct { + fn: testEncodeDecode(struct { EmbeddedStruct }{ EmbeddedStruct: EmbeddedStruct{ @@ -375,7 +385,7 @@ func TestEmbedding(t *testing.T) { }), }, "should embed struct pointer": { - fn: makeEmbedTest(struct { + fn: testEncodeDecode(struct { *EmbeddedStruct }{ EmbeddedStruct: &EmbeddedStruct{ @@ -385,7 +395,7 @@ func TestEmbedding(t *testing.T) { }), }, "should embed nil pointer struct and not nil pointer struct": { - fn: makeEmbedTest(struct { + fn: testEncodeDecode(struct { *EmbeddedStruct *AnotherEmbeddedStruct }{ @@ -397,7 +407,7 @@ func TestEmbedding(t *testing.T) { }), }, "should embed struct with marshaller": { - fn: makeEmbedTest(struct { + fn: testEncodeDecode(struct { Sequence[string] }{ Sequence: Sequence[string]{ @@ -406,21 +416,21 @@ func TestEmbedding(t *testing.T) { }), }, "should not embed nil struct pointer": { - fn: makeIncorrectEmbedTest(struct { + fn: testIncorrectEncode(struct { *EmbeddedStruct }{ EmbeddedStruct: nil, }), }, "should not embed simple type": { - fn: makeIncorrectEmbedTest(struct { + fn: testIncorrectEncode(struct { int }{ 0x11, }), }, "should not embed pointer to simple type": { - fn: makeIncorrectEmbedTest(struct { + fn: testIncorrectEncode(struct { *int }{ int: new(int), @@ -433,7 +443,7 @@ func TestEmbedding(t *testing.T) { } } -func makeEmbedTest[V any](v V) func(t *testing.T) { +func testEncodeDecode[V any](v V) func(t *testing.T) { return func(t *testing.T) { t.Helper() encoded := must(protoenc.Marshal(&v))(t) @@ -447,7 +457,7 @@ func makeEmbedTest[V any](v V) func(t *testing.T) { } } -func makeIncorrectEmbedTest[V any](v V) func(t *testing.T) { +func testIncorrectEncode[V any](v V) func(t *testing.T) { return func(t *testing.T) { t.Helper() diff --git a/messages/fuzz_test.go b/messages/fuzz_test.go index d1928fb..af0625d 100644 --- a/messages/fuzz_test.go +++ b/messages/fuzz_test.go @@ -53,15 +53,3 @@ func hexToBytes(f *testing.F, s string) []byte { return b } - -func TestName(t *testing.T) { - ourBasicMessage := BasicMessage{ - Int64: 0, - UInt64: 0, - Fixed64: protoenc.FixedU64(0), - SomeString: "", - SomeBytes: nil, - } - encoded1 := must(protoenc.Marshal(&ourBasicMessage))(t) - t.Logf("\n%s", hex.Dump(encoded1)) -} diff --git a/messages/messages_test.go b/messages/messages_test.go index e339448..fc2e9f2 100644 --- a/messages/messages_test.go +++ b/messages/messages_test.go @@ -25,7 +25,11 @@ type BasicMessage struct { } func TestBasicMessage(t *testing.T) { + t.Parallel() + t.Run("check that the outputs of both messages are the same", func(t *testing.T) { + t.Parallel() + ourBasicMessage := BasicMessage{ Int64: 1, UInt64: 2, @@ -43,6 +47,8 @@ func TestBasicMessage(t *testing.T) { }) t.Run("check that the outputs of both zero messages are the same", func(t *testing.T) { + t.Parallel() + ourBasicMessage := BasicMessage{} encoded1 := must(protoenc.Marshal(&ourBasicMessage))(t) basicMessage := protoUnmarshal[messages.BasicMessage](t, encoded1) @@ -53,6 +59,8 @@ func TestBasicMessage(t *testing.T) { }) t.Run("check that the outputs of both somewhat empty messages are the same", func(t *testing.T) { + t.Parallel() + ourBasicMessage := BasicMessage{SomeString: "some string"} encoded1 := must(protoenc.Marshal(&ourBasicMessage))(t) basicMessage := protoUnmarshal[messages.BasicMessage](t, encoded1) @@ -72,7 +80,11 @@ type MessageRepeatedFields struct { } func TestMessageRepeatedFields(t *testing.T) { + t.Parallel() + t.Run("check that the outputs of both messages are the same", func(t *testing.T) { + t.Parallel() + originalMsg := MessageRepeatedFields{ Int64: []int64{1, 2, 3}, UInt64: []uint64{4, 5, 6}, @@ -90,6 +102,8 @@ func TestMessageRepeatedFields(t *testing.T) { }) t.Run("check that the outputs of both zero messages are the same", func(t *testing.T) { + t.Parallel() + ourBasicMessage := MessageRepeatedFields{} encoded1 := must(protoenc.Marshal(&ourBasicMessage))(t) basicMessage := protoUnmarshal[messages.MessageRepeatedFields](t, encoded1) @@ -100,6 +114,8 @@ func TestMessageRepeatedFields(t *testing.T) { }) t.Run("check that the outputs of both somewhat empty messages are the same", func(t *testing.T) { + t.Parallel() + ourBasicMessage := MessageRepeatedFields{SomeString: []string{"some string"}} encoded1 := must(protoenc.Marshal(&ourBasicMessage))(t) basicMessage := protoUnmarshal[messages.MessageRepeatedFields](t, encoded1) @@ -115,7 +131,11 @@ type BasicMessageRep struct { } func TestBasicMessageRep(t *testing.T) { + t.Parallel() + t.Run("check that the outputs of both messages are the same", func(t *testing.T) { + t.Parallel() + originalMsg := BasicMessageRep{ BasicMessage: []BasicMessage{ { @@ -143,6 +163,8 @@ func TestBasicMessageRep(t *testing.T) { }) t.Run("check that the outputs of both zero messages are the same", func(t *testing.T) { + t.Parallel() + ourBasicMessage := BasicMessageRep{} encoded1 := must(protoenc.Marshal(&ourBasicMessage))(t) basicMessage := protoUnmarshal[messages.BasicMessageRep](t, encoded1) @@ -153,6 +175,8 @@ func TestBasicMessageRep(t *testing.T) { }) t.Run("check that the outputs of both somewhat empty messages are the same", func(t *testing.T) { + t.Parallel() + ourBasicMessage := BasicMessageRep{ BasicMessage: []BasicMessage{ { @@ -176,7 +200,11 @@ type MessageComplexFields struct { } func TestMessageComplexFields(t *testing.T) { + t.Parallel() + t.Run("check that the outputs of both messages are the same", func(t *testing.T) { + t.Parallel() + originalMsg := MessageComplexFields{ MapToMsg: map[string]BasicMessage{ "key": { @@ -240,6 +268,8 @@ func TestMessageComplexFields(t *testing.T) { }) t.Run("check that the outputs of both zero messages are the same", func(t *testing.T) { + t.Parallel() + ourBasicMessage := MessageComplexFields{} encoded1 := must(protoenc.Marshal(&ourBasicMessage))(t) basicMessage := protoUnmarshal[messages.MessageComplexFields](t, encoded1) @@ -250,6 +280,8 @@ func TestMessageComplexFields(t *testing.T) { }) t.Run("check that the outputs of both somewhat empty messages are the same", func(t *testing.T) { + t.Parallel() + originalMsg := MessageComplexFields{ MapToMsg: map[string]BasicMessage{ "key": { diff --git a/person_test.go b/person_test.go index 9844051..e55dd0c 100644 --- a/person_test.go +++ b/person_test.go @@ -64,8 +64,8 @@ func Example_protobuf() { // Decode it person2 := Person{} - if err := protoenc.Unmarshal(buf, &person2); err != nil { - panic("Decode failed") + if err = protoenc.Unmarshal(buf, &person2); err != nil { + panic(err) } if !reflect.DeepEqual(person, person2) { diff --git a/predefined_types.go b/predefined_types.go index ea9f478..c9e6c2e 100644 --- a/predefined_types.go +++ b/predefined_types.go @@ -35,7 +35,7 @@ var ( typeFixedS32s = reflect.SliceOf(typeFixedS32) typeFixedS64s = reflect.SliceOf(typeFixedS64) typeDurations = reflect.SliceOf(typeDuration) - typeBytes = reflect.SliceOf(typeOf[byte]()) + typeByte = typeOf[byte]() ) func typeOf[T any]() reflect.Type { diff --git a/protobuf_test.go b/protobuf_test.go index 7254805..f39b243 100644 --- a/protobuf_test.go +++ b/protobuf_test.go @@ -222,54 +222,85 @@ func TestTimeTypesEncodeDecode(t *testing.T) { assert.Equal(t, in.Duration, out.Duration) } -type wrongTestMsg struct { - /* encoding of testMsg is equivalent to the encoding to the following in - a .proto file: - message cipherText { - int32 a = 1; - int32 b = 2; - } - - message MapFieldEntry { - uint32 key = 1; - cipherText value = 2; - } - - message testMsg { - repeated MapFieldEntry map_field = 1; - } - for details see: - https://developers.google.com/protocol-buffers/docs/proto#backwards-compatibility */ - M map[uint32][]cipherText -} - -type rightTestMsg struct { - M map[uint32]*cipherText `protobuf:"1"` -} -type cipherText struct { - A int32 `protobuf:"1"` - B int32 `protobuf:"2"` -} - func TestMapSliceStruct(t *testing.T) { - cv := []cipherText{{}, {}} - msg := &wrongTestMsg{ - M: map[uint32][]cipherText{1: cv}, + t.Parallel() + + type structData struct { + A int32 `protobuf:"1"` + B int32 `protobuf:"2"` } - _, err := protoenc.Marshal(msg) - assert.Error(t, err) + t.Run("test map with slices", func(t *testing.T) { + t.Parallel() - msg2 := &rightTestMsg{ - M: map[uint32]*cipherText{1: {4, 5}}, - } + type wrongMap struct { + M map[uint32][]structData `protobuf:"1"` + } - buff, err := protoenc.Marshal(msg2) - assert.NoError(t, err) + cv := []structData{{}, {}} + msg := &wrongMap{ + M: map[uint32][]structData{1: cv}, + } - dec := &rightTestMsg{} - err = protoenc.Unmarshal(buff, dec) - assert.NoError(t, err) + _, err := protoenc.Marshal(msg) + assert.Error(t, err) + }) + + t.Run("test wrong map ptr to slice", func(t *testing.T) { + t.Parallel() + + type wrongMap struct { + M map[uint32]*[]structData `protobuf:"1"` + } + + ptrCv := &[]structData{{}, {}} + msg := &wrongMap{ + M: map[uint32]*[]structData{1: ptrCv}, + } + + _, err := protoenc.Marshal(msg) + assert.Error(t, err) + }) + + t.Run("test right map with ptr to struct", func(t *testing.T) { + t.Parallel() + + type rightMap struct { + M map[uint32]*structData `protobuf:"1"` + } + + msg := &rightMap{ + M: map[uint32]*structData{1: {4, 5}}, + } + + buff, err := protoenc.Marshal(msg) + assert.NoError(t, err) + + dec := &rightMap{} + err = protoenc.Unmarshal(buff, dec) + assert.NoError(t, err) + + assert.True(t, reflect.DeepEqual(dec, msg)) + }) + + t.Run("test right map with empty struct", func(t *testing.T) { + t.Parallel() + + type rightMap struct { + M map[uint32]struct{} `protobuf:"1"` + } + + msg := &rightMap{ + M: map[uint32]struct{}{1: {}}, + } + + buff, err := protoenc.Marshal(msg) + assert.NoError(t, err) + + dec := &rightMap{} + err = protoenc.Unmarshal(buff, dec) + assert.NoError(t, err) - assert.True(t, reflect.DeepEqual(dec, msg2)) + assert.True(t, reflect.DeepEqual(dec, msg)) + }) } diff --git a/scanner.go b/scanner.go new file mode 100644 index 0000000..dff5ee0 --- /dev/null +++ b/scanner.go @@ -0,0 +1,334 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package protoenc + +import ( + "errors" + "fmt" + "math" + "reflect" + + "google.golang.org/protobuf/encoding/protowire" +) + +func makeScanner(buf []byte) *scanner { + return &scanner{ + buf: buf, + } +} + +//nolint:govet +type scanner struct { + buf []byte + + lastFieldNum protowire.Number + lastType protowire.Type + + lastValueType valueType + lastComplex complexValue + lastPrimitive primitiveValue + + lastErr error +} + +type valueType int8 + +const ( + valueTypeInvalid valueType = iota + valueTypeComplex + valueTypePrimitive +) + +func (s *scanner) Scan() bool { + if len(s.buf) == 0 || s.lastErr != nil { + return false + } + + s.lastErr = s.scan() + + return s.lastErr == nil +} + +func (s *scanner) Err() error { + return s.lastErr +} + +func (s *scanner) FieldNum() protowire.Number { + return s.lastFieldNum +} + +func (s *scanner) Primitive() (primitiveValue, bool) { + if s.lastValueType != valueTypePrimitive { + return primitiveValue{}, false + } + + return s.lastPrimitive, true +} + +func (s *scanner) Complex() (complexValue, bool) { + if s.lastValueType != valueTypeComplex { + return complexValue{}, false + } + + return s.lastComplex, true +} + +func (s *scanner) scan() error { + s.lastValueType = valueTypeInvalid + s.lastPrimitive = primitiveValue{} + s.lastComplex = complexValue{} + s.lastFieldNum = 0 + s.lastType = 0 + + fieldnum, wiretype, n := protowire.ConsumeTag(s.buf) + if n <= 0 { + return errors.New("bad protobuf field key") + } + + s.lastFieldNum = fieldnum + s.lastType = wiretype + s.buf = s.buf[n:] + + ds := makeDataScanner(s.lastType, s.buf) + + if !ds.Scan() { + if err := ds.Err(); err != nil { + return err + } + + return errors.New("protobuf data scanner failed") + } + + s.buf = ds.buf + + if val, ok := ds.ComplexValue(); ok { + s.lastComplex = val + s.lastValueType = valueTypeComplex + } else if val, ok := ds.PrimitiveValue(); ok { + s.lastPrimitive = val + s.lastValueType = valueTypePrimitive + } else { + return errors.New("bad value type") + } + + return nil +} + +type primitiveValue struct { + value uint64 + wireType protowire.Type +} + +func (v *primitiveValue) WireType() protowire.Type { + return v.wireType +} + +func (v *primitiveValue) Bool() (bool, error) { + if v.wireType != protowire.VarintType { + return false, fmt.Errorf("bad wiretype for bool: %v", v.wireType) + } + + return protowire.DecodeBool(v.value), nil +} + +func (v *primitiveValue) Int() (int64, error) { + switch v.wireType { //nolint:exhaustive + case protowire.VarintType: + return int64(v.value), nil + case protowire.Fixed32Type: + return int64(int32(v.value)), nil + case protowire.Fixed64Type: + return int64(v.value), nil + default: + return -1, fmt.Errorf("bad wiretype for int: %v", v.wireType) + } +} + +func (v *primitiveValue) Uint() (uint64, error) { + switch v.wireType { //nolint:exhaustive + case protowire.VarintType: + return v.value, nil + case protowire.Fixed32Type: + return uint64(uint32(v.value)), nil + case protowire.Fixed64Type: + return v.value, nil + default: + return 0, fmt.Errorf("bad wiretype for uint: %v", v.wireType) + } +} + +func (v *primitiveValue) Float32() (float32, error) { + if v.wireType != protowire.Fixed32Type { + return 0, fmt.Errorf("bad wiretype for float32: %v", v.wireType) + } + + return math.Float32frombits(uint32(v.value)), nil +} + +func (v *primitiveValue) Float64() (float64, error) { + if v.wireType != protowire.Fixed64Type { + return 0, fmt.Errorf("bad wiretype for float64: %v", v.wireType) + } + + return math.Float64frombits(v.value), nil +} + +type complexValue struct { + value []byte + wireType protowire.Type +} + +func (c *complexValue) Bytes() ([]byte, error) { + if c.wireType != protowire.BytesType { + return nil, fmt.Errorf("bad wiretype for bytes: %v", c.wireType) + } + + return c.value, nil +} + +func makeDataScanner(wiretype protowire.Type, buf []byte) dataScanner { + var lastErr error + if len(buf) == 0 { + lastErr = errors.New("buffer for data scanner cannot be empty") + } + + return dataScanner{ + wiretype: wiretype, + buf: buf, + lastErr: lastErr, + } +} + +type dataScanner struct { + lastErr error + lastBuf []byte + buf []byte + lastVal uint64 + wiretype protowire.Type +} + +func (s *dataScanner) Scan() bool { + if s.wiretype < 0 { + s.lastErr = fmt.Errorf("negative wiretype: %v", s.wiretype) + } + + if len(s.buf) == 0 || s.lastErr != nil { + return false + } + + s.lastBuf = nil + s.lastVal = 0 + + switch s.wiretype { //nolint:exhaustive + case protowire.VarintType: + val, n := protowire.ConsumeVarint(s.buf) + if n <= 0 { + s.lastErr = errors.New("bad protobuf varint value") + + return false + } + + s.buf = s.buf[n:] + s.lastVal = val + + case protowire.Fixed32Type: + val, n := protowire.ConsumeFixed32(s.buf) + if n <= 0 { + s.lastErr = errors.New("bad protobuf 32-bit value") + + return false + } + + s.buf = s.buf[n:] + s.lastVal = uint64(val) + + case protowire.Fixed64Type: + val, n := protowire.ConsumeFixed64(s.buf) + if n <= 0 { + s.lastErr = errors.New("bad protobuf 64-bit value") + + return false + } + + s.buf = s.buf[n:] + s.lastVal = val + + case protowire.BytesType: + val, n := protowire.ConsumeBytes(s.buf) + if n <= 0 { + s.lastErr = errors.New("bad protobuf length-delimited value") + } + + s.buf = s.buf[n:] + s.lastBuf = val[:len(val):len(val)] + + default: + s.lastErr = errors.New("unknown protobuf wire-type") + } + + return true +} + +func (s *dataScanner) Err() error { + return s.lastErr +} + +func (s *dataScanner) ComplexValue() (complexValue, bool) { + return complexValue{ + wireType: s.wiretype, + value: s.lastBuf, + }, s.lastBuf != nil +} + +func (s *dataScanner) PrimitiveValue() (primitiveValue, bool) { + return primitiveValue{ + wireType: s.wiretype, + value: s.lastVal, + }, s.lastBuf == nil +} + +func (s *dataScanner) Wiretype() protowire.Type { + if s.wiretype < 0 { + panic(fmt.Errorf("invalid wiretype: %v", s.wiretype)) + } + + return s.wiretype +} + +func getDataScannerFor(eltype reflect.Type, buf []byte) (dataScanner, bool, error) { + switch eltype.Kind() { //nolint:exhaustive + case reflect.Uint8, reflect.Uint16, reflect.Int8, reflect.Int16: + return makeDataScanner(protowire.Fixed32Type, buf), true, nil + + case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint32, reflect.Uint64, reflect.Uint: + if (eltype.Kind() == reflect.Int || eltype.Kind() == reflect.Uint) && eltype.Size() < 8 { + return dataScanner{}, false, errors.New("detected a 32bit machine, please either use (u)int64 or (u)int32") + } + + switch eltype { + case typeFixedS32: + return makeDataScanner(protowire.Fixed32Type, buf), true, nil + case typeFixedS64: + return makeDataScanner(protowire.Fixed64Type, buf), true, nil + case typeFixedU32: + return makeDataScanner(protowire.Fixed32Type, buf), true, nil + case typeFixedU64: + return makeDataScanner(protowire.Fixed64Type, buf), true, nil + case typeDuration: + return dataScanner{}, false, nil + default: + return makeDataScanner(protowire.VarintType, buf), true, nil + } + + case reflect.Float32: + return makeDataScanner(protowire.Fixed32Type, buf), true, nil + + case reflect.Float64: + return makeDataScanner(protowire.Fixed64Type, buf), true, nil + default: + return dataScanner{}, false, nil + } +} diff --git a/slice_test.go b/slice_test.go new file mode 100644 index 0000000..db9fa79 --- /dev/null +++ b/slice_test.go @@ -0,0 +1,454 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package protoenc_test + +import ( + "encoding/hex" + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/siderolabs/protoenc" +) + +func TestSliceEncodingDecoding(t *testing.T) { + t.Parallel() + + //nolint:govet + type localType struct { + A float32 `protobuf:"1"` + B struct { + C []int `protobuf:"1"` + } `protobuf:"2"` + C string `protobuf:"3"` + D map[protoenc.FixedS32]protoenc.FixedS64 `protobuf:"4"` + E map[protoenc.FixedU64]protoenc.FixedU32 `protobuf:"5"` + F map[float64]*struct { + G int64 `protobuf:"1"` + } `protobuf:"6"` + } + + type localMap struct { //nolint:unused + A map[int]int `protobuf:"1"` + } + + tests := map[string]struct { + fn func(t *testing.T) + }{ + "bool": {testSlice[bool]}, + "string": {testSlice[string]}, + "int": {testSlice[int]}, + "int32": {testSlice[int32]}, + "int64": {testSlice[int64]}, + "uint": {testSlice[uint]}, + "uint32": {testSlice[uint32]}, + "uint64": {testSlice[uint64]}, + "float32": {testSlice[float32]}, + "float64": {testSlice[float64]}, + "sliceWrapper[int64]": {testSlice[sliceWrapper[int64]]}, + "sliceWrapper[[]uint8]": {testSlice[sliceWrapper[[]uint8]]}, + "sliceWrapper[*[]uint8]": {testSlice[sliceWrapper[*[]uint8]]}, + "FixedS32": {testSlice[protoenc.FixedS32]}, + "FixedS64": {testSlice[protoenc.FixedS64]}, + "FixedU32": {testSlice[protoenc.FixedU32]}, + "FixedU64": {testSlice[protoenc.FixedU64]}, + "time.Time": {testSlice[time.Time]}, + "Struct": {testSlice[localType]}, + "sliceWrapper[localMap]": {testSlice[sliceWrapper[localMap]]}, + "time.Duration": {testSlice[time.Duration]}, + } + + for name, test := range tests { + t.Run(name, test.fn) + } +} + +func testSlice[T any](t *testing.T) { + t.Parallel() + + // This is our best-effort attempt to generate a random slice of values. + for i := 0; i < 100; i++ { + original := sliceWrapper[T]{} + faker := gofakeit.New(Seed + int64(i)) + + require.NoError(t, faker.Struct(&original)) + + // This is needed because faker cannot fill time.Time in slices. + if timeSlice, ok := any(&original.Arr).(*[]time.Time); ok { + for i := range *timeSlice { + (*timeSlice)[i] = faker.Date() + } + } + + buf := must(protoenc.Marshal(&original))(t) + target := sliceWrapper[T]{} + + err := protoenc.Unmarshal(buf, &target) + if err != nil { + t.Log(original) + + require.FailNow(t, "", "%d iteration: %v", i, err) + } + + if !assert.Equal(t, original, target) { + t.Log(hex.Dump(buf)) + t.FailNow() + } + } +} + +func TestSliceEncodingResult(t *testing.T) { + t.Parallel() + + encodedInts := hexToBytes(t, "0a 03 01 02 03") + encodedFixedU32s := hexToBytes(t, "0a 0c 01 00 00 00 02 00 00 00 03 00 00 00") + encodedFloat32s := hexToBytes(t, "0a 0c 00 00 80 3f 00 00 00 40 00 00 40 40 ") + + // here 0a (which is field=1, type=2 encoded) begin to repeat because of the protbuf specification + encodedStrings := hexToBytes(t, "[0a 02 [61 62]] [0a 02 [62 63]] [0a 03 [63 64 65]]") + encodedWrappedInts := hexToBytes(t, "[0a 01 [01]] [0a 01 [02]] [0a 01 [03]]") + + // here 08 is also begin to repeat because we use inner structure + encodedLocalType := hexToBytes(t, "[0a 02 [08 01]] [0a 02 [08 02]] [0a 02 [08 03]]") + + encodedTime := hexToBytes(t, "[0a 02 [10 01]] [0a 02 [10 02]] [0a 02 [10 03]]") + + encodedDuration := hexToBytes(t, "[0a 02 [08 10]] [0a 02 [08 11]] [0a 02 [08 12]]") + + emptyStructs := hexToBytes(t, "[0a 00] [0a 00] [0a 00]") + + type localType struct { + A int `protobuf:"1"` + } + + type localEmptyType struct{} + + tests := []struct { //nolint:govet + name string + fn func(t *testing.T) + }{ + { + "ints should be encoded in 'packed' form", + testSliceEncodingResult([]int{1, 2, 3}, encodedInts), + }, + { + "uints should be encoded in 'packed' form", + testSliceEncodingResult([]uint{1, 2, 3}, encodedInts), + }, + { + "float32s should be encoded in 'packed' form", + testSliceEncodingResult([]float32{1, 2, 3}, encodedFloat32s), + }, + { + "FixedU32s should be encoded in 'packed' form", + testSliceEncodingResult([]protoenc.FixedU32{1, 2, 3}, encodedFixedU32s), + }, + { + "strings should be encoded in normal form", + testSliceEncodingResult([]string{"ab", "bc", "cde"}, encodedStrings), + }, + { + "wrapped values should be encoded in normal form", + testSliceEncodingResult([]intWrapper{{1}, {2}, {3}}, encodedWrappedInts), + }, + { + "wrapped values with no marshallers should be encoded in normal form", + testSliceEncodingResult([]localType{{1}, {2}, {3}}, encodedLocalType), + }, + { + "time values should be encoded in normal form", + testSliceEncodingResult([]time.Time{time.Unix(0, 1), time.Unix(0, 2), time.Unix(0, 3)}, encodedTime), + }, + { + "duration values should be encoded in normal form", + testSliceEncodingResult([]time.Duration{16 * time.Second, 17 * time.Second, 18 * time.Second}, encodedDuration), + }, + { + "nil pointers will be skipped, and it should be encoded in normal form", + testSliceEncodingResult([]*string{nil, ptr("ab"), ptr("bc"), nil, ptr("cde")}, encodedStrings), + }, + { + "empty structs should return only tags", + testSliceEncodingResult([]localEmptyType{{}, {}, {}}, emptyStructs), + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, test.fn) + } +} + +func testSliceEncodingResult[T any](slc []T, expected []byte) func(t *testing.T) { + return func(t *testing.T) { + t.Parallel() + + original := sliceWrapper[T]{Arr: slc} + buf := must(protoenc.Marshal(&original))(t) + + assert.Equal(t, expected, buf) + } +} + +func TestSmallIntegers(t *testing.T) { + t.Parallel() + + encodedBytes := hexToBytes(t, "0a 03 01 FF 03") + encodedFixed := hexToBytes(t, "0a 0c [01 00 00 00] [ff 00 00 00] [03 00 00 00]") + encodedFixedNegative := hexToBytes(t, "0a 0c [01 00 00 00] [ff ff ff ff] [03 00 00 00]") + encodedUint16s := hexToBytes(t, "0a 0c [01 00 00 00] [ff ff 00 00] [03 00 00 00]") + + type customByte byte + + type customType struct { + Int16 int16 `protobuf:"1"` + Uint16 uint16 `protobuf:"3"` + Int8 int8 `protobuf:"2"` + Uint8 uint8 `protobuf:"4"` + CustomByte customByte `protobuf:"5"` + } + + encodedCustomType := hexToBytes(t, "0a 19 [0d [ff ff ff ff]] [1d [ff ff 00 00]] [15 [ff ff ff ff]] [25 [ff 00 00 00]] [2d [ff 00 00 00]]") + + tests := []struct { //nolint:govet + name string + fn func(t *testing.T) + }{ + { + "array of bytes should be encoded in 'bytes' form", + testEncodeDecodeWrapped([...]byte{1, 0xFF, 3}, encodedBytes), + }, + { + "array of custom byte types should be encoded in 'fixed32' form", + testEncodeDecodeWrapped([...]customByte{1, 0xFF, 3}, encodedFixed), + }, + { + "slice of custom byte type should be encoded in 'fixed32' form", + testEncodeDecodeWrapped([]customByte{1, 0xFF, 3}, encodedFixed), + }, + { + "slice of int8 should be encoded in 'fixed32' form", + testEncodeDecodeWrapped([]int8{1, -1, 3}, encodedFixedNegative), + }, + { + "slice of int16 type should be encoded in 'fixed32' form", + testEncodeDecodeWrapped([]int16{1, -1, 3}, encodedFixedNegative), + }, + { + "slice of uint16 type should be encoded in 'fixed32' form", + testEncodeDecodeWrapped([]uint16{1, 0xFFFF, 3}, encodedUint16s), + }, + { + "customType should be encoded in 'fixed32' form", + testEncodeDecodeWrapped(customType{ + Int16: -1, + Uint16: 0xFFFF, + Int8: -1, + Uint8: 0xFF, + CustomByte: 0xFF, + }, encodedCustomType), + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, test.fn) + } +} + +func testEncodeDecodeWrapped[T any](slc T, expected []byte) func(t *testing.T) { + return func(t *testing.T) { + t.Parallel() + + original := OneFieldStruct[T]{Field: slc} + buf := must(protoenc.Marshal(&original))(t) + + require.Equal(t, expected, buf) + + var decoded OneFieldStruct[T] + + require.NoError(t, protoenc.Unmarshal(buf, &decoded)) + require.Equal(t, original, decoded) + } +} + +// newSliceWrapper returns a new wrapper type around slice field with the given elements. +func newSliceWrapper[T any](elements ...T) *sliceWrapper[T] { + return &sliceWrapper[T]{elements} +} + +type sliceWrapper[T any] struct { + Arr []T `protobuf:"1"` +} + +func TestDisallowedTypes(t *testing.T) { + t.Parallel() + + type localMapOfSlices struct { //nolint:unused + A map[int][]int `protobuf:"1"` + } + + type complexKey struct { //nolint:unused + A int `protobuf:"1"` + } + + type localMapWithComplexKey struct { //nolint:unused + A map[complexKey]int `protobuf:"1"` + } + + type localMapWithPtrKey struct { //nolint:unused + A map[*int]int `protobuf:"1"` + } + + tests := map[string]struct { + fn func(t *testing.T) + }{ + "sliceWrapper[map[string]string]": { + fn: testDisallowedTypes[sliceWrapper[map[string]string]], + }, + "sliceWrapper[localMapOfSlices]": { + fn: testDisallowedTypes[sliceWrapper[localMapOfSlices]], + }, + "sliceWrapper[localMapWithComplexKey]": { + fn: testDisallowedTypes[sliceWrapper[localMapWithComplexKey]], + }, + "sliceWrapper[localMapWithPtrKey]": { + fn: testDisallowedTypes[sliceWrapper[localMapWithPtrKey]], + }, + "map[string]string": { + fn: testDisallowedTypes[map[string]string], + }, + "sliceWrapper[*int]": { + fn: testDisallowedTypes[sliceWrapper[*int]], + }, + "sliceWrapper[*int8]": { + fn: testDisallowedTypes[sliceWrapper[*int8]], + }, + "sliceWrapper[*[]int8]": { + fn: testDisallowedTypes[sliceWrapper[*[]int8]], + }, + "sliceWrapper[*[]int16]": { + fn: testDisallowedTypes[sliceWrapper[*[]int16]], + }, + "sliceWrapper[*[]int]": { + fn: testDisallowedTypes[sliceWrapper[*[]int]], + }, + "sliceWrapper[[][]int]": { + fn: testDisallowedTypes[sliceWrapper[[][]int]], + }, + "sliceWrapper[*[][]int]": { + fn: testDisallowedTypes[sliceWrapper[*[][]int]], + }, + "sliceWrapper[Valuer[int]]": { + fn: func(t *testing.T) { + t.Parallel() + + arr := newSliceWrapper[Valuer[int]](&Value[int]{1}, &Value[int]{1}) + buf, err := protoenc.Marshal(arr) + require.NoError(t, err) + + var target sliceWrapper[Valuer[int]] + err = protoenc.Unmarshal(buf, &target) + require.Error(t, err) + assert.Contains(t, err.Error(), "nil interface fields are not supported") + }, + }, + } + + for name, test := range tests { + t.Run(name, test.fn) + } +} + +func testDisallowedTypes[T any](t *testing.T) { + t.Parallel() + + faker := gofakeit.New(Seed) + + var original T + + require.NoError(t, faker.Struct(&original)) + + _, err := protoenc.Marshal(&original) + require.Error(t, err) + assert.Regexp(t, "(unsupported type)|(takes a struct)", err.Error()) +} + +func TestDuration(t *testing.T) { + t.Parallel() + + expected := newSliceWrapper(time.Second*11, time.Second*12, time.Second*13) + buf := must(protoenc.Marshal(expected))(t) + + t.Log(hex.Dump(buf)) + + var actual sliceWrapper[time.Duration] + + require.NoError(t, protoenc.Unmarshal(buf, &actual)) + assert.Equal(t, expected.Arr, actual.Arr) +} + +func TestTime(t *testing.T) { + t.Parallel() + + expected := newSliceWrapper(time.Unix(11, 0).UTC(), time.Unix(12, 0).UTC(), time.Unix(13, 0).UTC()) + buf := must(protoenc.Marshal(expected))(t) + + t.Log(hex.Dump(buf)) + + var actual sliceWrapper[time.Time] + + require.NoError(t, protoenc.Unmarshal(buf, &actual)) + assert.Equal(t, expected.Arr, actual.Arr) +} + +func TestSliceToArray(t *testing.T) { + t.Parallel() + + expected := newSliceWrapper(1, 2, 3, 4, 5, 6, 7, 8, 9, 100500) + buf := must(protoenc.Marshal(expected))(t) + + t.Log(hex.Dump(buf)) + + type structWithArray struct { + Arr [10]int `protobuf:"1"` + } + + var actual structWithArray + + require.NoError(t, protoenc.Unmarshal(buf, &actual)) + assert.Equal(t, expected.Arr, actual.Arr[:]) +} + +func TestSlicesOfEmpty(t *testing.T) { + type Empty struct{} + + type NotEmpty struct { + Field int `protobuf:"1"` + } + + t.Run("not empty to empty", func(t *testing.T) { + wrapper := newSliceWrapper(NotEmpty{Field: 1}, NotEmpty{Field: 1}, NotEmpty{Field: 1}) + buf := must(protoenc.Marshal(wrapper))(t) + + result := sliceWrapper[Empty]{} + + require.NoError(t, protoenc.Unmarshal(buf, &result)) + require.Len(t, result.Arr, 3) + }) + + t.Run("empty to not empty", func(t *testing.T) { + wrapper := newSliceWrapper(Empty{}, Empty{}, Empty{}) + buf := must(protoenc.Marshal(wrapper))(t) + + result := sliceWrapper[NotEmpty]{} + + require.NoError(t, protoenc.Unmarshal(buf, &result)) + require.Len(t, result.Arr, 3) + }) +} diff --git a/type_cache.go b/type_cache.go index 16025ec..3b67ab2 100644 --- a/type_cache.go +++ b/type_cache.go @@ -59,16 +59,17 @@ func structFields(typ reflect.Type) ([]FieldData, error) { for i := 0; i < typ.NumField(); i++ { typField := typ.Field(i) + // Report error sooner than later + if typField.Anonymous && deref(typField.Type).Kind() != reflect.Struct { + return nil, fmt.Errorf("%s.%s.%s is not a struct type", typ.PkgPath(), typ.Name(), typField.Name) + } + // Skipping private types if !typField.IsExported() { continue } if typField.Anonymous { - if deref(typField.Type).Kind() != reflect.Struct { - return nil, fmt.Errorf("%s.%s.%s is not a struct type", typ.PkgPath(), typ.Name(), typField.Name) - } - fields, err := structFields(typField.Type) if err != nil { return nil, err @@ -100,15 +101,11 @@ func structFields(typ reflect.Type) ([]FieldData, error) { }) } - if len(result) == 0 { - return nil, fmt.Errorf("%s.%s has no exported fields", typ.PkgPath(), typ.Name()) - } - return result, nil } func deref(typ reflect.Type) reflect.Type { - for typ.Kind() == reflect.Ptr { + for typ.Kind() == reflect.Pointer { typ = typ.Elem() } @@ -225,7 +222,7 @@ func RegisterEncoderDecoder[T any, Enc func(T) ([]byte, error), Dec func([]byte) } func indirect(typ reflect.Type) reflect.Type { - if typ.Kind() == reflect.Ptr { + if typ.Kind() == reflect.Pointer { return typ.Elem() } diff --git a/unmarshal.go b/unmarshal.go index 32a0f97..521cf4c 100644 --- a/unmarshal.go +++ b/unmarshal.go @@ -8,7 +8,6 @@ import ( "encoding" "errors" "fmt" - "math" "reflect" "google.golang.org/protobuf/encoding/protowire" @@ -17,9 +16,6 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) -// Decoder is the main struct used to decode a protobuf blob. -type unmarshaller struct{} - // Unmarshal a protobuf value into a Go value. // The caller must pass a pointer to the struct to decode into. func Unmarshal(buf []byte, ptr interface{}) error { @@ -49,16 +45,14 @@ func unmarshal(buf []byte, structPtr interface{}) (returnErr error) { } val := reflect.ValueOf(structPtr) - if val.Kind() != reflect.Ptr { + if val.Kind() != reflect.Pointer { return errors.New("decode has been given a non pointer type") } - de := unmarshaller{} - - return de.unmarshalStruct(buf, val.Elem()) + return unmarshalStruct(val.Elem(), buf) } -func (u *unmarshaller) unmarshalStruct(buf []byte, structVal reflect.Value) error { +func unmarshalStruct(structVal reflect.Value, buf []byte) error { if structVal.Kind() != reflect.Struct { return errors.New("not a struct") } @@ -79,37 +73,56 @@ func (u *unmarshaller) unmarshalStruct(buf []byte, structVal reflect.Value) erro return err } - for len(buf) > 0 { - fieldnum, wiretype, n := protowire.ConsumeTag(buf) - if n <= 0 { - return errors.New("bad protobuf field key") - } - - buf = buf[n:] + rdr := makeScanner(buf) + for rdr.Scan() { var field reflect.Value - fieldIndex := findField(structFields, fieldnum) + fieldIndex := findField(structFields, rdr.FieldNum()) if fieldIndex != -1 { field = initStructField(structVal, structFields[fieldIndex]) } - // Decode the field's value - rem, err := u.decodeValue(wiretype, buf, field) - if err != nil { + if err = putValue(field, rdr); err != nil { if fieldIndex != -1 { - return fmt.Errorf("error while unmarshalling field %+v: %w", structFields[fieldIndex].Field, err) + return fmt.Errorf("error while unmarshalling field '%s' of struct '%s.%s': %w", + structFields[fieldIndex].Field.Name, + structVal.Type().PkgPath(), + structVal.Type().Name(), + err) } return err } + } - buf = rem + if err := rdr.Err(); err != nil { + return err } return nil } +func putValue(dst reflect.Value, rdr *scanner) error { + if val, ok := rdr.Primitive(); ok { + err := unmarshalPrimitive(dst, val) + if err != nil { + return fmt.Errorf("error while unmarshalling primitive '%v': %w", val, err) + } + + return nil + } else if val, ok := rdr.Complex(); ok { + err := unmarshalBytes(dst, val) + if err != nil { + return fmt.Errorf("error while unmarshalling complex '%v': %w", val, err) + } + + return nil + } + + panic("unexpected value") +} + func zeroStructFields(val reflect.Value) { for i := 0; i < val.NumField(); i++ { field := val.Field(i) @@ -137,7 +150,7 @@ func initStructField(structField reflect.Value, fieldData FieldData) reflect.Val path := index[:i+1] result = structField.FieldByIndex(path) - if result.Kind() == reflect.Ptr && result.IsNil() { + if result.Kind() == reflect.Pointer && result.IsNil() { result.Set(reflect.New(result.Type().Elem())) } } @@ -163,174 +176,224 @@ func findField(fields []FieldData, fieldnum protowire.Number) int { return idx } -// Pull a value from the buffer and put it into a reflective Value. -func (u *unmarshaller) decodeValue(wiretype protowire.Type, buf []byte, dst reflect.Value) ([]byte, error) { - var ( - // Break out the value from the buffer based on the wire type - decodedValue uint64 - n int - decodedBytes []byte - ) +func tryDecodeFunc(vb []byte, dst reflect.Value) (bool, error) { + dec, ok := decoders.Get(dst.Type()) + if !ok { + return false, nil + } - switch wiretype { //nolint:exhaustive - case protowire.VarintType: - decodedValue, n = protowire.ConsumeVarint(buf) - if n <= 0 { - return nil, errors.New("bad protobuf varint value") - } + if err := dec(vb, dst); err != nil { + return false, err + } + + return true, nil +} - buf = buf[n:] +func mapEntry(dstEntry reflect.Value, buf []byte) error { + entryKey := reflect.New(dstEntry.Type().Key()).Elem() + entryVal := reflect.New(dstEntry.Type().Elem()).Elem() - case protowire.Fixed32Type: - var res uint32 - res, n = protowire.ConsumeFixed32(buf) + s := makeScanner(buf) - if n <= 0 { - return nil, errors.New("bad protobuf 32-bit value") + // scan key + if !s.Scan() { + if s.Err() != nil { + return s.Err() } - decodedValue = uint64(res) - buf = buf[n:] + return errors.New("map key is missing") + } + + // map key can only be a primitive type or a string + switch entryKey.Kind() { //nolint:exhaustive + case reflect.Struct, reflect.Array, reflect.Interface, reflect.Pointer: + return errors.New("map key cannot be struct, array, interface or pointer") + } - case protowire.Fixed64Type: - var res uint64 - res, n = protowire.ConsumeFixed64(buf) + if err := putValue(entryKey, s); err != nil { + return fmt.Errorf("failed to unmarshal map key type:'%s': %w", entryKey.Type().String(), err) + } - if n <= 0 { - return nil, errors.New("bad protobuf 64-bit value") + // map value cannot be slice or array if ([]uint8 and [n]uint8 are exceptions) + switch entryVal.Kind() { //nolint:exhaustive + case reflect.Slice, reflect.Array: + if entryVal.Type().Elem() == typeByte { + break } - decodedValue = res - buf = buf[n:] + fallthrough + case reflect.Interface: + return errors.New("map value cannot be non byte slice, array or interface") + } - case protowire.BytesType: - decodedBytes, n = protowire.ConsumeBytes(buf) - if n <= 0 { - return nil, errors.New("bad protobuf length-delimited value") + // scan value + if s.Scan() { + if err := putValue(entryVal, s); err != nil { + return fmt.Errorf("failed to unmarshal map value type:'%s': %w", entryKey.Type().String(), err) } + } - decodedBytes = decodedBytes[:len(decodedBytes):len(decodedBytes)] - buf = buf[n:] + if s.Err() != nil { + return fmt.Errorf("map scanning failed: %w", s.Err()) + } - default: - return nil, errors.New("unknown protobuf wire-type") + // scan more and fail if there is more + if s.Scan() { + return errors.New("map entry cannot have several values") } - if err := u.putInto(dst, wiretype, decodedValue, decodedBytes); err != nil { - return nil, err + if !entryKey.IsValid() || !entryVal.IsValid() { + return errors.New("proto: bad map data: missing key/val") } - return buf, nil + dstEntry.SetMapIndex(entryKey, entryVal) + + return nil } -//nolint:gocognit,gocyclo,cyclop -func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, decodedValue uint64, decodedBytes []byte) error { +func unmarshalPrimitive(dst reflect.Value, value primitiveValue) error { // Value is not settable (invalid reflect.Value, private) if !dst.CanSet() { return nil } - // Check predefined pb types - switch dst.Type() { - case typeTime: - if wiretype != protowire.BytesType { - return fmt.Errorf("bad wiretype for time.Time: %v", wiretype) + switch dst.Kind() { //nolint:exhaustive + case reflect.Pointer: + if dst.IsNil() { + err := instantiate(dst) + if err != nil { + return err + } } - var result timestamppb.Timestamp + return unmarshalPrimitive(dst.Elem(), value) - err := proto.Unmarshal(decodedBytes, &result) + case reflect.Bool: + val, err := value.Bool() if err != nil { return err } - dst.Set(reflect.ValueOf(result.AsTime())) + dst.SetBool(val) return nil - case typeDuration: - if wiretype != protowire.BytesType { - return fmt.Errorf("bad wiretype for time.Duration: %v", wiretype) - } - var result durationpb.Duration + case reflect.Int, reflect.Int32, reflect.Int64, + reflect.Int8, reflect.Int16: // Those two are a special case + if dst.Kind() == reflect.Int && dst.Type().Size() < 8 { + return errors.New("detected a 32bit machine, please use either int64 or int32") + } - err := proto.Unmarshal(decodedBytes, &result) + val, err := value.Int() if err != nil { return err } - dst.Set(reflect.ValueOf(result.AsDuration())) + dst.SetInt(val) return nil - } - switch dst.Kind() { //nolint:exhaustive - case reflect.Bool: - if wiretype != protowire.VarintType { - return fmt.Errorf("bad wiretype for bool: %v", wiretype) + case reflect.Uint, reflect.Uint32, reflect.Uint64, + reflect.Uint8, // This is a special case for uint8 kind, []uint8 values will be decoded as protobuf 'bytes' + reflect.Uint16: + if dst.Kind() == reflect.Uint && dst.Type().Size() < 8 { + return errors.New("detected a 32bit machine, please use either uint64 or uint32") } - if decodedValue > 1 { - return errors.New("invalid bool value") + val, err := value.Uint() + if err != nil { + return err } - dst.SetBool(protowire.DecodeBool(decodedValue)) + dst.SetUint(val) - case reflect.Int, reflect.Int32, reflect.Int64: - // Signed integers may be encoded either zigzag-varint or fixed - // Note that protobufs don't support 8- or 16-bit ints. - if dst.Kind() == reflect.Int && dst.Type().Size() < 8 { - return errors.New("detected a 32bit machine, please use either int64 or int32") - } + return nil - sv, err := decodeSignedInt(wiretype, decodedValue) + case reflect.Float32: + val, err := value.Float32() if err != nil { - fmt.Println("Error Reflect.Int for decodedValue=", decodedValue, "wiretype=", wiretype, "for Value=", dst.Type().Name()) - return err } - dst.SetInt(sv) + dst.SetFloat(float64(val)) - case reflect.Uint, reflect.Uint32, reflect.Uint64: - // Varint-encoded 32-bit and 64-bit unsigned integers. - if dst.Kind() == reflect.Uint && dst.Type().Size() < 8 { - return errors.New("detected a 32bit machine, please use either uint64 or uint32") - } + return nil - switch wiretype { //nolint:exhaustive - case protowire.VarintType: - dst.SetUint(decodedValue) - case protowire.Fixed32Type: - dst.SetUint(uint64(uint32(decodedValue))) - case protowire.Fixed64Type: - dst.SetUint(decodedValue) - default: - return errors.New("bad wiretype for uint") + case reflect.Float64: + val, err := value.Float64() + if err != nil { + return err } - case reflect.Float32: - if wiretype != protowire.Fixed32Type { - return errors.New("bad wiretype for float32") + dst.SetFloat(val) + + return nil + + default: + return fmt.Errorf("unsupported primitive kind " + dst.Kind().String()) + } +} + +// Instantiate an arbitrary type, handling dynamic interface types. +// Returns a Ptr value. +func instantiate(dst reflect.Value) error { + dstType := dst.Type() + + if dstType.Kind() == reflect.Interface { + return fmt.Errorf("cannot instantiate interface type %s", dstType.Name()) + } + + dst.Set(reflect.New(dstType.Elem())) + + return nil +} + +//nolint:cyclop +func unmarshalBytes(dst reflect.Value, value complexValue) (err error) { + defer func() { + if err != nil { + err = fmt.Errorf("failed to unmarshal bytes: %w", err) } + }() + + // Value is not settable (invalid reflect.Value, private) + if !dst.CanSet() { + return nil + } - dst.SetFloat(float64(math.Float32frombits(uint32(decodedValue)))) + bytes, err := value.Bytes() + if err != nil { + return fmt.Errorf("bad wiretype for complex types: %w", err) + } - case reflect.Float64: - if wiretype != protowire.Fixed64Type { - return errors.New("bad wiretype for float64") + // Check predefined pb types + switch dst.Type() { + case typeTime: + var result timestamppb.Timestamp + + err = proto.Unmarshal(bytes, &result) + if err != nil { + return err } - dst.SetFloat(math.Float64frombits(decodedValue)) + dst.Set(reflect.ValueOf(result.AsTime())) - case reflect.String: - if wiretype != protowire.BytesType { - return errors.New("bad wiretype for string") + return nil + case typeDuration: + var result durationpb.Duration + + err = proto.Unmarshal(bytes, &result) + if err != nil { + return err } - dst.SetString(string(decodedBytes)) + dst.Set(reflect.ValueOf(result.AsDuration())) - case reflect.Ptr: + return nil + } + + switch dst.Kind() { //nolint:exhaustive + case reflect.Pointer: if dst.IsNil() { err := instantiate(dst) if err != nil { @@ -338,123 +401,107 @@ func (u *unmarshaller) putInto(dst reflect.Value, wiretype protowire.Type, decod } } - return u.putInto(dst.Elem(), wiretype, decodedValue, decodedBytes) + return unmarshalBytes(dst.Elem(), value) + + case reflect.String: + dst.SetString(string(bytes)) + + return nil case reflect.Struct: if enc, ok := dst.Addr().Interface().(encoding.BinaryUnmarshaler); ok { - return enc.UnmarshalBinary(decodedBytes) - } - - if wiretype != protowire.BytesType { - return errors.New("bad wiretype for embedded message") + return enc.UnmarshalBinary(bytes) } - return u.unmarshalStruct(decodedBytes, dst) + return unmarshalStruct(dst, bytes) case reflect.Slice, reflect.Array: - // Repeated field or byte-slice - if wiretype != protowire.BytesType { - return errors.New("bad wiretype for repeated field") - } + return slice(dst, value) - return u.slice(dst, decodedBytes) case reflect.Map: - if wiretype != protowire.BytesType { - return errors.New("bad wiretype for repeated field") - } - if dst.IsNil() { dst.Set(reflect.MakeMap(dst.Type())) } - return u.mapEntry(dst, decodedBytes) + return mapEntry(dst, bytes) + case reflect.Interface: // TODO: find a way to handle nil interfaces if dst.IsNil() { return errors.New("nil interface fields are not supported") } - // If the object support self-decoding, use that. if enc, ok := dst.Interface().(encoding.BinaryUnmarshaler); ok { - if wiretype != protowire.BytesType { - return errors.New("bad wiretype for bytes") - } - - return enc.UnmarshalBinary(decodedBytes) + return enc.UnmarshalBinary(bytes) } - // Decode into the object the interface points to. - return Unmarshal(decodedBytes, dst.Interface()) + return Unmarshal(bytes, dst.Interface()) default: - panic("unsupported value kind " + dst.Kind().String()) + return fmt.Errorf("unsupported value kind " + dst.Kind().String()) } - - return nil } -func tryDecodeFunc(vb []byte, dst reflect.Value) (bool, error) { - dec, ok := decoders.Get(dst.Type()) - if !ok { - return false, nil - } - - if err := dec(vb, dst); err != nil { - return false, err - } - - return true, nil -} - -func decodeSignedInt(wiretype protowire.Type, v uint64) (int64, error) { - switch wiretype { //nolint:exhaustive - case protowire.VarintType: - return int64(v), nil - case protowire.Fixed32Type: - return int64(int32(v)), nil - case protowire.Fixed64Type: - return int64(v), nil - default: - return -1, errors.New("bad wiretype for sint") +func unmarshalByteSeqeunce(dst reflect.Value, val complexValue) error { + unmarshalBytes, err := val.Bytes() + if err != nil { + return err } -} -// Instantiate an arbitrary type, handling dynamic interface types. -// Returns a Ptr value. -func instantiate(dst reflect.Value) error { - dstType := dst.Type() + if dst.Kind() == reflect.Array { + if dst.Len() != len(unmarshalBytes) { + return errors.New("array length and buffer length differ") + } - if dstType.Kind() == reflect.Interface { - return fmt.Errorf("cannot instantiate interface type %s", dstType.Name()) + for i := 0; i < dst.Len(); i++ { + // no SetByte method in reflect so has to pass down by uint64 + dst.Index(i).SetUint(uint64(unmarshalBytes[i])) + } + } else { + dst.SetBytes(unmarshalBytes) } - dst.Set(reflect.New(dstType.Elem())) - return nil } -func (u *unmarshaller) slice(dst reflect.Value, decodedBytes []byte) error { - // Find the element type, and create a temporary instance of it. +func slice(dst reflect.Value, val complexValue) error { elemType := dst.Type().Elem() - ok, err := tryDecodeUnpackedByteSlice(dst, elemType, decodedBytes) - if err != nil { - return err + if elemType.Kind() == reflect.Pointer { + if !isSlicePtrElemSupported(elemType) { + return fmt.Errorf("unsupported type: '%s'", dst.String()) + } } - if ok { + // we only decode bytes as []byte or [n]byte field + if elemType == typeByte { + err := unmarshalByteSeqeunce(dst, val) + if err != nil { + return err + } + return nil } - wiretype, err := getWiretypeFor(elemType) + bytes, err := val.Bytes() if err != nil { return err } - if wiretype < 0 { // Other unpacked repeated types - // Just unpack and append one value from decodedBytes. + ds, ok, err := getDataScannerFor(elemType, bytes) + if err != nil { + return err + } + + if !ok { // Other unpacked repeated types + // Just unpack and append one value from buf. + if dst.Kind() == reflect.Array { + return fmt.Errorf("arrays of complex types are not supported: '%s'", dst.String()) + } + elem := reflect.New(elemType).Elem() - if err = u.putInto(elem, protowire.BytesType, 0, decodedBytes); err != nil { + + if err = unmarshalBytes(elem, val); err != nil { return err } @@ -463,13 +510,13 @@ func (u *unmarshaller) slice(dst reflect.Value, decodedBytes []byte) error { return nil } - ok, err = tryDecodePredefinedSlice(wiretype, decodedBytes, dst) - if err != nil { - return err - } + ok, err = tryUnmarshalPredefinedSliceTypes(ds.Wiretype(), bytes, dst) - if ok { + switch { + case ok: return nil + case err != nil: + return err } sw := sequenceWrapper{ @@ -479,115 +526,21 @@ func (u *unmarshaller) slice(dst reflect.Value, decodedBytes []byte) error { defer sw.FixLen() // Decode packed values from the buffer and append them to the dst. - for len(decodedBytes) > 0 { + for ds.Scan() { nextElem := sw.NextElem() - rem, err := u.decodeValue(wiretype, decodedBytes, nextElem) - if err != nil { - return err - } - - decodedBytes = rem - } - - return nil -} - -func getWiretypeFor(elemType reflect.Type) (protowire.Type, error) { - switch elemType.Kind() { //nolint:exhaustive - case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Int, - reflect.Uint32, reflect.Uint64, reflect.Uint: - if (elemType.Kind() == reflect.Int || elemType.Kind() == reflect.Uint) && elemType.Size() < 8 { - return -1, errors.New("detected a 32bit machine, please either use (u)int64 or (u)int32") - } - - switch elemType { - case typeFixedS32: - return protowire.Fixed32Type, nil - case typeFixedS64: - return protowire.Fixed64Type, nil - case typeFixedU32: - return protowire.Fixed32Type, nil - case typeFixedU64: - return protowire.Fixed64Type, nil - case typeDuration: - return -1, nil - default: - return protowire.VarintType, nil - } - - case reflect.Float32: - return protowire.Fixed32Type, nil - - case reflect.Float64: - return protowire.Fixed64Type, nil - default: - return -1, nil - } -} - -func tryDecodeUnpackedByteSlice(dst reflect.Value, elemType reflect.Type, decodedBytes []byte) (bool, error) { - if elemType.Kind() != reflect.Uint8 { - return false, nil - } - - if dst.Kind() == reflect.Array { - if dst.Len() != len(decodedBytes) { - return false, errors.New("array length and buffer length differ") - } - - for i := 0; i < dst.Len(); i++ { - // no SetByte method in reflect so has to pass down by uint64 - dst.Index(i).SetUint(uint64(decodedBytes[i])) - } - } else { - dst.SetBytes(decodedBytes) - } - - return true, nil -} - -func (u *unmarshaller) mapEntry(dstEntry reflect.Value, decodedBytes []byte) error { - entryKey := reflect.New(dstEntry.Type().Key()).Elem() - entryVal := reflect.New(dstEntry.Type().Elem()).Elem() - - _, wiretype, n := protowire.ConsumeTag(decodedBytes) - if n <= 0 { - return errors.New("bad protobuf field key") - } - - buf := decodedBytes[n:] - - var err error - buf, err = u.decodeValue(wiretype, buf, entryKey) - - if err != nil { - return err - } - - for len(buf) > 0 { // for repeated values (slices etc) - _, wiretype, n := protowire.ConsumeTag(buf) - if n <= 0 { - return errors.New("bad protobuf field key") + value, ok := ds.PrimitiveValue() + if !ok { + return errors.New("incorrect value in packed slice") } - buf = buf[n:] - buf, err = u.decodeValue(wiretype, buf, entryVal) - + err := unmarshalPrimitive(nextElem, value) if err != nil { - return err + return fmt.Errorf("failed to unmarshal slice type '%s': %w", dst.Type(), err) } } - if !entryKey.IsValid() || !entryVal.IsValid() { - // We did not decode the key or the value in the map entry. - // Either way, it's an invalid map entry. - return errors.New("proto: bad map data: missing key/val") - } - - dstEntry.SetMapIndex(entryKey, entryVal) - - return nil + return ds.Err() } type sequenceWrapper struct { diff --git a/unmarshal_fastpath.go b/unmarshal_fastpath.go index 6d8c1c5..1791bbd 100644 --- a/unmarshal_fastpath.go +++ b/unmarshal_fastpath.go @@ -28,7 +28,7 @@ var predefiniedDecoders = map[reflect.Type]func(buf []byte, dst reflect.Value) ( typeOf[[]float64](): decodeFloat64, } -func tryDecodePredefinedSlice(wiretype protowire.Type, buf []byte, dst reflect.Value) (bool, error) { +func tryUnmarshalPredefinedSliceTypes(wiretype protowire.Type, buf []byte, dst reflect.Value) (bool, error) { switch wiretype { //nolint:exhaustive case protowire.VarintType, protowire.Fixed32Type, protowire.Fixed64Type: fn, ok := predefiniedDecoders[dst.Type()]