From 22488b2cc00fa05097e32ba22951fc05607b0a1c Mon Sep 17 00:00:00 2001 From: Marco Munizaga Date: Wed, 9 Mar 2022 22:18:38 -0800 Subject: [PATCH] Use defer and remove ReaderWithEOFContext --- cbor_cid.go | 1 - gen.go | 22 +++++++-- peeker.go | 5 -- testing/cbor_gen.go | 66 +++++++++++++++++++++----- testing/cbor_map_gen.go | 55 ++++++++++++++++++---- utils.go | 102 +++++++++------------------------------- utils_test.go | 22 --------- 7 files changed, 139 insertions(+), 134 deletions(-) diff --git a/cbor_cid.go b/cbor_cid.go index fee2a5d..57d6ec2 100644 --- a/cbor_cid.go +++ b/cbor_cid.go @@ -13,7 +13,6 @@ func (c CborCid) MarshalCBOR(w io.Writer) error { } func (c *CborCid) UnmarshalCBOR(r io.Reader) error { - r = NewReaderWithEOFContext(r) oc, err := ReadCid(r) if err != nil { return err diff --git a/gen.go b/gen.go index df2eca9..1ed373d 100644 --- a/gen.go +++ b/gen.go @@ -1032,17 +1032,24 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error { func emitCborUnmarshalStructTuple(w io.Writer, gti *GenTypeInfo) error { err := doTemplate(w, gti, ` -func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { - r = cbg.NewReaderWithEOFContext(r) +func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) { *t = {{.Name}}{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := {{ ReadHeader "br" }} if err != nil { return err } + hasReadOnce = true + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } @@ -1191,17 +1198,24 @@ func emitCborMarshalStructMap(w io.Writer, gti *GenTypeInfo) error { func emitCborUnmarshalStructMap(w io.Writer, gti *GenTypeInfo) error { err := doTemplate(w, gti, ` -func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) error { - r = cbg.NewReaderWithEOFContext(r) +func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) { *t = {{.Name}}{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := {{ ReadHeader "br" }} if err != nil { return err } + hasReadOnce = true + if maj != cbg.MajMap { return fmt.Errorf("cbor input should be of type map") } diff --git a/peeker.go b/peeker.go index a9225d8..3600a84 100644 --- a/peeker.go +++ b/peeker.go @@ -12,11 +12,6 @@ type BytePeeker interface { } func GetPeeker(r io.Reader) BytePeeker { - if r, ok := r.(*ReaderWithEOFContext); ok { - if r, ok := r.R.(BytePeeker); ok { - return r - } - } if r, ok := r.(BytePeeker); ok { return r } diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index 125e220..7614710 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -47,17 +47,24 @@ func (t *SignedArray) MarshalCBOR(w io.Writer) error { return nil } -func (t *SignedArray) UnmarshalCBOR(r io.Reader) error { - r = cbg.NewReaderWithEOFContext(r) +func (t *SignedArray) UnmarshalCBOR(r io.Reader) (err error) { *t = SignedArray{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } + hasReadOnce = true + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } @@ -171,17 +178,24 @@ func (t *SimpleTypeOne) MarshalCBOR(w io.Writer) error { return nil } -func (t *SimpleTypeOne) UnmarshalCBOR(r io.Reader) error { - r = cbg.NewReaderWithEOFContext(r) +func (t *SimpleTypeOne) UnmarshalCBOR(r io.Reader) (err error) { *t = SimpleTypeOne{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } + hasReadOnce = true + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } @@ -413,17 +427,24 @@ func (t *SimpleTypeTwo) MarshalCBOR(w io.Writer) error { return nil } -func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) error { - r = cbg.NewReaderWithEOFContext(r) +func (t *SimpleTypeTwo) UnmarshalCBOR(r io.Reader) (err error) { *t = SimpleTypeTwo{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } + hasReadOnce = true + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } @@ -734,17 +755,24 @@ func (t *DeferredContainer) MarshalCBOR(w io.Writer) error { return nil } -func (t *DeferredContainer) UnmarshalCBOR(r io.Reader) error { - r = cbg.NewReaderWithEOFContext(r) +func (t *DeferredContainer) UnmarshalCBOR(r io.Reader) (err error) { *t = DeferredContainer{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } + hasReadOnce = true + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } @@ -854,17 +882,24 @@ func (t *FixedArrays) MarshalCBOR(w io.Writer) error { return nil } -func (t *FixedArrays) UnmarshalCBOR(r io.Reader) error { - r = cbg.NewReaderWithEOFContext(r) +func (t *FixedArrays) UnmarshalCBOR(r io.Reader) (err error) { *t = FixedArrays{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } + hasReadOnce = true + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } @@ -1000,17 +1035,24 @@ func (t *ThingWithSomeTime) MarshalCBOR(w io.Writer) error { return nil } -func (t *ThingWithSomeTime) UnmarshalCBOR(r io.Reader) error { - r = cbg.NewReaderWithEOFContext(r) +func (t *ThingWithSomeTime) UnmarshalCBOR(r io.Reader) (err error) { *t = ThingWithSomeTime{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } + hasReadOnce = true + if maj != cbg.MajArray { return fmt.Errorf("cbor input should be of type array") } diff --git a/testing/cbor_map_gen.go b/testing/cbor_map_gen.go index 6b56e9a..3a08067 100644 --- a/testing/cbor_map_gen.go +++ b/testing/cbor_map_gen.go @@ -189,17 +189,24 @@ func (t *SimpleTypeTree) MarshalCBOR(w io.Writer) error { return nil } -func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) error { - r = cbg.NewReaderWithEOFContext(r) +func (t *SimpleTypeTree) UnmarshalCBOR(r io.Reader) (err error) { *t = SimpleTypeTree{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } + hasReadOnce = true + if maj != cbg.MajMap { return fmt.Errorf("cbor input should be of type map") } @@ -445,17 +452,24 @@ func (t *NeedScratchForMap) MarshalCBOR(w io.Writer) error { return nil } -func (t *NeedScratchForMap) UnmarshalCBOR(r io.Reader) error { - r = cbg.NewReaderWithEOFContext(r) +func (t *NeedScratchForMap) UnmarshalCBOR(r io.Reader) (err error) { *t = NeedScratchForMap{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } + hasReadOnce = true + if maj != cbg.MajMap { return fmt.Errorf("cbor input should be of type map") } @@ -692,17 +706,24 @@ func (t *SimpleStructV1) MarshalCBOR(w io.Writer) error { return nil } -func (t *SimpleStructV1) UnmarshalCBOR(r io.Reader) error { - r = cbg.NewReaderWithEOFContext(r) +func (t *SimpleStructV1) UnmarshalCBOR(r io.Reader) (err error) { *t = SimpleStructV1{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } + hasReadOnce = true + if maj != cbg.MajMap { return fmt.Errorf("cbor input should be of type map") } @@ -1247,17 +1268,24 @@ func (t *SimpleStructV2) MarshalCBOR(w io.Writer) error { return nil } -func (t *SimpleStructV2) UnmarshalCBOR(r io.Reader) error { - r = cbg.NewReaderWithEOFContext(r) +func (t *SimpleStructV2) UnmarshalCBOR(r io.Reader) (err error) { *t = SimpleStructV2{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } + hasReadOnce = true + if maj != cbg.MajMap { return fmt.Errorf("cbor input should be of type map") } @@ -1654,17 +1682,24 @@ func (t *RenamedFields) MarshalCBOR(w io.Writer) error { return nil } -func (t *RenamedFields) UnmarshalCBOR(r io.Reader) error { - r = cbg.NewReaderWithEOFContext(r) +func (t *RenamedFields) UnmarshalCBOR(r io.Reader) (err error) { *t = RenamedFields{} br := cbg.GetPeeker(r) scratch := make([]byte, 8) + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := cbg.CborReadHeaderBuf(br, scratch) if err != nil { return err } + hasReadOnce = true + if maj != cbg.MajMap { return fmt.Errorf("cbor input should be of type map") } diff --git a/utils.go b/utils.go index 8efe66f..0591b9f 100644 --- a/utils.go +++ b/utils.go @@ -27,13 +27,6 @@ func discard(br io.Reader, n int) error { } switch r := br.(type) { - case *ReaderWithEOFContext: - err := discard(r.R, n) - if err == io.EOF && r.hasReadOnce { - err = io.ErrUnexpectedEOF - } - r.hasReadOnce = true - return err case *bytes.Buffer: buf := r.Next(n) if len(buf) == 0 { @@ -67,14 +60,16 @@ func discard(br io.Reader, n int) error { } } -func ScanForLinks(br io.Reader, cb func(cid.Cid)) error { +func ScanForLinks(br io.Reader, cb func(cid.Cid)) (err error) { hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() scratch := make([]byte, maxCidLength) for remaining := uint64(1); remaining > 0; remaining-- { maj, extra, err := CborReadHeaderBuf(br, scratch) - if err == io.EOF && hasReadOnce { - return io.ErrUnexpectedEOF - } if err != nil { return err } @@ -84,18 +79,12 @@ func ScanForLinks(br io.Reader, cb func(cid.Cid)) error { case MajUnsignedInt, MajNegativeInt, MajOther: case MajByteString, MajTextString: err := discard(br, int(extra)) - if err == io.EOF && hasReadOnce { - return io.ErrUnexpectedEOF - } if err != nil { return err } case MajTag: if extra == 42 { maj, extra, err = CborReadHeaderBuf(br, scratch) - if err == io.EOF && hasReadOnce { - return io.ErrUnexpectedEOF - } if err != nil { return err } @@ -109,9 +98,6 @@ func ScanForLinks(br io.Reader, cb func(cid.Cid)) error { } if _, err := io.ReadAtLeast(br, scratch[:extra], int(extra)); err != nil { - if err == io.EOF && hasReadOnce { - return io.ErrUnexpectedEOF - } return err } @@ -174,7 +160,6 @@ func (d *Deferred) MarshalCBOR(w io.Writer) error { func (d *Deferred) UnmarshalCBOR(br io.Reader) error { // Reuse any existing buffers. - br = NewReaderWithEOFContext(br) reusedBuf := d.Raw[:0] d.Raw = nil buf := bytes.NewBuffer(reusedBuf) @@ -240,13 +225,6 @@ func readByte(r io.Reader) (byte, error) { // try to cast to a concrete type, it's much faster than casting to an // interface. switch r := r.(type) { - case *ReaderWithEOFContext: - b, err := readByte(r.R) - if err == io.EOF && r.hasReadOnce { - err = io.ErrUnexpectedEOF - } - r.hasReadOnce = true - return b, err case *bytes.Buffer: return r.ReadByte() case *bytes.Reader: @@ -263,11 +241,18 @@ func readByte(r io.Reader) (byte, error) { return buf[0], err } -func CborReadHeader(br io.Reader) (byte, uint64, error) { +func CborReadHeader(br io.Reader) (_b byte, _ui uint64, err error) { + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() first, err := readByte(br) if err != nil { return 0, 0, err } + hasReadOnce = true maj := (first & 0xe0) >> 5 low := first & 0x1f @@ -323,13 +308,6 @@ func readByteBuf(r io.Reader, scratch []byte) (byte, error) { // Reading a single byte from these buffers is much faster than copying // into a slice. switch r := r.(type) { - case *ReaderWithEOFContext: - b, err := readByteBuf(r.R, scratch) - if err == io.EOF && r.hasReadOnce { - err = io.ErrUnexpectedEOF - } - r.hasReadOnce = true - return b, err case *bytes.Buffer: return r.ReadByte() case *bytes.Reader: @@ -493,11 +471,18 @@ func CborEncodeMajorType(t byte, l uint64) []byte { } } -func ReadTaggedByteArray(br io.Reader, exptag uint64, maxlen uint64) ([]byte, error) { +func ReadTaggedByteArray(br io.Reader, exptag uint64, maxlen uint64) (bs []byte, err error) { + hasReadOnce := false + defer func() { + if err == io.EOF && hasReadOnce { + err = io.ErrUnexpectedEOF + } + }() maj, extra, err := CborReadHeader(br) if err != nil { return nil, err } + hasReadOnce = true if maj != MajTag { return nil, fmt.Errorf("expected cbor type 'tag' in input") @@ -681,7 +666,6 @@ func (cb CborBool) MarshalCBOR(w io.Writer) error { } func (cb *CborBool) UnmarshalCBOR(r io.Reader) error { - r = NewReaderWithEOFContext(r) t, val, err := CborReadHeader(r) if err != nil { return err @@ -719,7 +703,6 @@ func (ci CborInt) MarshalCBOR(w io.Writer) error { } func (ci *CborInt) UnmarshalCBOR(r io.Reader) error { - r = NewReaderWithEOFContext(r) maj, extra, err := CborReadHeader(r) if err != nil { return err @@ -756,7 +739,6 @@ func (ct CborTime) MarshalCBOR(w io.Writer) error { } func (ct *CborTime) UnmarshalCBOR(r io.Reader) error { - r = NewReaderWithEOFContext(r) var cbi CborInt if err := cbi.UnmarshalCBOR(r); err != nil { return err @@ -784,43 +766,3 @@ func (ct *CborTime) UnmarshalJSON(b []byte) error { *(*time.Time)(ct) = t return nil } - -// ReaderWithEOFContext keeps track of whether it was able to read at least one -// byte from the underlying reader. It uses this context to either return EOF or -// ErrUnexpectedEOF using the following rule: -// - if we were not able to read a single byte because of EOF, it returns EOF. -// - if we were able to read at least a single byte, but runs into an EOF -// later, then it returns ErrUnexpectedEOF. -// -// This reader is useful when a function reads from a reader multiple times and -// the function should only return EOF if it was not able to read a single byte -// because of EOF. Otherwise it should return ErrUnexpectedEOF. For example, -// when we unmarshal a CBOR blob into an object it should only return EOF if the -// input blob was empty, otherwise it should return an ErrUnexpectedEOF since it -// started decoding the blob but failed. -type ReaderWithEOFContext struct { - R io.Reader - hasReadOnce bool -} - -func (r *ReaderWithEOFContext) Read(p []byte) (n int, err error) { - if !r.hasReadOnce { - r.hasReadOnce = true - return r.R.Read(p) - } - - n, err = r.R.Read(p) - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return n, err -} - -//go:inline -func NewReaderWithEOFContext(r io.Reader) *ReaderWithEOFContext { - if rc, ok := r.(*ReaderWithEOFContext); ok { - return rc - } else { - return &ReaderWithEOFContext{R: r} - } -} diff --git a/utils_test.go b/utils_test.go index a23eafa..75e19e3 100644 --- a/utils_test.go +++ b/utils_test.go @@ -70,26 +70,6 @@ func TestScanForLinksShouldReturnEOFWhenNothingRead(t *testing.T) { t.Log(cids) } -func TestReaderWithEOFContext(t *testing.T) { - emptyReader := &ReaderWithEOFContext{R: strings.NewReader("")} - buf := make([]byte, 1) - _, err := emptyReader.Read(buf) - if err != io.EOF { - t.Fatal(err) - } - - oneByteReader := &ReaderWithEOFContext{R: strings.NewReader("1")} - _, err = io.ReadFull(oneByteReader, buf) - if err != nil { - t.Fatal(err) - } - - _, err = io.ReadFull(oneByteReader, buf) - if err != io.ErrUnexpectedEOF { - t.Fatal(err) - } -} - func TestDeferredMaxLengthSingle(t *testing.T) { var header bytes.Buffer if err := WriteMajorTypeHeader(&header, MajByteString, ByteArrayMaxLen+1); err != nil { @@ -117,8 +97,6 @@ func TestReadEOFSemantics(t *testing.T) { } newTestCases := func() []testCase { return []testCase{ - {name: "Reader with EOF context that returns EOF and n bytes read", reader: &ReaderWithEOFContext{R: &testReader1Byte{b: 0x01}}, shouldFail: false}, - {name: "Reader with EOF context around Empty Byte Reader", reader: &ReaderWithEOFContext{R: bytes.NewReader([]byte{})}, shouldFail: true}, {name: "Reader that returns EOF and n bytes read", reader: &testReader1Byte{b: 0x01}, shouldFail: false}, {name: "Peeker with Reader that returns EOF and n bytes read", reader: GetPeeker(&testReader1Byte{b: 0x01}), shouldFail: false}, {name: "Peeker with Exhausted Reader", reader: GetPeeker(&testReader1Byte{b: 0x01, emptied: true}), shouldFail: true},