Skip to content

Commit

Permalink
zstd: Add configurable Decoder window size (#394)
Browse files Browse the repository at this point in the history
Also reduce default memory allocs.

Fixes #390
Replaces #392
  • Loading branch information
klauspost authored Jun 8, 2021
1 parent f83864f commit f118b5f
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 25 deletions.
4 changes: 2 additions & 2 deletions zstd/blockdec.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ func (b *blockDec) reset(br byteBuffer, windowSize uint64) error {

// Read block data.
if cap(b.dataStorage) < cSize {
if b.lowMem {
if b.lowMem || cSize > maxCompressedBlockSize {
b.dataStorage = make([]byte, 0, cSize)
} else {
b.dataStorage = make([]byte, 0, maxBlockSize)
b.dataStorage = make([]byte, 0, maxCompressedBlockSize)
}
}
if cap(b.dst) <= maxSize {
Expand Down
25 changes: 22 additions & 3 deletions zstd/decoder_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ type decoderOptions struct {
lowMem bool
concurrent int
maxDecodedSize uint64
maxWindowSize uint64
dicts []dict
}

func (o *decoderOptions) setDefault() {
*o = decoderOptions{
// use less ram: true for now, but may change.
lowMem: true,
concurrent: runtime.GOMAXPROCS(0),
lowMem: true,
concurrent: runtime.GOMAXPROCS(0),
maxWindowSize: MaxWindowSize,
}
o.maxDecodedSize = 1 << 63
}
Expand Down Expand Up @@ -52,7 +54,6 @@ func WithDecoderConcurrency(n int) DOption {
// WithDecoderMaxMemory allows to set a maximum decoded size for in-memory
// non-streaming operations or maximum window size for streaming operations.
// This can be used to control memory usage of potentially hostile content.
// For streaming operations, the maximum window size is capped at 1<<30 bytes.
// Maximum and default is 1 << 63 bytes.
func WithDecoderMaxMemory(n uint64) DOption {
return func(o *decoderOptions) error {
Expand Down Expand Up @@ -81,3 +82,21 @@ func WithDecoderDicts(dicts ...[]byte) DOption {
return nil
}
}

// WithDecoderMaxWindow allows to set a maximum window size for decodes.
// This allows rejecting packets that will cause big memory usage.
// The Decoder will likely allocate more memory based on the WithDecoderLowmem setting.
// If WithDecoderMaxMemory is set to a lower value, that will be used.
// Default is 512MB, Maximum is ~3.75 TB as per zstandard spec.
func WithDecoderMaxWindow(size uint64) DOption {
return func(o *decoderOptions) error {
if size < MinWindowSize {
return errors.New("WithMaxWindowSize must be at least 1KB, 1024 bytes")
}
if size > (1<<41)+7*(1<<38) {
return errors.New("WithMaxWindowSize must be less than (1<<41) + 7*(1<<38) ~ 3.75TB")
}
o.maxWindowSize = size
return nil
}
}
122 changes: 117 additions & 5 deletions zstd/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func TestNewDecoder(t *testing.T) {
func TestNewDecoderMemory(t *testing.T) {
defer timeout(60 * time.Second)()
var testdata bytes.Buffer
enc, err := NewWriter(&testdata, WithWindowSize(64<<10), WithSingleSegment(false))
enc, err := NewWriter(&testdata, WithWindowSize(32<<10), WithSingleSegment(false))
if err != nil {
t.Fatal(err)
}
Expand All @@ -200,6 +200,9 @@ func TestNewDecoderMemory(t *testing.T) {
n = 200
}

// 16K buffer
var tmp [16 << 10]byte

var before, after runtime.MemStats
runtime.GC()
runtime.ReadMemStats(&before)
Expand All @@ -214,8 +217,6 @@ func TestNewDecoderMemory(t *testing.T) {
}
}

// 32K buffer
var tmp [128 << 10]byte
for i := range decs {
_, err := io.ReadFull(decs[i], tmp[:])
if err != nil {
Expand All @@ -226,17 +227,128 @@ func TestNewDecoderMemory(t *testing.T) {
runtime.GC()
runtime.ReadMemStats(&after)
size := (after.HeapInuse - before.HeapInuse) / uint64(n) / 1024

const expect = 124
t.Log(size, "KiB per decoder")
// This is not exact science, but fail if we suddenly get more than 2x what we expect.
if size > 221*2 && !testing.Short() {
t.Errorf("expected < 221KB per decoder, got %d", size)
if size > expect*2 && !testing.Short() {
t.Errorf("expected < %dKB per decoder, got %d", expect, size)
}

for _, dec := range decs {
dec.Close()
}
}

func TestNewDecoderMemoryHighMem(t *testing.T) {
defer timeout(60 * time.Second)()
var testdata bytes.Buffer
enc, err := NewWriter(&testdata, WithWindowSize(32<<10), WithSingleSegment(false))
if err != nil {
t.Fatal(err)
}
// Write 256KB
for i := 0; i < 256; i++ {
tmp := strings.Repeat(string([]byte{byte(i)}), 1024)
_, err := enc.Write([]byte(tmp))
if err != nil {
t.Fatal(err)
}
}
err = enc.Close()
if err != nil {
t.Fatal(err)
}

var n = 50
if testing.Short() {
n = 10
}

// 16K buffer
var tmp [16 << 10]byte

var before, after runtime.MemStats
runtime.GC()
runtime.ReadMemStats(&before)

var decs = make([]*Decoder, n)
for i := range decs {
// Wrap in NopCloser to avoid shortcut.
input := ioutil.NopCloser(bytes.NewBuffer(testdata.Bytes()))
decs[i], err = NewReader(input, WithDecoderConcurrency(1), WithDecoderLowmem(false))
if err != nil {
t.Fatal(err)
}
}

for i := range decs {
_, err := io.ReadFull(decs[i], tmp[:])
if err != nil {
t.Fatal(err)
}
}

runtime.GC()
runtime.ReadMemStats(&after)
size := (after.HeapInuse - before.HeapInuse) / uint64(n) / 1024

const expect = 3915
t.Log(size, "KiB per decoder")
// This is not exact science, but fail if we suddenly get more than 2x what we expect.
if size > expect*2 && !testing.Short() {
t.Errorf("expected < %dKB per decoder, got %d", expect, size)
}

for _, dec := range decs {
dec.Close()
}
}

func TestNewDecoderFrameSize(t *testing.T) {
defer timeout(60 * time.Second)()
var testdata bytes.Buffer
enc, err := NewWriter(&testdata, WithWindowSize(64<<10))
if err != nil {
t.Fatal(err)
}
// Write 256KB
for i := 0; i < 256; i++ {
tmp := strings.Repeat(string([]byte{byte(i)}), 1024)
_, err := enc.Write([]byte(tmp))
if err != nil {
t.Fatal(err)
}
}
err = enc.Close()
if err != nil {
t.Fatal(err)
}
// Must fail
dec, err := NewReader(bytes.NewReader(testdata.Bytes()), WithDecoderMaxWindow(32<<10))
if err != nil {
t.Fatal(err)
}
_, err = io.Copy(ioutil.Discard, dec)
if err == nil {
dec.Close()
t.Fatal("Wanted error, got none")
}
dec.Close()

// Must succeed.
dec, err = NewReader(bytes.NewReader(testdata.Bytes()), WithDecoderMaxWindow(64<<10))
if err != nil {
t.Fatal(err)
}
_, err = io.Copy(ioutil.Discard, dec)
if err != nil {
dec.Close()
t.Fatalf("Wanted no error, got %+v", err)
}
dec.Close()
}

func TestNewDecoderGood(t *testing.T) {
defer timeout(30 * time.Second)()
testDecoderFile(t, "testdata/good.zip")
Expand Down
32 changes: 17 additions & 15 deletions zstd/framedec.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ type frameDec struct {

WindowSize uint64

// maxWindowSize is the maximum windows size to support.
// should never be bigger than max-int.
maxWindowSize uint64

// In order queue of blocks being decoded.
decoding chan *blockDec

Expand All @@ -50,8 +46,11 @@ type frameDec struct {
}

const (
// The minimum Window_Size is 1 KB.
// MinWindowSize is the minimum Window Size, which is 1 KB.
MinWindowSize = 1 << 10

// MaxWindowSize is the maximum encoder window size
// and the default decoder maximum window size.
MaxWindowSize = 1 << 29
)

Expand All @@ -61,12 +60,11 @@ var (
)

func newFrameDec(o decoderOptions) *frameDec {
d := frameDec{
o: o,
maxWindowSize: MaxWindowSize,
if o.maxWindowSize > o.maxDecodedSize {
o.maxWindowSize = o.maxDecodedSize
}
if d.maxWindowSize > o.maxDecodedSize {
d.maxWindowSize = o.maxDecodedSize
d := frameDec{
o: o,
}
return &d
}
Expand Down Expand Up @@ -251,13 +249,17 @@ func (d *frameDec) reset(br byteBuffer) error {
}
}

if d.WindowSize > d.maxWindowSize {
printf("window size %d > max %d\n", d.WindowSize, d.maxWindowSize)
if d.WindowSize > uint64(d.o.maxWindowSize) {
if debugDecoder {
printf("window size %d > max %d\n", d.WindowSize, d.o.maxWindowSize)
}
return ErrWindowSizeExceeded
}
// The minimum Window_Size is 1 KB.
if d.WindowSize < MinWindowSize {
println("got window size: ", d.WindowSize)
if debugDecoder {
println("got window size: ", d.WindowSize)
}
return ErrWindowSizeTooSmall
}
d.history.windowSize = int(d.WindowSize)
Expand Down Expand Up @@ -352,8 +354,8 @@ func (d *frameDec) checkCRC() error {

func (d *frameDec) initAsync() {
if !d.o.lowMem && !d.SingleSegment {
// set max extra size history to 10MB.
d.history.maxSize = d.history.windowSize + maxBlockSize*5
// set max extra size history to 2MB.
d.history.maxSize = d.history.windowSize + maxBlockSize
}
// re-alloc if more than one extra block size.
if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize {
Expand Down

0 comments on commit f118b5f

Please sign in to comment.