From 0fb145e4ece11a4bac5660d1e71c57a83ad0575f Mon Sep 17 00:00:00 2001 From: Faye Amacker <33205765+fxamacker@users.noreply.github.com> Date: Sun, 1 Jan 2023 14:46:44 -0600 Subject: [PATCH] Add Decoder.Skip() to skip CBOR data item When decoding CBOR Sequences (RFC 8742), it can be useful to skip to the next CBOR data item without decoding. This change adds Decoder.Skip() which will skip to next CBOR data item (if there is any) otherwise it will return error such as io.EOF, io.UnexpectedEOF, etc. --- stream.go | 32 +++++++++++++ stream_test.go | 124 +++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 153 insertions(+), 3 deletions(-) diff --git a/stream.go b/stream.go index 9337940e..c559f52d 100644 --- a/stream.go +++ b/stream.go @@ -58,6 +58,38 @@ func (dec *Decoder) Decode(v interface{}) error { } } +// Skip skips to the next CBOR data item (if there is any), +// otherwise it returns error such as io.EOF, io.UnexpectedEOF, etc. +func (dec *Decoder) Skip() error { + if len(dec.buf) == dec.off { + if n, err := dec.read(); n == 0 { + return err + } + } + for { + dec.d.reset(dec.buf[dec.off:]) + err := dec.d.valid(true) + if err == nil { + // Only increment dec.off if current CBOR data item is valid. + // If current data item is incomplete (io.ErrUnexpectedEOF), + // we want to try again after reading more data. + dec.off += dec.d.off + dec.bytesRead += dec.d.off + return nil + } + if err != io.ErrUnexpectedEOF { + return err + } + // Need to read more data. + if n, err := dec.read(); n == 0 { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return err + } + } +} + // NumBytesRead returns the number of bytes read. func (dec *Decoder) NumBytesRead() int { return dec.bytesRead diff --git a/stream_test.go b/stream_test.go index 6dfe3393..211d79d8 100644 --- a/stream_test.go +++ b/stream_test.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "reflect" + "strings" "testing" "time" ) @@ -193,7 +194,7 @@ func TestDecoderReadError(t *testing.T) { } func TestDecoderInvalidData(t *testing.T) { - data := []byte{0x01, 0x83, 0x01, 0x02} + data := []byte{0x01, 0x3e} decoder := NewDecoder(bytes.NewReader(data)) var v1 interface{} @@ -206,8 +207,125 @@ func TestDecoderInvalidData(t *testing.T) { err = decoder.Decode(&v2) if err == nil { t.Errorf("Decode() didn't return error when decoding invalid data item") - } else if err != io.ErrUnexpectedEOF { - t.Errorf("Decode() error %q, want %q", err, io.ErrUnexpectedEOF) + } else if !strings.Contains(err.Error(), "cbor: invalid additional information") { + t.Errorf("Decode() error %q, want \"cbor: invalid additional information\"", err) + } +} + +func TestDecoderSkip(t *testing.T) { + var buf bytes.Buffer + for i := 0; i < 5; i++ { + for _, tc := range unmarshalTests { + buf.Write(tc.cborData) + } + } + decoder := NewDecoder(&buf) + bytesRead := 0 + for i := 0; i < 5; i++ { + for _, tc := range unmarshalTests { + if err := decoder.Skip(); err != nil { + t.Fatalf("Skip() returned error %v", err) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + } + for i := 0; i < 2; i++ { + // no more data + err := decoder.Skip() + if err != io.EOF { + t.Errorf("Skip() returned error %v, want io.EOF (no more data)", err) + } + } +} + +func TestDecoderSkipInvalidDataError(t *testing.T) { + var buf bytes.Buffer + for _, tc := range unmarshalTests { + buf.Write(tc.cborData) + } + buf.WriteByte(0x3e) + + decoder := NewDecoder(&buf) + bytesRead := 0 + for i := 0; i < len(unmarshalTests); i++ { + tc := unmarshalTests[i] + if err := decoder.Skip(); err != nil { + t.Fatalf("Skip() returned error %v", err) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + for i := 0; i < 2; i++ { + // last data item is invalid + err := decoder.Skip() + if err == nil { + t.Fatalf("Skip() didn't return error") + } else if !strings.Contains(err.Error(), "cbor: invalid additional information") { + t.Errorf("Skip() error %q, want \"cbor: invalid additional information\"", err) + } + } +} + +func TestDecoderSkipUnexpectedEOFError(t *testing.T) { + var buf bytes.Buffer + for _, tc := range unmarshalTests { + buf.Write(tc.cborData) + } + buf.Truncate(buf.Len() - 1) + + decoder := NewDecoder(&buf) + bytesRead := 0 + for i := 0; i < len(unmarshalTests)-1; i++ { + tc := unmarshalTests[i] + if err := decoder.Skip(); err != nil { + t.Fatalf("Skip() returned error %v", err) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + for i := 0; i < 2; i++ { + // last data item is invalid + err := decoder.Skip() + if err != io.ErrUnexpectedEOF { + t.Errorf("Skip() returned error %v, want io.ErrUnexpectedEOF (truncated data)", err) + } + } +} + +func TestDecoderSkipReadError(t *testing.T) { + var buf bytes.Buffer + for _, tc := range unmarshalTests { + buf.Write(tc.cborData) + } + buf.Truncate(buf.Len() - 1) + + readerErr := errors.New("reader error") + + decoder := NewDecoder(NewErrorReader(buf.Bytes(), readerErr)) + bytesRead := 0 + for i := 0; i < len(unmarshalTests)-1; i++ { + tc := unmarshalTests[i] + if err := decoder.Skip(); err != nil { + t.Fatalf("Skip() returned error %v", err) + } + bytesRead += len(tc.cborData) + if decoder.NumBytesRead() != bytesRead { + t.Errorf("NumBytesRead() = %v, want %v", decoder.NumBytesRead(), bytesRead) + } + } + for i := 0; i < 2; i++ { + // truncated data because Reader returned error + err := decoder.Skip() + if err != readerErr { + t.Errorf("Skip() returned error %v, want reader error", err) + } } }