Skip to content

Commit

Permalink
zstd: free Decoder resources when Reset is called with a nil io.Reader (
Browse files Browse the repository at this point in the history
#305)

* zstd: free Decoder resources when Reset is called with a nil io.Reader
* Return a sentinel error when trying to read from a nil reader.
* Return the same error if the Decoder was created with a nil reader.

Fixes #296.
  • Loading branch information
mostynb committed Dec 30, 2020
1 parent bb5ba3d commit fa5ea64
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
20 changes: 15 additions & 5 deletions zstd/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) {
d.current.output = make(chan decodeOutput, d.o.concurrent)
d.current.flushed = true

if r == nil {
d.current.err = ErrDecoderNilInput
}

// Transfer option dicts.
d.dicts = make(map[uint32]dict, len(d.o.dicts))
for _, dc := range d.o.dicts {
Expand All @@ -111,7 +115,7 @@ func NewReader(r io.Reader, opts ...DOption) (*Decoder, error) {
// When the stream is done, io.EOF will be returned.
func (d *Decoder) Read(p []byte) (int, error) {
if d.stream == nil {
return 0, errors.New("no input has been initialized")
return 0, ErrDecoderNilInput
}
var n int
for {
Expand Down Expand Up @@ -152,12 +156,20 @@ func (d *Decoder) Read(p []byte) (int, error) {

// Reset will reset the decoder the supplied stream after the current has finished processing.
// Note that this functionality cannot be used after Close has been called.
// Reset can be called with a nil reader to release references to the previous reader.
// After being called with a nil reader, no other operations than Reset or DecodeAll or Close
// should be used.
func (d *Decoder) Reset(r io.Reader) error {
if d.current.err == ErrDecoderClosed {
return d.current.err
}

d.drainOutput()

if r == nil {
return errors.New("nil Reader sent as input")
d.current.err = ErrDecoderNilInput
d.current.flushed = true
return nil
}

if d.stream == nil {
Expand All @@ -166,8 +178,6 @@ func (d *Decoder) Reset(r io.Reader) error {
go d.startStreamDecoder(d.stream)
}

d.drainOutput()

// If bytes buffer and < 1MB, do sync decoding anyway.
if bb, ok := r.(*bytes.Buffer); ok && bb.Len() < 1<<20 {
if debug {
Expand Down Expand Up @@ -249,7 +259,7 @@ func (d *Decoder) drainOutput() {
// Any error encountered during the write is also returned.
func (d *Decoder) WriteTo(w io.Writer) (int64, error) {
if d.stream == nil {
return 0, errors.New("no input has been initialized")
return 0, ErrDecoderNilInput
}
var n int64
for {
Expand Down
42 changes: 42 additions & 0 deletions zstd/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,48 @@ func TestPredefTables(t *testing.T) {
}
}

func TestResetNil(t *testing.T) {
dec, err := NewReader(nil)
if err != nil {
t.Fatal(err)
}
defer dec.Close()

_, err = ioutil.ReadAll(dec)
if err != ErrDecoderNilInput {
t.Fatalf("Expected ErrDecoderNilInput when decoding from a nil reader, got %v", err)
}

emptyZstdBlob := []byte{40, 181, 47, 253, 32, 0, 1, 0, 0}

dec.Reset(bytes.NewBuffer(emptyZstdBlob))

result, err := ioutil.ReadAll(dec)
if err != nil && err != io.EOF {
t.Fatal(err)
}
if len(result) != 0 {
t.Fatalf("Expected to read 0 bytes, actually read %d", len(result))
}

dec.Reset(nil)

_, err = ioutil.ReadAll(dec)
if err != ErrDecoderNilInput {
t.Fatalf("Expected ErrDecoderNilInput when decoding from a nil reader, got %v", err)
}

dec.Reset(bytes.NewBuffer(emptyZstdBlob))

result, err = ioutil.ReadAll(dec)
if err != nil && err != io.EOF {
t.Fatal(err)
}
if len(result) != 0 {
t.Fatalf("Expected to read 0 bytes, actually read %d", len(result))
}
}

func timeout(after time.Duration) (cancel func()) {
c := time.After(after)
cc := make(chan struct{})
Expand Down
4 changes: 4 additions & 0 deletions zstd/zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ var (
// ErrDecoderClosed will be returned if the Decoder was used after
// Close has been called.
ErrDecoderClosed = errors.New("decoder used after Close")

// ErrDecoderNilInput is returned when a nil Reader was provided
// and an operation other than Reset/DecodeAll/Close was attempted.
ErrDecoderNilInput = errors.New("nil input provided as reader")
)

func println(a ...interface{}) {
Expand Down

0 comments on commit fa5ea64

Please sign in to comment.