diff --git a/zstd/decoder.go b/zstd/decoder.go index cdda0de58b..62fd373240 100644 --- a/zstd/decoder.go +++ b/zstd/decoder.go @@ -85,6 +85,10 @@ func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) { d.current.output = make(chan decodeOutput, d.o.concurrent) d.current.flushed = true + if r == nil { + d.current.err = ErrDecoderNilInput + } + // Transfer option dicts. d.dicts = make(map[uint32]dict, len(d.o.dicts)) for _, dc := range d.o.dicts { @@ -111,7 +115,7 @@ func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) { // When the stream is done, io.EOF will be returned. func (d *Decoder) Read(p []byte) (int, error) { if d.stream == nil { - return 0, errors.New("no input has been initialized") + return 0, ErrDecoderNilInput } var n int for { @@ -152,12 +156,20 @@ func (d *Decoder) Read(p []byte) (int, error) { // Reset will reset the decoder the supplied stream after the current has finished processing. // Note that this functionality cannot be used after Close has been called. +// Reset can be called with a nil reader to release references to the previous reader. +// After being called with a nil reader, no other operations than Reset or DecodeAll or Close +// should be used. func (d *Decoder) Reset(r io.Reader) error { if d.current.err == ErrDecoderClosed { return d.current.err } + + d.drainOutput() + if r == nil { - return errors.New("nil Reader sent as input") + d.current.err = ErrDecoderNilInput + d.current.flushed = true + return nil } if d.stream == nil { @@ -166,8 +178,6 @@ func (d *Decoder) Reset(r io.Reader) error { go d.startStreamDecoder(d.stream) } - d.drainOutput() - // If bytes buffer and < 1MB, do sync decoding anyway. if bb, ok := r.(*bytes.Buffer); ok && bb.Len() < 1<<20 { if debug { @@ -249,7 +259,7 @@ func (d *Decoder) drainOutput() { // Any error encountered during the write is also returned. func (d *Decoder) WriteTo(w io.Writer) (int64, error) { if d.stream == nil { - return 0, errors.New("no input has been initialized") + return 0, ErrDecoderNilInput } var n int64 for { diff --git a/zstd/decoder_test.go b/zstd/decoder_test.go index 23c8688d18..223016643a 100644 --- a/zstd/decoder_test.go +++ b/zstd/decoder_test.go @@ -1434,6 +1434,48 @@ func TestPredefTables(t *testing.T) { } } +func TestResetNil(t *testing.T) { + dec, err := NewReader(nil) + if err != nil { + t.Fatal(err) + } + defer dec.Close() + + _, err = ioutil.ReadAll(dec) + if err != ErrDecoderNilInput { + t.Fatalf("Expected ErrDecoderNilInput when decoding from a nil reader, got %v", err) + } + + emptyZstdBlob := []byte{40, 181, 47, 253, 32, 0, 1, 0, 0} + + dec.Reset(bytes.NewBuffer(emptyZstdBlob)) + + result, err := ioutil.ReadAll(dec) + if err != nil && err != io.EOF { + t.Fatal(err) + } + if len(result) != 0 { + t.Fatalf("Expected to read 0 bytes, actually read %d", len(result)) + } + + dec.Reset(nil) + + _, err = ioutil.ReadAll(dec) + if err != ErrDecoderNilInput { + t.Fatalf("Expected ErrDecoderNilInput when decoding from a nil reader, got %v", err) + } + + dec.Reset(bytes.NewBuffer(emptyZstdBlob)) + + result, err = ioutil.ReadAll(dec) + if err != nil && err != io.EOF { + t.Fatal(err) + } + if len(result) != 0 { + t.Fatalf("Expected to read 0 bytes, actually read %d", len(result)) + } +} + func timeout(after time.Duration) (cancel func()) { c := time.After(after) cc := make(chan struct{}) diff --git a/zstd/zstd.go b/zstd/zstd.go index 0807719c8b..0c761dd626 100644 --- a/zstd/zstd.go +++ b/zstd/zstd.go @@ -73,6 +73,10 @@ var ( // ErrDecoderClosed will be returned if the Decoder was used after // Close has been called. ErrDecoderClosed = errors.New("decoder used after Close") + + // ErrDecoderNilInput is returned when a nil Reader was provided + // and an operation other than Reset/DecodeAll/Close was attempted. + ErrDecoderNilInput = errors.New("nil input provided as reader") ) func println(a ...interface{}) {