From 60589f05c1f88e8227dd9ba089f263b717f0b4ba Mon Sep 17 00:00:00 2001 From: Faye Amacker <33205765+fxamacker@users.noreply.github.com> Date: Sat, 31 Dec 2022 15:18:11 -0600 Subject: [PATCH] Fix Decoder.Decode() error type Previously, if CBOR data item was truncated and the underlying reader reached EOF, Decoder.Decode() returned io.EOF. This bug fix makes it return io.UnexpectedEOF because the CBOR data item was truncated when EOF was encountered. --- decode.go | 4 ++ stream.go | 34 ++++++++---- stream_test.go | 139 +++++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 152 insertions(+), 25 deletions(-) diff --git a/decode.go b/decode.go index c444451e..0db569ad 100644 --- a/decode.go +++ b/decode.go @@ -569,6 +569,10 @@ type decoder struct { dm *decMode } +// value decodes CBOR data item into the value pointed to by v. +// If CBOR data item is invalid, error is returned and offset isn't changed. +// If CBOR data item is valid but fails to be decode into v for other reasons, +// error is returned and offset is moved to the next CBOR data item. func (d *decoder) value(v interface{}) error { // v can't be nil, non-pointer, or nil pointer value. if v == nil { diff --git a/stream.go b/stream.go index 1d4ec706..dd5a377c 100644 --- a/stream.go +++ b/stream.go @@ -16,6 +16,7 @@ type Decoder struct { buf []byte off int // next read offset in buf bytesRead int + readError error } // NewDecoder returns a new decoder that reads and decodes from r using @@ -31,22 +32,30 @@ func (dec *Decoder) Decode(v interface{}) error { return err } } - - dec.d.reset(dec.buf[dec.off:]) - err := dec.d.value(v) - dec.off += dec.d.off - dec.bytesRead += dec.d.off - if err != nil { + for { + dec.d.reset(dec.buf[dec.off:]) + err := dec.d.value(v) + // Increment dec.off even if err is not nil because + // dec.d.off points to the next CBOR data item if current + // CBOR data item is valid but failed to be decoded into v. + // This allows next CBOR data item to be decoded in next + // call to this function. + dec.off += dec.d.off + dec.bytesRead += dec.d.off + if err == nil { + return nil + } if err != io.ErrUnexpectedEOF { return err } // Need to read more data. - if n, e := dec.read(); n == 0 { - return e + if n, err := dec.read(); n == 0 { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return err } - return dec.Decode(v) } - return nil } // NumBytesRead returns the number of bytes read. @@ -55,6 +64,10 @@ func (dec *Decoder) NumBytesRead() int { } func (dec *Decoder) read() (int, error) { + if dec.readError != nil { + return 0, dec.readError + } + // Grow buf if needed. const minRead = 512 if cap(dec.buf)-len(dec.buf)+dec.off < minRead { @@ -71,6 +84,7 @@ func (dec *Decoder) read() (int, error) { // Read from reader and reslice buf. n, err := dec.r.Read(dec.buf[len(dec.buf):cap(dec.buf)]) dec.buf = dec.buf[0 : len(dec.buf)+n] + dec.readError = err return n, err } diff --git a/stream_test.go b/stream_test.go index 8d4282e7..7165039e 100644 --- a/stream_test.go +++ b/stream_test.go @@ -5,6 +5,7 @@ package cbor import ( "bytes" + "errors" "fmt" "io" "reflect" @@ -40,14 +41,16 @@ func TestDecoder(t *testing.T) { } } } - // no more data - var v interface{} - err := decoder.Decode(&v) - if v != nil { - t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) - } - if err != io.EOF { - t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err) + for i := 0; i < 2; i++ { + // no more data + var v interface{} + err := decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + } + if err != io.EOF { + t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err) + } } } @@ -94,14 +97,98 @@ func TestDecoderUnmarshalTypeError(t *testing.T) { } } } - // no more data - var v interface{} - err := decoder.Decode(&v) - if v != nil { - t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + for i := 0; i < 2; i++ { + // no more data + var v interface{} + err := decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + } + if err != io.EOF { + t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err) + } + } +} + +func TestDecoderUnexpectedEOFError(t *testing.T) { + var buf bytes.Buffer + for _, tc := range unmarshalTests { + buf.Write(tc.cborData) + } + buf.Truncate(buf.Len() - 1) + + decoder := NewDecoder(&buf) + bytesRead := 0 + for i := 0; i < len(unmarshalTests)-1; i++ { + tc := unmarshalTests[i] + var v interface{} + if err := decoder.Decode(&v); err != nil { + t.Fatalf("Decode() returned error %v", err) + } + if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { + if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + for i := 0; i < 2; i++ { + // truncated data + var v interface{} + err := decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + } + if err != io.ErrUnexpectedEOF { + t.Errorf("Decode() returned error %v, want io.UnexpectedEOF (truncated data)", err) + } + } +} + +func TestDecoderReadError(t *testing.T) { + var buf bytes.Buffer + for _, tc := range unmarshalTests { + buf.Write(tc.cborData) + } + buf.Truncate(buf.Len() - 1) + + readerErr := errors.New("reader error") + + decoder := NewDecoder(NewErrorReader(buf.Bytes(), readerErr)) + bytesRead := 0 + for i := 0; i < len(unmarshalTests)-1; i++ { + tc := unmarshalTests[i] + var v interface{} + if err := decoder.Decode(&v); err != nil { + t.Fatalf("Decode() returned error %v", err) + } + if tm, ok := tc.emptyInterfaceValue.(time.Time); ok { + if vt, ok := v.(time.Time); !ok || !tm.Equal(vt) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + } else if !reflect.DeepEqual(v, tc.emptyInterfaceValue) { + t.Errorf("Decode() = %v (%T), want %v (%T)", v, v, tc.emptyInterfaceValue, tc.emptyInterfaceValue) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } } - if err != io.EOF { - t.Errorf("Decode() returned error %v, want io.EOF (no more data)", err) + for i := 0; i < 2; i++ { + // truncated data because Reader returned error + var v interface{} + err := decoder.Decode(&v) + if v != nil { + t.Errorf("Decode() = %v (%T), want nil (no more data)", v, v) + } + if err != readerErr { + t.Errorf("Decode() returned error %v, want reader error", err) + } } } @@ -438,3 +525,25 @@ func TestNilRawMessageUnmarshalCBORError(t *testing.T) { t.Errorf("UnmarshalCBOR() returned error %q, want %q", err.Error(), wantErrorMsg) } } + +type ErrorReader struct { + data []byte + off int + err error +} + +func NewErrorReader(data []byte, err error) *ErrorReader { + return &ErrorReader{data: data, err: err} +} + +func (r *ErrorReader) Read(b []byte) (int, error) { + var n int + if r.off < len(r.data) { + n = copy(b, r.data[r.off:]) + r.off += n + } + if n < len(b) { + return n, r.err + } + return n, nil +}