Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Decoder.Decode() return io.UnexpectedEOF instead of io.EOF if CBOR data item is truncated #379

Merged
merged 1 commit into from
Dec 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,10 @@ type decoder struct {
dm *decMode
}

// value decodes CBOR data item into the value pointed to by v.
// If CBOR data item is invalid, error is returned and offset isn't changed.
// If CBOR data item is valid but fails to be decode into v for other reasons,
// error is returned and offset is moved to the next CBOR data item.
func (d *decoder) value(v interface{}) error {
// v can't be nil, non-pointer, or nil pointer value.
if v == nil {
Expand Down
34 changes: 24 additions & 10 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Decoder struct {
buf []byte
off int // next read offset in buf
bytesRead int
readError error
}

// NewDecoder returns a new decoder that reads and decodes from r using
Expand All @@ -31,22 +32,30 @@ func (dec *Decoder) Decode(v interface{}) error {
return err
}
}

dec.d.reset(dec.buf[dec.off:])
err := dec.d.value(v)
dec.off += dec.d.off
dec.bytesRead += dec.d.off
if err != nil {
for {
dec.d.reset(dec.buf[dec.off:])
err := dec.d.value(v)
// Increment dec.off even if err is not nil because
// dec.d.off points to the next CBOR data item if current
// CBOR data item is valid but failed to be decoded into v.
// This allows next CBOR data item to be decoded in next
// call to this function.
dec.off += dec.d.off
dec.bytesRead += dec.d.off
if err == nil {
return nil
}
if err != io.ErrUnexpectedEOF {
return err
}
// Need to read more data.
if n, e := dec.read(); n == 0 {
return e
if n, err := dec.read(); n == 0 {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return err
}
return dec.Decode(v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea to replace recursive call with for-loop. 👍

}
return nil
}

// NumBytesRead returns the number of bytes read.
Expand All @@ -55,6 +64,10 @@ func (dec *Decoder) NumBytesRead() int {
}

func (dec *Decoder) read() (int, error) {
if dec.readError != nil {
return 0, dec.readError
}

// Grow buf if needed.
const minRead = 512
if cap(dec.buf)-len(dec.buf)+dec.off < minRead {
Expand All @@ -71,6 +84,7 @@ func (dec *Decoder) read() (int, error) {
// Read from reader and reslice buf.
n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)])
dec.buf = dec.buf[0 : len(dec.buf)+n]
dec.readError = err
return n, err
}

Expand Down
139 changes: 124 additions & 15 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cbor

import (
"bytes"
"errors"
"fmt"
"io"
"reflect"
Expand Down Expand Up @@ -40,14 +41,16 @@ func TestDecoder(t *testing.T) {
}
}
}
// no more data
var v interface{}
err := decoder.Decode(&v)
if v != nil {
t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v)
}
if err != io.EOF {
t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err)
for i := 0; i < 2; i++ {
// no more data
var v interface{}
err := decoder.Decode(&v)
if v != nil {
t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v)
}
if err != io.EOF {
t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err)
}
}
}

Expand Down Expand Up @@ -94,14 +97,98 @@ func TestDecoderUnmarshalTypeError(t *testing.T) {
}
}
}
// no more data
var v interface{}
err := decoder.Decode(&v)
if v != nil {
t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v)
for i := 0; i < 2; i++ {
// no more data
var v interface{}
err := decoder.Decode(&v)
if v != nil {
t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v)
}
if err != io.EOF {
t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err)
}
}
}

func TestDecoderUnexpectedEOFError(t *testing.T) {
var buf bytes.Buffer
for _, tc := range unmarshalTests {
buf.Write(tc.cborData)
}
buf.Truncate(buf.Len() - 1)

decoder := NewDecoder(&buf)
bytesRead := 0
for i := 0; i < len(unmarshalTests)-1; i++ {
tc := unmarshalTests[i]
var v interface{}
if err := decoder.Decode(&v); err != nil {
t.Fatalf("Decode() returned error %v", err)
}
if tm, ok := tc.emptyInterfaceValue.(time.Time); ok {
if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) {
t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
} else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) {
t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
bytesRead += len(tc.cborData)
if decoder.NumBytesRead() != bytesRead {
t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead)
}
}
for i := 0; i < 2; i++ {
// truncated data
var v interface{}
err := decoder.Decode(&v)
if v != nil {
t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v)
}
if err != io.ErrUnexpectedEOF {
t.Errorf("Decode() returned error %v, want io.UnexpectedEOF (truncated data)", err)
}
}
}

func TestDecoderReadError(t *testing.T) {
var buf bytes.Buffer
for _, tc := range unmarshalTests {
buf.Write(tc.cborData)
}
buf.Truncate(buf.Len() - 1)

readerErr := errors.New("reader error")

decoder := NewDecoder(NewErrorReader(buf.Bytes(), readerErr))
bytesRead := 0
for i := 0; i < len(unmarshalTests)-1; i++ {
tc := unmarshalTests[i]
var v interface{}
if err := decoder.Decode(&v); err != nil {
t.Fatalf("Decode() returned error %v", err)
}
if tm, ok := tc.emptyInterfaceValue.(time.Time); ok {
if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) {
t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
} else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) {
t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue)
}
bytesRead += len(tc.cborData)
if decoder.NumBytesRead() != bytesRead {
t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead)
}
}
if err != io.EOF {
t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err)
for i := 0; i < 2; i++ {
// truncated data because Reader returned error
var v interface{}
err := decoder.Decode(&v)
if v != nil {
t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v)
}
if err != readerErr {
t.Errorf("Decode() returned error %v, want reader error", err)
}
}
}

Expand Down Expand Up @@ -438,3 +525,25 @@ func TestNilRawMessageUnmarshalCBORError(t *testing.T) {
t.Errorf("UnmarshalCBOR() returned error %q, want %q", err.Error(), wantErrorMsg)
}
}

type ErrorReader struct {
data []byte
off int
err error
}

func NewErrorReader(data []byte, err error) *ErrorReader {
return &ErrorReader{data: data, err: err}
}

func (r *ErrorReader) Read(b []byte) (int, error) {
var n int
if r.off < len(r.data) {
n = copy(b, r.data[r.off:])
r.off += n
}
if n < len(b) {
return n, r.err
}
return n, nil
}