diff --git a/decode.go b/decode.go index bc0e6895..343da9ea 100644 --- a/decode.go +++ b/decode.go @@ -40,17 +40,17 @@ import ( // To unmarshal CBOR into an empty interface value, Unmarshal uses the // following rules: // -// CBOR booleans decode to bool. -// CBOR positive integers decode to uint64. -// CBOR negative integers decode to int64 (big.Int if value overflows). -// CBOR floating points decode to float64. -// CBOR byte strings decode to []byte. -// CBOR text strings decode to string. -// CBOR arrays decode to []interface{}. -// CBOR maps decode to map[interface{}]interface{}. -// CBOR null and undefined values decode to nil. -// CBOR times (tag 0 and 1) decode to time.Time. -// CBOR bignums (tag 2 and 3) decode to big.Int. +// CBOR booleans decode to bool. +// CBOR positive integers decode to uint64. +// CBOR negative integers decode to int64 (big.Int if value overflows). +// CBOR floating points decode to float64. +// CBOR byte strings decode to []byte. +// CBOR text strings decode to string. +// CBOR arrays decode to []interface{}. +// CBOR maps decode to map[interface{}]interface{}. +// CBOR null and undefined values decode to nil. +// CBOR times (tag 0 and 1) decode to time.Time. +// CBOR bignums (tag 2 and 3) decode to big.Int. // // To unmarshal a CBOR array into a slice, Unmarshal allocates a new slice // if the CBOR array is empty or slice capacity is less than CBOR array length. @@ -75,9 +75,9 @@ import ( // To unmarshal a CBOR map into a struct, Unmarshal matches CBOR map keys to the // keys in the following priority: // -// 1. "cbor" key in struct field tag, -// 2. "json" key in struct field tag, -// 3. struct field name. +// 1. "cbor" key in struct field tag, +// 2. "json" key in struct field tag, +// 3. struct field name. // // Unmarshal tries an exact match for field name, then a case-insensitive match. // Map key-value pairs without corresponding struct fields are ignored. See @@ -549,7 +549,16 @@ func (dm *decMode) DecOptions() DecOptions { // See the documentation for Unmarshal for details. func (dm *decMode) Unmarshal(data []byte, v interface{}) error { d := decoder{data: data, dm: dm} - return d.value(v, false) + + // check valid + off := d.off // Save offset before data validation + err := d.valid(false) // don't allow any extra data after valid data item. + d.off = off // Restore offset + if err != nil { + return err + } + + return d.value(v) } // Valid checks whether the CBOR data is complete and well-formed. @@ -570,10 +579,10 @@ type decoder struct { } // value decodes CBOR data item into the value pointed to by v. -// If CBOR data item is invalid, error is returned and offset isn't changed. -// If CBOR data item is valid but fails to be decode into v for other reasons, +// If CBOR data item fails to be decoded into v, // error is returned and offset is moved to the next CBOR data item. -func (d *decoder) value(v interface{}, allowExtraData bool) error { +// Precondition: d.data contains at least one valid CBOR data item. +func (d *decoder) value(v interface{}) error { // v can't be nil, non-pointer, or nil pointer value. if v == nil { return &InvalidUnmarshalError{"cbor: Unmarshal(nil)"} @@ -584,14 +593,6 @@ func (d *decoder) value(v interface{}, allowExtraData bool) error { } else if rv.IsNil() { return &InvalidUnmarshalError{"cbor: Unmarshal(nil " + rv.Type().String() + ")"} } - - off := d.off // Save offset before data validation - err := d.valid(allowExtraData) - d.off = off // Restore offset - if err != nil { - return err - } - rv = rv.Elem() return d.parseToValue(rv, getTypeInfo(rv.Type())) } diff --git a/stream.go b/stream.go index 45a3422c..4e8c1e87 100644 --- a/stream.go +++ b/stream.go @@ -16,7 +16,6 @@ type Decoder struct { buf []byte off int // next read offset in buf bytesRead int - readError error } // NewDecoder returns a new decoder that reads and decodes from r using @@ -27,67 +26,38 @@ func NewDecoder(r io.Reader) *Decoder { // Decode reads CBOR value and decodes it into the value pointed to by v. func (dec *Decoder) Decode(v interface{}) error { - if len(dec.buf) == dec.off { - if n, err := dec.read(); n == 0 { - return err - } - } - for { - dec.d.reset(dec.buf[dec.off:]) - err := dec.d.value(v, true) - // Increment dec.off even if err is not nil because - // dec.d.off points to the next CBOR data item if current - // CBOR data item is valid but failed to be decoded into v. - // This allows next CBOR data item to be decoded in next - // call to this function. - dec.off += dec.d.off - dec.bytesRead += dec.d.off - if err == nil { - return nil - } - if err != io.ErrUnexpectedEOF { - return err - } - // Need to read more data. - if n, err := dec.read(); n == 0 { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return err - } + _, err := dec.readNext() + if err != nil { + // Return validation error or read error. + return err } + + dec.d.reset(dec.buf[dec.off:]) + err = dec.d.value(v) + + // Increment dec.off even if decoding err is not nil because + // dec.d.off points to the next CBOR data item if current + // CBOR data item is valid but failed to be decoded into v. + // This allows next CBOR data item to be decoded in next + // call to this function. + dec.off += dec.d.off + dec.bytesRead += dec.d.off + + return err } // Skip skips to the next CBOR data item (if there is any), // otherwise it returns error such as io.EOF, io.UnexpectedEOF, etc. func (dec *Decoder) Skip() error { - if len(dec.buf) == dec.off { - if n, err := dec.read(); n == 0 { - return err - } - } - for { - dec.d.reset(dec.buf[dec.off:]) - err := dec.d.valid(true) - if err == nil { - // Only increment dec.off if current CBOR data item is valid. - // If current data item is incomplete (io.ErrUnexpectedEOF), - // we want to try again after reading more data. - dec.off += dec.d.off - dec.bytesRead += dec.d.off - return nil - } - if err != io.ErrUnexpectedEOF { - return err - } - // Need to read more data. - if n, err := dec.read(); n == 0 { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return err - } + n, err := dec.readNext() + if err != nil { + // Return validation error or read error. + return err } + + dec.off += n + dec.bytesRead += n + return nil } // NumBytesRead returns the number of bytes read. @@ -95,11 +65,66 @@ func (dec *Decoder) NumBytesRead() int { return dec.bytesRead } -func (dec *Decoder) read() (int, error) { - if dec.readError != nil { - return 0, dec.readError +// readNext() reads next CBOR data item from Reader to buffer. +// It returns the size of next CBOR data item. +// It also returns validation error or read error if any. +func (dec *Decoder) readNext() (int, error) { + var readErr error + var validErr error + + for { + // Process any unread data in dec.buf. + if dec.off < len(dec.buf) { + dec.d.reset(dec.buf[dec.off:]) + off := dec.off // Save offset before data validation + validErr = dec.d.valid(true) + dec.off = off // Restore offset + + if validErr == nil { + return dec.d.off, nil + } + + if validErr != io.ErrUnexpectedEOF { + return 0, validErr + } + + // Process last read error on io.ErrUnexpectedEOF. + if readErr != nil { + if readErr == io.EOF { + // current CBOR data item is incomplete. + return 0, io.ErrUnexpectedEOF + } + return 0, readErr + } + } + + // More data is needed and there was no read error. + var n int + for n == 0 { + n, readErr = dec.read() + if n == 0 && readErr != nil { + // No more data can be read and read error is encountered. + // At this point, validErr is either nil or io.ErrUnexpectedEOF. + if readErr == io.EOF { + if validErr == io.ErrUnexpectedEOF { + // current CBOR data item is incomplete. + return 0, io.ErrUnexpectedEOF + } + } + return 0, readErr + } + } + + // At this point, dec.buf contains new data from last read (n > 0). } +} +// read() reads data from Reader to buffer. +// It returns number of bytes read and any read error encountered. +// Postconditions: +// - dec.buf contains previously unread data and new data. +// - dec.off is 0. +func (dec *Decoder) read() (int, error) { // Grow buf if needed. const minRead = 512 if cap(dec.buf)-len(dec.buf)+dec.off < minRead { @@ -116,7 +141,6 @@ func (dec *Decoder) read() (int, error) { // Read from reader and reslice buf. n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)]) dec.buf = dec.buf[0 : len(dec.buf)+n] - dec.readError = err return n, err } diff --git a/stream_test.go b/stream_test.go index 211d79d8..bbd74ac5 100644 --- a/stream_test.go +++ b/stream_test.go @@ -21,37 +21,51 @@ func TestDecoder(t *testing.T) { buf.Write(tc.cborData) } } - decoder := NewDecoder(&buf) - bytesRead := 0 - for i := 0; i < 5; i++ { - for _, tc := range unmarshalTests { - var v interface{} - if err := decoder.Decode(&v); err != nil { - t.Fatalf("Decode() returned error %v", err) - } - if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { - if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { - t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + + testCases := []struct { + name string + reader io.Reader + }{ + {"bytes.Buffer", &buf}, + {"1 byte reader", newNBytesReader(buf.Bytes(), 1)}, + {"toggled reader", newToggledReader(buf.Bytes(), 1)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decoder := NewDecoder(tc.reader) + bytesRead := 0 + for i := 0; i < 5; i++ { + for _, tc := range unmarshalTests { + var v interface{} + if err := decoder.Decode(&v); err != nil { + t.Fatalf("Decode() returned error %v", err) + } + if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { + if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } } - } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { - t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) } - bytesRead += len(tc.cborData) - if decoder.NumBytesRead() != bytesRead { - t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + for i := 0; i < 2; i++ { + // no more data + var v interface{} + err := decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + } + if err != io.EOF { + t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err) + } } - } - } - for i := 0; i < 2; i++ { - // no more data - var v interface{} - err := decoder.Decode(&v) - if v != nil { - t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) - } - if err != io.EOF { - t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err) - } + }) } } @@ -64,50 +78,64 @@ func TestDecoderUnmarshalTypeError(t *testing.T) { } } } - decoder := NewDecoder(&buf) - bytesRead := 0 - for i := 0; i < 5; i++ { - for _, tc := range unmarshalTests { - for _, typ := range tc.wrongTypes { - v := reflect.New(typ) - if err := decoder.Decode(v.Interface()); err == nil { - t.Errorf("Decode(0x%x) didn't return an error, want UnmarshalTypeError", tc.cborData) - } else if _, ok := err.(*UnmarshalTypeError); !ok { - t.Errorf("Decode(0x%x) returned wrong error type %T, want UnmarshalTypeError", tc.cborData, err) - } - bytesRead += len(tc.cborData) - if decoder.NumBytesRead() != bytesRead { - t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) - } - var vi interface{} - if err := decoder.Decode(&vi); err != nil { - t.Errorf("Decode() returned error %v", err) - } - if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { - if vt, ok := vi.(time.Time); !ok || !tm.Equal(vt) { - t.Errorf("Decode() = %v (%T), want %v (%T)", vi, vi, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + testCases := []struct { + name string + reader io.Reader + }{ + {"bytes.Buffer", &buf}, + {"1 byte reader", newNBytesReader(buf.Bytes(), 1)}, + {"toggled reader", newToggledReader(buf.Bytes(), 1)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decoder := NewDecoder(tc.reader) + bytesRead := 0 + for i := 0; i < 5; i++ { + for _, tc := range unmarshalTests { + for _, typ := range tc.wrongTypes { + v := reflect.New(typ) + if err := decoder.Decode(v.Interface()); err == nil { + t.Errorf("Decode(0x%x) didn't return an error, want UnmarshalTypeError", tc.cborData) + } else if _, ok := err.(*UnmarshalTypeError); !ok { + t.Errorf("Decode(0x%x) returned wrong error type %T, want UnmarshalTypeError", tc.cborData, err) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + + var vi interface{} + if err := decoder.Decode(&vi); err != nil { + t.Errorf("Decode() returned error %v", err) + } + if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { + if vt, ok := vi.(time.Time); !ok || !tm.Equal(vt) { + t.Errorf("Decode() = %v (%T), want %v (%T)", vi, vi, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } else if !reflect.DeepEqual(vi, tc.emptyInterfaceValue) { + t.Errorf("Decode() = %v (%T), want %v (%T)", vi, vi, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } } - } else if !reflect.DeepEqual(vi, tc.emptyInterfaceValue) { - t.Errorf("Decode() = %v (%T), want %v (%T)", vi, vi, tc.emptyInterfaceValue, tc.emptyInterfaceValue) } - bytesRead += len(tc.cborData) - if decoder.NumBytesRead() != bytesRead { - t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + for i := 0; i < 2; i++ { + // no more data + var v interface{} + err := decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + } + if err != io.EOF { + t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err) } } - } - } - for i := 0; i < 2; i++ { - // no more data - var v interface{} - err := decoder.Decode(&v) - if v != nil { - t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) - } - if err != io.EOF { - t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err) - } + }) } } @@ -118,36 +146,50 @@ func TestDecoderUnexpectedEOFError(t *testing.T) { } buf.Truncate(buf.Len() - 1) - decoder := NewDecoder(&buf) - bytesRead := 0 - for i := 0; i < len(unmarshalTests)-1; i++ { - tc := unmarshalTests[i] - var v interface{} - if err := decoder.Decode(&v); err != nil { - t.Fatalf("Decode() returned error %v", err) - } - if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { - if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { - t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) - } - } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { - t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) - } - bytesRead += len(tc.cborData) - if decoder.NumBytesRead() != bytesRead { - t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) - } + testCases := []struct { + name string + reader io.Reader + }{ + {"bytes.Buffer", &buf}, + {"1 byte reader", newNBytesReader(buf.Bytes(), 1)}, + {"toggled reader", newToggledReader(buf.Bytes(), 1)}, } - for i := 0; i < 2; i++ { - // truncated data - var v interface{} - err := decoder.Decode(&v) - if v != nil { - t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) - } - if err != io.ErrUnexpectedEOF { - t.Errorf("Decode() returned error %v, want io.UnexpectedEOF (truncated data)", err) - } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + decoder := NewDecoder(tc.reader) + bytesRead := 0 + for i := 0; i < len(unmarshalTests)-1; i++ { + tc := unmarshalTests[i] + var v interface{} + if err := decoder.Decode(&v); err != nil { + t.Fatalf("Decode() returned error %v", err) + } + if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { + if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + for i := 0; i < 2; i++ { + // truncated data + var v interface{} + err := decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + } + if err != io.ErrUnexpectedEOF { + t.Errorf("Decode() returned error %v, want io.UnexpectedEOF (truncated data)", err) + } + } + }) } } @@ -160,36 +202,114 @@ func TestDecoderReadError(t *testing.T) { readerErr := errors.New("reader error") - decoder := NewDecoder(NewErrorReader(buf.Bytes(), readerErr)) - bytesRead := 0 - for i := 0; i < len(unmarshalTests)-1; i++ { - tc := unmarshalTests[i] - var v interface{} - if err := decoder.Decode(&v); err != nil { - t.Fatalf("Decode() returned error %v", err) - } - if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { - if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { - t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + testCases := []struct { + name string + reader io.Reader + }{ + {"byte reader", newNBytesReaderWithError(buf.Bytes(), 512, readerErr)}, + {"1 byte reader", newNBytesReaderWithError(buf.Bytes(), 1, readerErr)}, + {"toggled reader", newToggledReaderWithError(buf.Bytes(), 1, readerErr)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decoder := NewDecoder(tc.reader) + bytesRead := 0 + for i := 0; i < len(unmarshalTests)-1; i++ { + tc := unmarshalTests[i] + var v interface{} + if err := decoder.Decode(&v); err != nil { + t.Fatalf("Decode() returned error %v", err) + } + if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { + if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } } - } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { - t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) - } - bytesRead += len(tc.cborData) - if decoder.NumBytesRead() != bytesRead { - t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) - } + for i := 0; i < 2; i++ { + // truncated data because Reader returned error + var v interface{} + err := decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + } + if err != readerErr { + t.Errorf("Decode() returned error %v, want reader error", err) + } + } + }) } - for i := 0; i < 2; i++ { - // truncated data because Reader returned error - var v interface{} - err := decoder.Decode(&v) - if v != nil { - t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) - } - if err != readerErr { - t.Errorf("Decode() returned error %v, want reader error", err) - } +} + +func TestDecoderNoData(t *testing.T) { + readerErr := errors.New("reader error") + + testCases := []struct { + name string + reader io.Reader + wantErr error + }{ + {"byte.Buffer", new(bytes.Buffer), io.EOF}, + {"1 byte reader", newNBytesReaderWithError(nil, 0, readerErr), readerErr}, + {"toggled reader", newToggledReaderWithError(nil, 0, readerErr), readerErr}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decoder := NewDecoder(tc.reader) + for i := 0; i < 2; i++ { + var v interface{} + err := decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil", v, v) + } + if err != tc.wantErr { + t.Errorf("Decode() returned error %v, want error %v", err, tc.wantErr) + } + } + }) + } +} + +func TestDecoderRecoverableReadError(t *testing.T) { + cborData := hexDecode("83010203") // [1,2,3] + wantValue := []interface{}{uint64(1), uint64(2), uint64(3)} + recoverableReaderErr := errors.New("recoverable reader error") + + decoder := NewDecoder(newRecoverableReader(cborData, 1, recoverableReaderErr)) + + var v interface{} + err := decoder.Decode(&v) + if err != recoverableReaderErr { + t.Fatalf("Decode() returned error %v, want error %v", err, recoverableReaderErr) + } + + err = decoder.Decode(&v) + if err != nil { + t.Fatalf("Decode() returned error %v", err) + } + if !reflect.DeepEqual(v, wantValue) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, wantValue, wantValue) + } + if decoder.NumBytesRead() != len(cborData) { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), len(cborData)) + } + + // no more data + v = interface{}(nil) + err = decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + } + if err != io.EOF { + t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err) } } @@ -219,25 +339,39 @@ func TestDecoderSkip(t *testing.T) { buf.Write(tc.cborData) } } - decoder := NewDecoder(&buf) - bytesRead := 0 - for i := 0; i < 5; i++ { - for _, tc := range unmarshalTests { - if err := decoder.Skip(); err != nil { - t.Fatalf("Skip() returned error %v", err) + + testCases := []struct { + name string + reader io.Reader + }{ + {"bytes.Buffer", &buf}, + {"1 byte reader", newNBytesReader(buf.Bytes(), 1)}, + {"toggled reader", newToggledReader(buf.Bytes(), 1)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decoder := NewDecoder(tc.reader) + bytesRead := 0 + for i := 0; i < 5; i++ { + for _, tc := range unmarshalTests { + if err := decoder.Skip(); err != nil { + t.Fatalf("Skip() returned error %v", err) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } } - bytesRead += len(tc.cborData) - if decoder.NumBytesRead() != bytesRead { - t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + for i := 0; i < 2; i++ { + // no more data + err := decoder.Skip() + if err != io.EOF { + t.Errorf("Skip() returned error %v, want io.EOF (no more data)", err) + } } - } - } - for i := 0; i < 2; i++ { - // no more data - err := decoder.Skip() - if err != io.EOF { - t.Errorf("Skip() returned error %v, want io.EOF (no more data)", err) - } + }) } } @@ -248,26 +382,39 @@ func TestDecoderSkipInvalidDataError(t *testing.T) { } buf.WriteByte(0x3e) - decoder := NewDecoder(&buf) - bytesRead := 0 - for i := 0; i < len(unmarshalTests); i++ { - tc := unmarshalTests[i] - if err := decoder.Skip(); err != nil { - t.Fatalf("Skip() returned error %v", err) - } - bytesRead += len(tc.cborData) - if decoder.NumBytesRead() != bytesRead { - t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) - } + testCases := []struct { + name string + reader io.Reader + }{ + {"bytes.Buffer", &buf}, + {"1 byte reader", newNBytesReader(buf.Bytes(), 1)}, + {"toggled reader", newToggledReader(buf.Bytes(), 1)}, } - for i := 0; i < 2; i++ { - // last data item is invalid - err := decoder.Skip() - if err == nil { - t.Fatalf("Skip() didn't return error") - } else if !strings.Contains(err.Error(), "cbor: invalid additional information") { - t.Errorf("Skip() error %q, want \"cbor: invalid additional information\"", err) - } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decoder := NewDecoder(tc.reader) + bytesRead := 0 + for i := 0; i < len(unmarshalTests); i++ { + tc := unmarshalTests[i] + if err := decoder.Skip(); err != nil { + t.Fatalf("Skip() returned error %v", err) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + for i := 0; i < 2; i++ { + // last data item is invalid + err := decoder.Skip() + if err == nil { + t.Fatalf("Skip() didn't return error") + } else if !strings.Contains(err.Error(), "cbor: invalid additional information") { + t.Errorf("Skip() error %q, want \"cbor: invalid additional information\"", err) + } + } + }) } } @@ -278,24 +425,37 @@ func TestDecoderSkipUnexpectedEOFError(t *testing.T) { } buf.Truncate(buf.Len() - 1) - decoder := NewDecoder(&buf) - bytesRead := 0 - for i := 0; i < len(unmarshalTests)-1; i++ { - tc := unmarshalTests[i] - if err := decoder.Skip(); err != nil { - t.Fatalf("Skip() returned error %v", err) - } - bytesRead += len(tc.cborData) - if decoder.NumBytesRead() != bytesRead { - t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) - } + testCases := []struct { + name string + reader io.Reader + }{ + {"bytes.Buffer", &buf}, + {"1 byte reader", newNBytesReader(buf.Bytes(), 1)}, + {"toggled reader", newToggledReader(buf.Bytes(), 1)}, } - for i := 0; i < 2; i++ { - // last data item is invalid - err := decoder.Skip() - if err != io.ErrUnexpectedEOF { - t.Errorf("Skip() returned error %v, want io.ErrUnexpectedEOF (truncated data)", err) - } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decoder := NewDecoder(tc.reader) + bytesRead := 0 + for i := 0; i < len(unmarshalTests)-1; i++ { + tc := unmarshalTests[i] + if err := decoder.Skip(); err != nil { + t.Fatalf("Skip() returned error %v", err) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + for i := 0; i < 2; i++ { + // last data item is invalid + err := decoder.Skip() + if err != io.ErrUnexpectedEOF { + t.Errorf("Skip() returned error %v, want io.ErrUnexpectedEOF (truncated data)", err) + } + } + }) } } @@ -308,24 +468,89 @@ func TestDecoderSkipReadError(t *testing.T) { readerErr := errors.New("reader error") - decoder := NewDecoder(NewErrorReader(buf.Bytes(), readerErr)) - bytesRead := 0 - for i := 0; i < len(unmarshalTests)-1; i++ { - tc := unmarshalTests[i] - if err := decoder.Skip(); err != nil { - t.Fatalf("Skip() returned error %v", err) - } - bytesRead += len(tc.cborData) - if decoder.NumBytesRead() != bytesRead { - t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) - } + testCases := []struct { + name string + reader io.Reader + }{ + {"byte reader", newNBytesReaderWithError(buf.Bytes(), 512, readerErr)}, + {"1 byte reader", newNBytesReaderWithError(buf.Bytes(), 1, readerErr)}, + {"toggled reader", newToggledReaderWithError(buf.Bytes(), 1, readerErr)}, } - for i := 0; i < 2; i++ { - // truncated data because Reader returned error - err := decoder.Skip() - if err != readerErr { - t.Errorf("Skip() returned error %v, want reader error", err) - } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decoder := NewDecoder(tc.reader) + bytesRead := 0 + for i := 0; i < len(unmarshalTests)-1; i++ { + tc := unmarshalTests[i] + if err := decoder.Skip(); err != nil { + t.Fatalf("Skip() returned error %v", err) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + for i := 0; i < 2; i++ { + // truncated data because Reader returned error + err := decoder.Skip() + if err != readerErr { + t.Errorf("Skip() returned error %v, want reader error", err) + } + } + }) + } +} + +func TestDecoderSkipNoData(t *testing.T) { + readerErr := errors.New("reader error") + + testCases := []struct { + name string + reader io.Reader + wantErr error + }{ + {"byte.Buffer", new(bytes.Buffer), io.EOF}, + {"1 byte reader", newNBytesReaderWithError(nil, 0, readerErr), readerErr}, + {"toggled reader", newToggledReaderWithError(nil, 0, readerErr), readerErr}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decoder := NewDecoder(tc.reader) + for i := 0; i < 2; i++ { + err := decoder.Skip() + if err != tc.wantErr { + t.Errorf("Decode() returned error %v, want error %v", err, tc.wantErr) + } + } + }) + } +} + +func TestDecoderSkipRecoverableReadError(t *testing.T) { + cborData := hexDecode("83010203") // [1,2,3] + recoverableReaderErr := errors.New("recoverable reader error") + + decoder := NewDecoder(newRecoverableReader(cborData, 1, recoverableReaderErr)) + + err := decoder.Skip() + if err != recoverableReaderErr { + t.Fatalf("Skip() returned error %v, want error %v", err, recoverableReaderErr) + } + + err = decoder.Skip() + if err != nil { + t.Fatalf("Skip() returned error %v", err) + } + if decoder.NumBytesRead() != len(cborData) { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), len(cborData)) + } + + // no more data + err = decoder.Skip() + if err != io.EOF { + t.Errorf("Skip() returned error %v, want io.EOF (no more data)", err) } } @@ -663,24 +888,107 @@ func TestNilRawMessageUnmarshalCBORError(t *testing.T) { } } -type ErrorReader struct { - data []byte - off int - err error +// nBytesReader reads at most maxBytesPerRead into b. It also returns error at the last read. +type nBytesReader struct { + data []byte + maxBytesPerRead int + off int + err error } -func NewErrorReader(data []byte, err error) *ErrorReader { - return &ErrorReader{data: data, err: err} +func newNBytesReader(data []byte, maxBytesPerRead int) *nBytesReader { + return &nBytesReader{ + data: append([]byte{}, data...), + maxBytesPerRead: maxBytesPerRead, + err: io.EOF, + } +} + +func newNBytesReaderWithError(data []byte, maxBytesPerRead int, err error) *nBytesReader { + return &nBytesReader{ + data: append([]byte{}, data...), + maxBytesPerRead: maxBytesPerRead, + err: err, + } } -func (r *ErrorReader) Read(b []byte) (int, error) { +func (r *nBytesReader) Read(b []byte) (int, error) { var n int if r.off < len(r.data) { - n = copy(b, r.data[r.off:]) + numOfBytesToRead := len(r.data) - r.off + if numOfBytesToRead > r.maxBytesPerRead { + numOfBytesToRead = r.maxBytesPerRead + } + n = copy(b, r.data[r.off:r.off+numOfBytesToRead]) r.off += n } - if n < len(b) { + if r.off == len(r.data) { return n, r.err } return n, nil } + +// toggledReader returns (0, nil) for every other read to mimic non-blocking read for stream reader. +type toggledReader struct { + nBytesReader + toggle bool +} + +func newToggledReader(data []byte, maxBytesPerRead int) *toggledReader { + return &toggledReader{ + nBytesReader: nBytesReader{ + data: append([]byte{}, data...), + maxBytesPerRead: maxBytesPerRead, + err: io.EOF, + }, + toggle: true, // first read returns (0, nil) + } +} + +func newToggledReaderWithError(data []byte, maxBytesPerRead int, err error) *toggledReader { + return &toggledReader{ + nBytesReader: nBytesReader{ + data: append([]byte{}, data...), + maxBytesPerRead: maxBytesPerRead, + err: err, + }, + toggle: true, // first read returns (0, nil) + } +} + +func (r *toggledReader) Read(b []byte) (int, error) { + defer func() { + r.toggle = !r.toggle + }() + if r.toggle { + return 0, nil + } + return r.nBytesReader.Read(b) +} + +// recoverableReader returns a recoverable error at first read operation. +type recoverableReader struct { + nBytesReader + recoverableErr error + first bool +} + +func newRecoverableReader(data []byte, maxBytesPerRead int, err error) *recoverableReader { + return &recoverableReader{ + nBytesReader: nBytesReader{ + data: append([]byte{}, data...), + maxBytesPerRead: maxBytesPerRead, + err: io.EOF, + }, + recoverableErr: err, + first: true, + } +} + +func (r *recoverableReader) Read(b []byte) (int, error) { + if r.first { + r.first = false + return 0, r.recoverableErr + } + return r.nBytesReader.Read(b) +}