Skip to content

Commit

Permalink
Enforce maximum reflection allocation size
Browse files Browse the repository at this point in the history
This is meant to protect against XDR doctored to cause a heap explosion.

The decoder won't make any allocation larger than the provided maximum.
  • Loading branch information
2opremio committed Nov 15, 2023
1 parent 6c7b684 commit 30dd2ac
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 15 deletions.
68 changes: 55 additions & 13 deletions xdr3/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ import (
"time"
)

const maxInt32 = int(^uint32(0) >> 1)
const maxInt32 = math.MaxInt32

var errMaxSlice = "data exceeds max slice limit"
var errMaxAllocation = "allocation would exceed the maximum limit"
var errIODecode = "%s while decoding %d bytes"

// DecodeDefaultMaxDepth is the default maximum decoding depth
Expand Down Expand Up @@ -93,8 +94,9 @@ func Unmarshal(r io.Reader, v interface{}) (int, error) {
// won't work.
type Decoder struct {
// used to minimize heap allocations during decoding
scratchBuf [8]byte
r io.Reader
scratchBuf [8]byte
r io.Reader
maxAllocSize int
}

// DecodeInt treats the next 4 bytes as an XDR encoded integer and returns the
Expand Down Expand Up @@ -477,7 +479,7 @@ func (d *Decoder) decodeFixedArray(v reflect.Value, ignoreOpaque bool, maxDepth
// elements of the same type as the array represented by the reflection value.
// The number of elements is obtained by first decoding the unsigned integer
// element count. Then each element is decoded into the passed array. The
// ignoreOpaque flag controls whether or not uint8 (byte) elements should be
// ignoreOpaque flag controls whether uint8 (byte) elements should be
// decoded individually or as a variable sequence of opaque data. It returns
// the number of bytes actually read.
//
Expand Down Expand Up @@ -509,6 +511,10 @@ func (d *Decoder) decodeArray(v reflect.Value, ignoreOpaque bool, maxSize int, m
// existing slice does not have enough capacity.
sliceLen := int(dataLen)
if v.Cap() < sliceLen {
growth := (sliceLen - v.Cap()) * int(v.Type().Size())
if growth > d.maxAllocSize {
return 0, unmarshalError("decodeArray", ErrOverflow, errMaxAllocation, nil, nil)
}
v.Set(reflect.MakeSlice(v.Type(), sliceLen, sliceLen))
}
v.SetLen(sliceLen)
Expand Down Expand Up @@ -591,7 +597,11 @@ func (d *Decoder) decodeUnion(v reflect.Value, maxDepth uint) (int, error) {

vv := v.FieldByName(arm)

vv.Set(reflect.New(vv.Type().Elem()))
vvet := vv.Type().Elem()
if d.maxAllocSize < int(vvet.Size()) {
return 0, unmarshalError("decode", ErrOverflow, errMaxAllocation, nil, nil)
}
vv.Set(reflect.New(vvet))

field, ok := v.Type().FieldByName(arm)
if !ok {
Expand Down Expand Up @@ -621,8 +631,8 @@ func (d *Decoder) decodeUnion(v reflect.Value, maxDepth uint) (int, error) {

// decodeStruct treats the next bytes as a series of XDR encoded elements
// of the same type as the exported fields of the struct represented by the
// passed reflection value. Pointers are automatically indirected and
// allocated as necessary. It returns the the number of bytes actually read.
// passed reflection value. Pointers are automatically indirected and
// allocated as necessary. It returns the number of bytes actually read.
//
// An UnmarshalError is returned if any issues are encountered while decoding
// the elements.
Expand Down Expand Up @@ -728,12 +738,16 @@ func (d *Decoder) decodeMap(v reflect.Value, maxDepth uint) (int, error) {
// Allocate storage for the underlying map if needed.
vt := v.Type()
if v.IsNil() {
// We assume that the map allocation won't exceed d.maxAllocSize
v.Set(reflect.MakeMap(vt))
}

// Decode each key and value according to their type.
keyType := vt.Key()
elemType := vt.Elem()
if uintptr(d.maxAllocSize) < keyType.Size()+elemType.Size()*uintptr(dataLen) {
return 0, unmarshalError("decode", ErrOverflow, errMaxAllocation, nil, nil)
}
for i := uint32(0); i < dataLen; i++ {
key := reflect.New(keyType).Elem()
n2, err := d.decode(key, 0, maxDepth)
Expand All @@ -756,7 +770,7 @@ func (d *Decoder) decodeMap(v reflect.Value, maxDepth uint) (int, error) {
// decodeInterface examines the interface represented by the passed reflection
// value to detect whether it is an interface that can be decoded into and
// if it is, extracts the underlying value to pass back into the decode function
// for decoding according to its type. It returns the the number of bytes
// for decoding according to its type. It returns the number of bytes
// actually read.
//
// An UnmarshalError is returned if any issues are encountered while decoding
Expand Down Expand Up @@ -786,6 +800,14 @@ func (d *Decoder) decodeInterface(v reflect.Value, maxDepth uint) (int, error) {
return d.decode(ve, 0, maxDepth)
}

func (d *Decoder) mergeMaxAllocSizeAndMaxSize(maxSize int) int {
if maxSize == 0 ||
d.maxAllocSize < maxSize {
return d.maxAllocSize
}
return maxSize
}

// decode is the main workhorse for unmarshalling via reflection. It uses
// the passed reflection value to choose the XDR primitives to decode from
// the encapsulated reader. It is a recursive function,
Expand All @@ -809,6 +831,7 @@ func (d *Decoder) decode(ve reflect.Value, maxSize int, maxDepth uint) (int, err
// since checking a string is much quicker.
if ve.Type().String() == "time.Time" {
// Read the value as a string and parse it.
maxSize = d.mergeMaxAllocSizeAndMaxSize(maxSize)
timeString, n, err := d.DecodeString(maxSize)
if err != nil {
return n, err
Expand Down Expand Up @@ -911,6 +934,7 @@ func (d *Decoder) decode(ve reflect.Value, maxSize int, maxDepth uint) (int, err
maxSize = dest.XDRMaxSize()
}

maxSize = d.mergeMaxAllocSizeAndMaxSize(maxSize)
s, n, err := d.DecodeString(maxSize)
if err != nil {
return n, err
Expand Down Expand Up @@ -993,7 +1017,7 @@ func setPtrToNil(v *reflect.Value) error {
return nil
}

func allocPtrIfNil(v *reflect.Value) error {
func (d *Decoder) allocPtrIfNil(v *reflect.Value) error {
if v.Kind() != reflect.Ptr {
msg := fmt.Sprintf("value is not a pointer: '%v'",
v.Type().String())
Expand All @@ -1010,7 +1034,11 @@ func allocPtrIfNil(v *reflect.Value) error {
return err
}
if isNil {
v.Set(reflect.New(v.Type().Elem()))
vet := v.Type().Elem()
if d.maxAllocSize < int(vet.Size()) {
return unmarshalError("decode", ErrOverflow, errMaxAllocation, nil, nil)
}
v.Set(reflect.New(vet))
}
return nil
}
Expand All @@ -1030,7 +1058,7 @@ func (d *Decoder) decodePtr(v reflect.Value, maxDepth uint) (int, error) {
return n, err
}

if err = allocPtrIfNil(&v); err != nil {
if err = d.allocPtrIfNil(&v); err != nil {
return n, err
}

Expand All @@ -1042,7 +1070,7 @@ func (d *Decoder) decodePtr(v reflect.Value, maxDepth uint) (int, error) {
// otherwise returns the passed value.
func (d *Decoder) indirectIfPtr(v reflect.Value) (reflect.Value, error) {
if v.Kind() == reflect.Ptr {
err := allocPtrIfNil(&v)
err := d.allocPtrIfNil(&v)
return v.Elem(), err
}
return v, nil
Expand All @@ -1053,11 +1081,25 @@ func (d *Decoder) indirectIfPtr(v reflect.Value) (reflect.Value, error) {
// data instead of a user-supplied reader. See the Unmarhsal documentation for
// specifics. Decode(v) is equivalent to DecodeWithMaxDepth(v, DecodeDefaultMaxDepth)
func (d *Decoder) Decode(v interface{}) (int, error) {
return d.DecodeWithMaxDepth(v, DecodeDefaultMaxDepth)
return d.DecodeWithMaxDepthAndMaxAllocSize(v, DecodeDefaultMaxDepth, 0)
}

// DecodeWithMaxDepth behaves like Decode, except an explicit maximum decoding depth is used
func (d *Decoder) DecodeWithMaxDepth(v interface{}, maxDepth uint) (int, error) {
return d.DecodeWithMaxDepthAndMaxAllocSize(v, maxDepth, 0)
}

// DecodeWithMaxDepthAndMaxAllocSize behaves like DecodeWithMaxDepth, except an explicit maximum
// allocation size is used. This is meant to protect against XDR doctored to cause a heap explosion.
// The decoder won't make any allocation larger than maxAllocSize
func (d *Decoder) DecodeWithMaxDepthAndMaxAllocSize(v interface{}, maxDepth uint, maxAllocSize int) (int, error) {
if maxAllocSize == 0 {
maxAllocSize = maxInt32
}
d.maxAllocSize = maxAllocSize
defer func() {
d.maxAllocSize = 0
}()
if v == nil {
msg := "can't unmarshal to nil interface"
return 0, unmarshalError("Unmarshal", ErrNilInterface, msg, nil,
Expand Down
21 changes: 21 additions & 0 deletions xdr3/decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1155,3 +1155,24 @@ func TestDecodeMaxDepth(t *testing.T) {
_, err = decoder.DecodeWithMaxDepth(&s, 2)
assertError(t, "", err, &UnmarshalError{ErrorCode: ErrMaxDecodingDepth})
}

func TestDecodeMaxAllocSize(t *testing.T) {
var buf bytes.Buffer
_, err := Marshal(&buf, "thisstringis23charslong")
if err != nil {
t.Error("unexpected error")
}

bufCopy := buf
decoder := NewDecoder(&bufCopy)
var s string
_, err = decoder.DecodeWithMaxDepthAndMaxAllocSize(&s, DecodeDefaultMaxDepth, 23)
if err != nil {
t.Error("unexpected error")
}

bufCopy = buf
decoder = NewDecoder(&bufCopy)
_, err = decoder.DecodeWithMaxDepthAndMaxAllocSize(&s, DecodeDefaultMaxDepth, 22)
assertError(t, "", err, &UnmarshalError{ErrorCode: ErrOverflow})
}
3 changes: 2 additions & 1 deletion xdr3/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ const (
ErrNotSettable

// ErrOverflow indicates that the data in question is too large to fit
// into the corresponding Go or XDR data type. For example, an integer
// into the corresponding Go or XDR data type or that allocating it exceeds the
// maximum allocation size limit. For example, an integer
// decoded from XDR that is too large to fit into a target type of int8,
// or opaque data that exceeds the max length of a Go slice.
ErrOverflow
Expand Down
2 changes: 1 addition & 1 deletion xdr3/internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func TstEncode(w io.Writer) func(v reflect.Value) (int, error) {

// TstDecode creates a new Decoder for the passed reader and returns the
// internal decode function on the Decoder.
func TstDecode(r io.Reader) func(v reflect.Value, maxSize int, maxDepth uint) (int, error) {
func TstDecode(r io.Reader) func(v reflect.Value, maxLen int, maxDepth uint) (int, error) {
dec := NewDecoder(r)
return dec.decode
}

0 comments on commit 30dd2ac

Please sign in to comment.