Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Decoder.Skip() to skip CBOR data item in CBOR Sequences (RFC 8742) #381

Merged
merged 1 commit into from
Jan 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,38 @@ func (dec *Decoder) Decode(v interface{}) error {
}
}

// Skip skips to the next CBOR data item (if there is any),
// otherwise it returns error such as io.EOF, io.UnexpectedEOF, etc.
func (dec *Decoder) Skip() error {
if len(dec.buf) == dec.off {
if n, err := dec.read(); n == 0 {
return err
}
}
for {
dec.d.reset(dec.buf[dec.off:])
err := dec.d.valid(true)
if err == nil {
// Only increment dec.off if current CBOR data item is valid.
// If current data item is incomplete (io.ErrUnexpectedEOF),
// we want to try again after reading more data.
dec.off += dec.d.off
dec.bytesRead += dec.d.off
return nil
}
if err != io.ErrUnexpectedEOF {
return err
}
// Need to read more data.
if n, err := dec.read(); n == 0 {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return err
}
}
}

// NumBytesRead returns the number of bytes read.
func (dec *Decoder) NumBytesRead() int {
return dec.bytesRead
Expand Down
124 changes: 121 additions & 3 deletions stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io"
"reflect"
"strings"
"testing"
"time"
)
Expand Down Expand Up @@ -193,7 +194,7 @@ func TestDecoderReadError(t *testing.T) {
}

func TestDecoderInvalidData(t *testing.T) {
data := []byte{0x01, 0x83, 0x01, 0x02}
data := []byte{0x01, 0x3e}
decoder := NewDecoder(bytes.NewReader(data))

var v1 interface{}
Expand All @@ -206,8 +207,125 @@ func TestDecoderInvalidData(t *testing.T) {
err = decoder.Decode(&v2)
if err == nil {
t.Errorf("Decode() didn't return error when decoding invalid data item")
} else if err != io.ErrUnexpectedEOF {
t.Errorf("Decode() error %q, want %q", err, io.ErrUnexpectedEOF)
} else if !strings.Contains(err.Error(), "cbor: invalid additional information") {
t.Errorf("Decode() error %q, want \"cbor: invalid additional information\"", err)
}
}

func TestDecoderSkip(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Nice work keeping coverage above 98%.

var buf bytes.Buffer
for i := 0; i < 5; i++ {
for _, tc := range unmarshalTests {
buf.Write(tc.cborData)
}
}
decoder := NewDecoder(&buf)
bytesRead := 0
for i := 0; i < 5; i++ {
for _, tc := range unmarshalTests {
if err := decoder.Skip(); err != nil {
t.Fatalf("Skip() returned error %v", err)
}
bytesRead += len(tc.cborData)
if decoder.NumBytesRead() != bytesRead {
t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead)
}
}
}
for i := 0; i < 2; i++ {
// no more data
err := decoder.Skip()
if err != io.EOF {
t.Errorf("Skip() returned error %v, want io.EOF (no more data)", err)
}
}
}

func TestDecoderSkipInvalidDataError(t *testing.T) {
var buf bytes.Buffer
for _, tc := range unmarshalTests {
buf.Write(tc.cborData)
}
buf.WriteByte(0x3e)

decoder := NewDecoder(&buf)
bytesRead := 0
for i := 0; i < len(unmarshalTests); i++ {
tc := unmarshalTests[i]
if err := decoder.Skip(); err != nil {
t.Fatalf("Skip() returned error %v", err)
}
bytesRead += len(tc.cborData)
if decoder.NumBytesRead() != bytesRead {
t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead)
}
}
for i := 0; i < 2; i++ {
// last data item is invalid
err := decoder.Skip()
if err == nil {
t.Fatalf("Skip() didn't return error")
} else if !strings.Contains(err.Error(), "cbor: invalid additional information") {
t.Errorf("Skip() error %q, want \"cbor: invalid additional information\"", err)
}
}
}

func TestDecoderSkipUnexpectedEOFError(t *testing.T) {
var buf bytes.Buffer
for _, tc := range unmarshalTests {
buf.Write(tc.cborData)
}
buf.Truncate(buf.Len() - 1)

decoder := NewDecoder(&buf)
bytesRead := 0
for i := 0; i < len(unmarshalTests)-1; i++ {
tc := unmarshalTests[i]
if err := decoder.Skip(); err != nil {
t.Fatalf("Skip() returned error %v", err)
}
bytesRead += len(tc.cborData)
if decoder.NumBytesRead() != bytesRead {
t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead)
}
}
for i := 0; i < 2; i++ {
// last data item is invalid
err := decoder.Skip()
if err != io.ErrUnexpectedEOF {
t.Errorf("Skip() returned error %v, want io.ErrUnexpectedEOF (truncated data)", err)
}
}
}

func TestDecoderSkipReadError(t *testing.T) {
var buf bytes.Buffer
for _, tc := range unmarshalTests {
buf.Write(tc.cborData)
}
buf.Truncate(buf.Len() - 1)

readerErr := errors.New("reader error")

decoder := NewDecoder(NewErrorReader(buf.Bytes(), readerErr))
bytesRead := 0
for i := 0; i < len(unmarshalTests)-1; i++ {
tc := unmarshalTests[i]
if err := decoder.Skip(); err != nil {
t.Fatalf("Skip() returned error %v", err)
}
bytesRead += len(tc.cborData)
if decoder.NumBytesRead() != bytesRead {
t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead)
}
}
for i := 0; i < 2; i++ {
// truncated data because Reader returned error
err := decoder.Skip()
if err != readerErr {
t.Errorf("Skip() returned error %v, want reader error", err)
}
}
}

Expand Down