diff --git a/stream.go b/stream.go index 9337940e..c559f52d 100644 --- a/stream.go +++ b/stream.go @@ -58,6 +58,38 @@ func (dec *Decoder) Decode(v interface{}) error { } } +// Skip skips to the next CBOR data item (if there is any), +// otherwise it returns error such as io.EOF, io.UnexpectedEOF, etc. +func (dec *Decoder) Skip() error { + if len(dec.buf) == dec.off { + if n, err := dec.read(); n == 0 { + return err + } + } + for { + dec.d.reset(dec.buf[dec.off:]) + err := dec.d.valid(true) + if err == nil { + // Only increment dec.off if current CBOR data item is valid. + // If current data item is incomplete (io.ErrUnexpectedEOF), + // we want to try again after reading more data. + dec.off += dec.d.off + dec.bytesRead += dec.d.off + return nil + } + if err != io.ErrUnexpectedEOF { + return err + } + // Need to read more data. + if n, err := dec.read(); n == 0 { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return err + } + } +} + // NumBytesRead returns the number of bytes read. func (dec *Decoder) NumBytesRead() int { return dec.bytesRead diff --git a/stream_test.go b/stream_test.go index 6dfe3393..211d79d8 100644 --- a/stream_test.go +++ b/stream_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "reflect" + "strings" "testing" "time" ) @@ -193,7 +194,7 @@ func TestDecoderReadError(t *testing.T) { } func TestDecoderInvalidData(t *testing.T) { - data := []byte{0x01, 0x83, 0x01, 0x02} + data := []byte{0x01, 0x3e} decoder := NewDecoder(bytes.NewReader(data)) var v1 interface{} @@ -206,8 +207,125 @@ func TestDecoderInvalidData(t *testing.T) { err = decoder.Decode(&v2) if err == nil { t.Errorf("Decode() didn't return error when decoding invalid data item") - } else if err != io.ErrUnexpectedEOF { - t.Errorf("Decode() error %q, want %q", err, io.ErrUnexpectedEOF) + } else if !strings.Contains(err.Error(), "cbor: invalid additional information") { + t.Errorf("Decode() error %q, want \"cbor: invalid additional information\"", err) + } +} + +func TestDecoderSkip(t *testing.T) { + var buf bytes.Buffer + for i := 0; i < 5; i++ { + for _, tc := range unmarshalTests { + buf.Write(tc.cborData) + } + } + decoder := NewDecoder(&buf) + bytesRead := 0 + for i := 0; i < 5; i++ { + for _, tc := range unmarshalTests { + if err := decoder.Skip(); err != nil { + t.Fatalf("Skip() returned error %v", err) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + } + for i := 0; i < 2; i++ { + // no more data + err := decoder.Skip() + if err != io.EOF { + t.Errorf("Skip() returned error %v, want io.EOF (no more data)", err) + } + } +} + +func TestDecoderSkipInvalidDataError(t *testing.T) { + var buf bytes.Buffer + for _, tc := range unmarshalTests { + buf.Write(tc.cborData) + } + buf.WriteByte(0x3e) + + decoder := NewDecoder(&buf) + bytesRead := 0 + for i := 0; i < len(unmarshalTests); i++ { + tc := unmarshalTests[i] + if err := decoder.Skip(); err != nil { + t.Fatalf("Skip() returned error %v", err) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + for i := 0; i < 2; i++ { + // last data item is invalid + err := decoder.Skip() + if err == nil { + t.Fatalf("Skip() didn't return error") + } else if !strings.Contains(err.Error(), "cbor: invalid additional information") { + t.Errorf("Skip() error %q, want \"cbor: invalid additional information\"", err) + } + } +} + +func TestDecoderSkipUnexpectedEOFError(t *testing.T) { + var buf bytes.Buffer + for _, tc := range unmarshalTests { + buf.Write(tc.cborData) + } + buf.Truncate(buf.Len() - 1) + + decoder := NewDecoder(&buf) + bytesRead := 0 + for i := 0; i < len(unmarshalTests)-1; i++ { + tc := unmarshalTests[i] + if err := decoder.Skip(); err != nil { + t.Fatalf("Skip() returned error %v", err) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + for i := 0; i < 2; i++ { + // last data item is invalid + err := decoder.Skip() + if err != io.ErrUnexpectedEOF { + t.Errorf("Skip() returned error %v, want io.ErrUnexpectedEOF (truncated data)", err) + } + } +} + +func TestDecoderSkipReadError(t *testing.T) { + var buf bytes.Buffer + for _, tc := range unmarshalTests { + buf.Write(tc.cborData) + } + buf.Truncate(buf.Len() - 1) + + readerErr := errors.New("reader error") + + decoder := NewDecoder(NewErrorReader(buf.Bytes(), readerErr)) + bytesRead := 0 + for i := 0; i < len(unmarshalTests)-1; i++ { + tc := unmarshalTests[i] + if err := decoder.Skip(); err != nil { + t.Fatalf("Skip() returned error %v", err) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + for i := 0; i < 2; i++ { + // truncated data because Reader returned error + err := decoder.Skip() + if err != readerErr { + t.Errorf("Skip() returned error %v, want reader error", err) + } } }