Skip to content

Commit

Permalink
Merge pull request #115 from DataDog/viq111/decompress-sanity-check
Browse files Browse the repository at this point in the history
[zstd][#60] Add decompression size sanity Check
  • Loading branch information
Viq111 authored Apr 14, 2022
2 parents 393c3c1 + f973148 commit ed849f7
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 70 deletions.
81 changes: 50 additions & 31 deletions zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ var (
ErrEmptySlice = errors.New("Bytes slice is empty")
)

const (
// decompressSizeBufferLimit is the limit we set on creating a decompression buffer for the Decompress API
// This is made to prevent DOS from maliciously-created payloads (aka zipbomb).
// For large payloads with a compression ratio > 10, you can do your own allocation and pass it to the method:
// dst := make([]byte, 1GB)
// decompressed, err := zstd.Decompress(dst, src)
decompressSizeBufferLimit = 1000 * 1000

zstdFrameHeaderSizeMax = 18 // From zstd.h. Since it's experimental API, hardcoding it
)

// CompressBound returns the worst case size needed for a destination buffer,
// which can be used to preallocate a destination buffer or select a previously
// allocated buffer from a pool.
Expand All @@ -46,6 +57,30 @@ func cCompressBound(srcSize int) int {
return int(C.ZSTD_compressBound(C.size_t(srcSize)))
}

// decompressSizeHint tries to give a hint on how much of the output buffer size we should have
// based on zstd frame descriptors. To prevent DOS from maliciously-created payloads, limit the size
func decompressSizeHint(src []byte) int {
// 1 MB or 10x input size
upperBound := 10 * len(src)
if upperBound < decompressSizeBufferLimit {
upperBound = decompressSizeBufferLimit
}

hint := upperBound
if len(src) >= zstdFrameHeaderSizeMax {
hint = int(C.ZSTD_getFrameContentSize(unsafe.Pointer(&src[0]), C.size_t(len(src))))
if hint < 0 { // On error, just use upperBound
hint = upperBound
}
}

// Take the minimum of both
if hint > upperBound {
return upperBound
}
return hint
}

// Compress src into dst. If you have a buffer to use, you can pass it to
// prevent allocation. If it is too small, or if nil is passed, a new buffer
// will be allocated and returned.
Expand Down Expand Up @@ -97,41 +132,25 @@ func Decompress(dst, src []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, ErrEmptySlice
}
decompress := func(dst, src []byte) ([]byte, error) {

cWritten := C.ZSTD_decompress(
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src)))

written := int(cWritten)
// Check error
if err := getError(written); err != nil {
return nil, err
}
return dst[:written], nil
bound := decompressSizeHint(src)
if cap(dst) >= bound {
dst = dst[0:cap(dst)]
} else {
dst = make([]byte, bound)
}

if len(dst) == 0 {
// Attempt to use zStd to determine decompressed size (may result in error or 0)
size := int(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src))))
if err := getError(size); err != nil {
return nil, err
}

if size > 0 {
dst = make([]byte, size)
} else {
dst = make([]byte, len(src)*3) // starting guess
}
written := int(C.ZSTD_decompress(
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src))))
err := getError(written)
if err == nil {
return dst[:written], nil
}
for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
result, err := decompress(dst, src)
if !IsDstSizeTooSmallError(err) {
return result, err
}
dst = make([]byte, len(dst)*2) // Grow buffer by 2
if !IsDstSizeTooSmallError(err) {
return nil, err
}

// We failed getting a dst buffer of correct size, use stream API
Expand Down
9 changes: 2 additions & 7 deletions zstd_bulk.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ var (
ErrEmptyDictionary = errors.New("Dictionary is empty")
// ErrBadDictionary is returned when cannot load the given dictionary
ErrBadDictionary = errors.New("Cannot load dictionary")
// ErrContentSize is returned when cannot determine the content size
ErrContentSize = errors.New("Cannot determine the content size")
)

// BulkProcessor implements Bulk processing dictionary API.
Expand Down Expand Up @@ -111,12 +109,9 @@ func (p *BulkProcessor) Decompress(dst, src []byte) ([]byte, error) {
if len(src) == 0 {
return nil, ErrEmptySlice
}
contentSize := uint64(C.ZSTD_getFrameContentSize(unsafe.Pointer(&src[0]), C.size_t(len(src))))
if contentSize == C.ZSTD_CONTENTSIZE_ERROR || contentSize == C.ZSTD_CONTENTSIZE_UNKNOWN {
return nil, ErrContentSize
}

if cap(dst) >= int(contentSize) {
contentSize := decompressSizeHint(src)
if cap(dst) >= contentSize {
dst = dst[0:contentSize]
} else {
dst = make([]byte, contentSize)
Expand Down
48 changes: 16 additions & 32 deletions zstd_ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,43 +96,27 @@ func (c *ctx) Decompress(dst, src []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, ErrEmptySlice
}
decompress := func(dst, src []byte) ([]byte, error) {

cWritten := C.ZSTD_decompressDCtx(
c.dctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src)))

written := int(cWritten)
// Check error
if err := getError(written); err != nil {
return nil, err
}
return dst[:written], nil
bound := decompressSizeHint(src)
if cap(dst) >= bound {
dst = dst[0:cap(dst)]
} else {
dst = make([]byte, bound)
}

if len(dst) == 0 {
// Attempt to use zStd to determine decompressed size (may result in error or 0)
size := int(C.size_t(C.ZSTD_getDecompressedSize(unsafe.Pointer(&src[0]), C.size_t(len(src)))))

if err := getError(size); err != nil {
return nil, err
}
written := int(C.ZSTD_decompressDCtx(
c.dctx,
unsafe.Pointer(&dst[0]),
C.size_t(len(dst)),
unsafe.Pointer(&src[0]),
C.size_t(len(src))))

if size > 0 {
dst = make([]byte, size)
} else {
dst = make([]byte, len(src)*3) // starting guess
}
err := getError(written)
if err == nil {
return dst[:written], nil
}
for i := 0; i < 3; i++ { // 3 tries to allocate a bigger buffer
result, err := decompress(dst, src)
if !IsDstSizeTooSmallError(err) {
return result, err
}
dst = make([]byte, len(dst)*2) // Grow buffer by 2
if !IsDstSizeTooSmallError(err) {
return nil, err
}

// We failed getting a dst buffer of correct size, use stream API
Expand Down
9 changes: 9 additions & 0 deletions zstd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package zstd

import (
"bytes"
b64 "encoding/base64"
"errors"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -284,6 +285,14 @@ func TestLegacy(t *testing.T) {
}
}

func TestBadPayloadZipBomb(t *testing.T) {
payload, _ := b64.StdEncoding.DecodeString("KLUv/dcwMDAwMDAwMDAwMAAA")
_, err := Decompress(nil, payload)
if err.Error() != "Src size is incorrect" {
t.Fatal("zstd should detect that the size is incorrect")
}
}

func BenchmarkCompression(b *testing.B) {
if raw == nil {
b.Fatal(ErrNoPayloadEnv)
Expand Down

0 comments on commit ed849f7

Please sign in to comment.