From d3a8ced6403c74abe3f7d971ee14cc6426537fd2 Mon Sep 17 00:00:00 2001 From: Michael Andersen Date: Fri, 28 Apr 2023 10:16:53 -0700 Subject: [PATCH] Add UnmarshalFirst Signed-off-by: Michael Andersen --- bench_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++ decode.go | 44 ++++++++++++++++++++++++++++++++++++++ decode_test.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 149 insertions(+) diff --git a/bench_test.go b/bench_test.go index 761bff91..5445db23 100644 --- a/bench_test.go +++ b/bench_test.go @@ -211,6 +211,54 @@ func BenchmarkUnmarshal(b *testing.B) { } } +func BenchmarkUnmarshalFirst(b *testing.B) { + // Random trailing data + trailingData := hexDecode("4a6b0f4718c73f391091ea1c") + for _, bm := range decodeBenchmarks { + for _, t := range bm.decodeToTypes { + name := "CBOR " + bm.name + " to Go " + t.String() + if t.Kind() == reflect.Struct { + name = "CBOR " + bm.name + " to Go " + t.Kind().String() + } + data := make([]byte, 0, len(bm.cborData)+len(trailingData)) + data = append(data, bm.cborData...) + data = append(data, trailingData...) + b.Run(name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + vPtr := reflect.New(t).Interface() + if _, err := UnmarshalFirst(data, vPtr); err != nil { + b.Fatal("UnmarshalFirst:", err) + } + } + }) + } + } +} + +func BenchmarkUnmarshalFirstViaDecoder(b *testing.B) { + // Random trailing data + trailingData := hexDecode("4a6b0f4718c73f391091ea1c") + for _, bm := range decodeBenchmarks { + for _, t := range bm.decodeToTypes { + name := "CBOR " + bm.name + " to Go " + t.String() + if t.Kind() == reflect.Struct { + name = "CBOR " + bm.name + " to Go " + t.Kind().String() + } + data := make([]byte, 0, len(bm.cborData)+len(trailingData)) + data = append(data, bm.cborData...) + data = append(data, trailingData...) + b.Run(name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + vPtr := reflect.New(t).Interface() + if err := NewDecoder(bytes.NewReader(data)).Decode(vPtr); err != nil { + b.Fatal("UnmarshalDecoder:", err) + } + } + }) + } + } +} + func BenchmarkDecode(b *testing.B) { for _, bm := range decodeBenchmarks { for _, t := range bm.decodeToTypes { diff --git a/decode.go b/decode.go index 5bc9e86a..dbfc0692 100644 --- a/decode.go +++ b/decode.go @@ -95,10 +95,25 @@ import ( // // Unmarshal supports CBOR tag 55799 (self-describe CBOR), tag 0 and 1 (time), // and tag 2 and 3 (bignum). +// +// Unmarshal returns ExtraneousDataError error (without decoding into v) +// if there are any remaining bytes following the first valid CBOR data item. +// See UnmarshalFirst, if you want to unmarshal only the first +// CBOR data item without ExtraneousDataError caused by remaining bytes. func Unmarshal(data []byte, v interface{}) error { return defaultDecMode.Unmarshal(data, v) } +// UnmarshalFirst parses the first CBOR data item into the value pointed to by v +// using default decoding options. Any remaining bytes are returned in rest. +// +// If v is nil, not a pointer, or a nil pointer, UnmarshalFirst returns an error. +// +// See the documentation for Unmarshal for details. +func UnmarshalFirst(data []byte, v interface{}) (rest []byte, err error) { + return defaultDecMode.UnmarshalFirst(data, v) +} + // Valid checks whether data is a well-formed encoded CBOR data item and // that it complies with default restrictions such as MaxNestedLevels, // MaxArrayElements, MaxMapPairs, etc. @@ -604,6 +619,35 @@ func (dm *decMode) Unmarshal(data []byte, v interface{}) error { return d.value(v) } +// UnmarshalFirst parses the first CBOR data item into the value pointed to by v +// using dm decoding mode. Any remaining bytes are returned in rest. +// +// If v is nil, not a pointer, or a nil pointer, UnmarshalFirst returns an error. +// +// See the documentation for Unmarshal for details. +func (dm *decMode) UnmarshalFirst(data []byte, v interface{}) (rest []byte, err error) { + d := decoder{data: data, dm: dm} + + // check well-formedness. + off := d.off // Save offset before data validation + err = d.wellformed(true) // allow extra data after well-formed data item + d.off = off // Restore offset + + // If it is well-formed, parse the value. This is structured like this to allow + // better test coverage + if err == nil { + err = d.value(v) + } + + // If either wellformed or value returned an error, do not return rest bytes + if err != nil { + return nil, err + } + + // Return the rest of the data slice (which might be len 0) + return d.data[d.off:], nil +} + // Valid checks whether data is a well-formed encoded CBOR data item and // that it complies with configurable restrictions such as MaxNestedLevels, // MaxArrayElements, MaxMapPairs, etc. diff --git a/decode_test.go b/decode_test.go index 0e585d3e..cc0e9837 100644 --- a/decode_test.go +++ b/decode_test.go @@ -5981,3 +5981,60 @@ func TestUnmarshalToDefaultMapType(t *testing.T) { }) } } + +func TestUnmarshalFirstNoTrailing(t *testing.T) { + for _, tc := range unmarshalTests { + var v interface{} + if rest, err := UnmarshalFirst(tc.cborData, &v); err != nil { + t.Errorf("UnmarshalFirst(0x%x) returned error %v", tc.cborData, err) + } else { + if len(rest) != 0 { + t.Errorf("UnmarshalFirst(0x%x) returned rest %x (want [])", tc.cborData, rest) + } + // Check the value as well, although this is covered by other tests + if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { + if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { + t.Errorf("UnmarshalFirst(0x%x) = %v (%T), want %v (%T)", tc.cborData, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { + t.Errorf("UnmarshalFirst(0x%x) = %v (%T), want %v (%T)", tc.cborData, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } + } +} + +func TestUnmarshalfirstTrailing(t *testing.T) { + // Random trailing data + trailingData := hexDecode("4a6b0f4718c73f391091ea1c") + for _, tc := range unmarshalTests { + data := make([]byte, 0, len(tc.cborData)+len(trailingData)) + data = append(data, tc.cborData...) + data = append(data, trailingData...) + var v interface{} + if rest, err := UnmarshalFirst(data, &v); err != nil { + t.Errorf("UnmarshalFirst(0x%x) returned error %v", data, err) + } else { + if !bytes.Equal(trailingData, rest) { + t.Errorf("UnmarshalFirst(0x%x) returned rest %x (want %x)", data, rest, trailingData) + } + // Check the value as well, although this is covered by other tests + if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { + if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { + t.Errorf("UnmarshalFirst(0x%x) = %v (%T), want %v (%T)", data, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { + t.Errorf("UnmarshalFirst(0x%x) = %v (%T), want %v (%T)", data, v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } + } +} + +func TestUnmarshalFirstInvalidItem(t *testing.T) { + // UnmarshalFirst should not return "rest" if the item was not well-formed + invalidCBOR := hexDecode("83FF20030102") + var v interface{} + rest, err := UnmarshalFirst(invalidCBOR, &v) + if rest != nil { + t.Errorf("UnmarshalFirst(0x%x) = (%x, %v), want (nil, err)", invalidCBOR, rest, err) + } +}