diff --git a/decode.go b/decode.go index c444451e..0db569ad 100644 --- a/decode.go +++ b/decode.go @@ -569,6 +569,10 @@ type decoder struct { dm *decMode } +// value decodes CBOR data item into the value pointed to by v. +// If CBOR data item is invalid, error is returned and offset isn't changed. +// If CBOR data item is valid but fails to be decode into v for other reasons, +// error is returned and offset is moved to the next CBOR data item. func (d *decoder) value(v interface{}) error { // v can't be nil, non-pointer, or nil pointer value. if v == nil { diff --git a/stream.go b/stream.go index 1d4ec706..dd5a377c 100644 --- a/stream.go +++ b/stream.go @@ -16,6 +16,7 @@ type Decoder struct { buf []byte off int // next read offset in buf bytesRead int + readError error } // NewDecoder returns a new decoder that reads and decodes from r using @@ -31,22 +32,30 @@ func (dec *Decoder) Decode(v interface{}) error { return err } } - - dec.d.reset(dec.buf[dec.off:]) - err := dec.d.value(v) - dec.off += dec.d.off - dec.bytesRead += dec.d.off - if err != nil { + for { + dec.d.reset(dec.buf[dec.off:]) + err := dec.d.value(v) + // Increment dec.off even if err is not nil because + // dec.d.off points to the next CBOR data item if current + // CBOR data item is valid but failed to be decoded into v. + // This allows next CBOR data item to be decoded in next + // call to this function. + dec.off += dec.d.off + dec.bytesRead += dec.d.off + if err == nil { + return nil + } if err != io.ErrUnexpectedEOF { return err } // Need to read more data. - if n, e := dec.read(); n == 0 { - return e + if n, err := dec.read(); n == 0 { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return err } - return dec.Decode(v) } - return nil } // NumBytesRead returns the number of bytes read. @@ -55,6 +64,10 @@ func (dec *Decoder) NumBytesRead() int { } func (dec *Decoder) read() (int, error) { + if dec.readError != nil { + return 0, dec.readError + } + // Grow buf if needed. const minRead = 512 if cap(dec.buf)-len(dec.buf)+dec.off < minRead { @@ -71,6 +84,7 @@ func (dec *Decoder) read() (int, error) { // Read from reader and reslice buf. n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)]) dec.buf = dec.buf[0 : len(dec.buf)+n] + dec.readError = err return n, err } diff --git a/stream_test.go b/stream_test.go index 8d4282e7..7165039e 100644 --- a/stream_test.go +++ b/stream_test.go @@ -5,6 +5,7 @@ package cbor import ( "bytes" + "errors" "fmt" "io" "reflect" @@ -40,14 +41,16 @@ func TestDecoder(t *testing.T) { } } } - // no more data - var v interface{} - err := decoder.Decode(&v) - if v != nil { - t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) - } - if err != io.EOF { - t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err) + for i := 0; i < 2; i++ { + // no more data + var v interface{} + err := decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + } + if err != io.EOF { + t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err) + } } } @@ -94,14 +97,98 @@ func TestDecoderUnmarshalTypeError(t *testing.T) { } } } - // no more data - var v interface{} - err := decoder.Decode(&v) - if v != nil { - t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + for i := 0; i < 2; i++ { + // no more data + var v interface{} + err := decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + } + if err != io.EOF { + t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err) + } + } +} + +func TestDecoderUnexpectedEOFError(t *testing.T) { + var buf bytes.Buffer + for _, tc := range unmarshalTests { + buf.Write(tc.cborData) + } + buf.Truncate(buf.Len() - 1) + + decoder := NewDecoder(&buf) + bytesRead := 0 + for i := 0; i < len(unmarshalTests)-1; i++ { + tc := unmarshalTests[i] + var v interface{} + if err := decoder.Decode(&v); err != nil { + t.Fatalf("Decode() returned error %v", err) + } + if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { + if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + for i := 0; i < 2; i++ { + // truncated data + var v interface{} + err := decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + } + if err != io.ErrUnexpectedEOF { + t.Errorf("Decode() returned error %v, want io.UnexpectedEOF (truncated data)", err) + } + } +} + +func TestDecoderReadError(t *testing.T) { + var buf bytes.Buffer + for _, tc := range unmarshalTests { + buf.Write(tc.cborData) + } + buf.Truncate(buf.Len() - 1) + + readerErr := errors.New("reader error") + + decoder := NewDecoder(NewErrorReader(buf.Bytes(), readerErr)) + bytesRead := 0 + for i := 0; i < len(unmarshalTests)-1; i++ { + tc := unmarshalTests[i] + var v interface{} + if err := decoder.Decode(&v); err != nil { + t.Fatalf("Decode() returned error %v", err) + } + if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { + if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } } - if err != io.EOF { - t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err) + for i := 0; i < 2; i++ { + // truncated data because Reader returned error + var v interface{} + err := decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + } + if err != readerErr { + t.Errorf("Decode() returned error %v, want reader error", err) + } } } @@ -438,3 +525,25 @@ func TestNilRawMessageUnmarshalCBORError(t *testing.T) { t.Errorf("UnmarshalCBOR() returned error %q, want %q", err.Error(), wantErrorMsg) } } + +type ErrorReader struct { + data []byte + off int + err error +} + +func NewErrorReader(data []byte, err error) *ErrorReader { + return &ErrorReader{data: data, err: err} +} + +func (r *ErrorReader) Read(b []byte) (int, error) { + var n int + if r.off < len(r.data) { + n = copy(b, r.data[r.off:]) + r.off += n + } + if n < len(b) { + return n, r.err + } + return n, nil +}