Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-3260 [master] Code hardening #1691

Merged
merged 4 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bson/default_value_decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func intDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Va
case reflect.Int64:
return reflect.ValueOf(i64), nil
case reflect.Int:
if int64(int(i64)) != i64 { // Can we fit this inside of an int
if i64 > math.MaxInt { // Can we fit this inside of an int
return emptyValue, fmt.Errorf("%d overflows int", i64)
}

Expand Down
4 changes: 2 additions & 2 deletions bson/extjson_wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ func (ejv *extJSONValue) parseBinary() (b []byte, subType byte, err error) {
return nil, 0, fmt.Errorf("$binary subType value should be string, but instead is %s", val.t)
}

i, err := strconv.ParseInt(val.v.(string), 16, 64)
i, err := strconv.ParseUint(val.v.(string), 16, 8)
if err != nil {
return nil, 0, fmt.Errorf("invalid $binary subType string: %s", val.v.(string))
return nil, 0, fmt.Errorf("invalid $binary subType string: %q: %w", val.v.(string), err)
}

subType = byte(i)
Expand Down
8 changes: 6 additions & 2 deletions bson/uint_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,15 @@ func (uic *uintCodec) decodeType(dc DecodeContext, vr ValueReader, t reflect.Typ

return reflect.ValueOf(uint64(i64)), nil
case reflect.Uint:
if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint
if i64 < 0 {
return emptyValue, fmt.Errorf("%d overflows uint", i64)
}
v := uint64(i64)
if v > math.MaxUint { // Can we fit this inside of an uint
return emptyValue, fmt.Errorf("%d overflows uint", i64)
}

return reflect.ValueOf(uint(i64)), nil
return reflect.ValueOf(uint(v)), nil
default:
return emptyValue, ValueDecoderError{
Name: "UintDecodeValue",
Expand Down
12 changes: 5 additions & 7 deletions bson/value_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ func (vr *valueReader) peekLength() (int32, error) {
}

idx := vr.offset
return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil
return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil
}

func (vr *valueReader) readLength() (int32, error) { return vr.readi32() }
Expand All @@ -851,7 +851,7 @@ func (vr *valueReader) readi32() (int32, error) {

idx := vr.offset
vr.offset += 4
return (int32(vr.d[idx]) | int32(vr.d[idx+1])<<8 | int32(vr.d[idx+2])<<16 | int32(vr.d[idx+3])<<24), nil
return int32(binary.LittleEndian.Uint32(vr.d[idx:])), nil
}

func (vr *valueReader) readu32() (uint32, error) {
Expand All @@ -861,7 +861,7 @@ func (vr *valueReader) readu32() (uint32, error) {

idx := vr.offset
vr.offset += 4
return (uint32(vr.d[idx]) | uint32(vr.d[idx+1])<<8 | uint32(vr.d[idx+2])<<16 | uint32(vr.d[idx+3])<<24), nil
return binary.LittleEndian.Uint32(vr.d[idx:]), nil
}

func (vr *valueReader) readi64() (int64, error) {
Expand All @@ -871,8 +871,7 @@ func (vr *valueReader) readi64() (int64, error) {

idx := vr.offset
vr.offset += 8
return int64(vr.d[idx]) | int64(vr.d[idx+1])<<8 | int64(vr.d[idx+2])<<16 | int64(vr.d[idx+3])<<24 |
int64(vr.d[idx+4])<<32 | int64(vr.d[idx+5])<<40 | int64(vr.d[idx+6])<<48 | int64(vr.d[idx+7])<<56, nil
return int64(binary.LittleEndian.Uint64(vr.d[idx:])), nil
}

func (vr *valueReader) readu64() (uint64, error) {
Expand All @@ -882,6 +881,5 @@ func (vr *valueReader) readu64() (uint64, error) {

idx := vr.offset
vr.offset += 8
return uint64(vr.d[idx]) | uint64(vr.d[idx+1])<<8 | uint64(vr.d[idx+2])<<16 | uint64(vr.d[idx+3])<<24 |
uint64(vr.d[idx+4])<<32 | uint64(vr.d[idx+5])<<40 | uint64(vr.d[idx+6])<<48 | uint64(vr.d[idx+7])<<56, nil
return binary.LittleEndian.Uint64(vr.d[idx:]), nil
}
4 changes: 2 additions & 2 deletions etc/run-atlas-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ set +x
# Get the atlas secrets.
. ${DRIVERS_TOOLS}/.evergreen/secrets_handling/setup-secrets.sh drivers/atlas_connect

echo "Running cmd/testatlas/main.go"
go run ./internal/cmd/testatlas/main.go "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS" >> test.suite
echo "Running cmd/testatlas"
go test -v -run ^TestAtlas$ go.mongodb.org/mongo-driver/internal/cmd/testatlas -args "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS" >> test.suite
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"errors"
"flag"
"fmt"
"os"
"testing"
"time"

"go.mongodb.org/mongo-driver/bson"
Expand All @@ -19,15 +21,19 @@ import (
"go.mongodb.org/mongo-driver/mongo/options"
)

func main() {
func TestMain(m *testing.M) {
flag.Parse()
os.Exit(m.Run())
}

func TestAtlas(t *testing.T) {
uris := flag.Args()
ctx := context.Background()

fmt.Printf("Running atlas tests for %d uris\n", len(uris))
t.Logf("Running atlas tests for %d uris\n", len(uris))

for idx, uri := range uris {
fmt.Printf("Running test %d\n", idx)
t.Logf("Running test %d\n", idx)

// Set a low server selection timeout so we fail fast if there are errors.
clientOpts := options.Client().
Expand All @@ -36,18 +42,18 @@ func main() {

// Run basic connectivity test.
if err := runTest(ctx, clientOpts); err != nil {
panic(fmt.Sprintf("error running test with TLS at index %d: %v", idx, err))
t.Fatalf("error running test with TLS at index %d: %v", idx, err)
}

// Run the connectivity test with InsecureSkipVerify to ensure SNI is done correctly even if verification is
// disabled.
clientOpts.TLSConfig.InsecureSkipVerify = true
if err := runTest(ctx, clientOpts); err != nil {
panic(fmt.Sprintf("error running test with tlsInsecure at index %d: %v", idx, err))
t.Fatalf("error running test with tlsInsecure at index %d: %v", idx, err)
}
}

fmt.Println("Finished!")
t.Logf("Finished!")
}

func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {
Expand Down
7 changes: 6 additions & 1 deletion internal/logger/io_sink.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package logger
import (
"encoding/json"
"io"
"math"
"sync"
"time"
)
Expand Down Expand Up @@ -36,7 +37,11 @@ func NewIOSink(out io.Writer) *IOSink {

// Info will write a JSON-encoded message to the io.Writer.
func (sink *IOSink) Info(_ int, msg string, keysAndValues ...interface{}) {
kvMap := make(map[string]interface{}, len(keysAndValues)/2+2)
mapSize := len(keysAndValues) / 2
if math.MaxInt-mapSize >= 2 {
mapSize += 2
}
kvMap := make(map[string]interface{}, mapSize)

kvMap[KeyTimestamp] = time.Now().UnixNano()
kvMap[KeyMessage] = msg
Expand Down
15 changes: 14 additions & 1 deletion mongo/options/clientoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"errors"
"fmt"
"io/ioutil"
"math"
"net"
"net/http"
"strings"
Expand Down Expand Up @@ -1142,7 +1143,19 @@ func addClientCertFromSeparateFiles(cfg *tls.Config, keyFile, certFile, keyPassw
return "", err
}

data := make([]byte, 0, len(keyData)+len(certData)+1)
keySize := len(keyData)
if keySize > 64*1024*1024 {
return "", errors.New("X.509 key must be less than 64 MiB")
}
certSize := len(certData)
if certSize > 64*1024*1024 {
return "", errors.New("X.509 certificate must be less than 64 MiB")
}
dataSize := int64(keySize) + int64(certSize) + 1
if dataSize > math.MaxInt {
return "", errors.New("size overflow")
}
data := make([]byte, 0, int(dataSize))
data = append(data, keyData...)
data = append(data, '\n')
data = append(data, certData...)
Expand Down
4 changes: 4 additions & 0 deletions mongo/writeconcern/writeconcern.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ package writeconcern
import (
"errors"
"fmt"
"math"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
Expand Down Expand Up @@ -157,6 +158,9 @@ func (wc *WriteConcern) MarshalBSONValue() (bson.Type, []byte, error) {
return 0, nil, ErrInconsistent
}

if w > math.MaxInt32 {
return 0, nil, fmt.Errorf("%d overflows int32", w)
}
Comment on lines +161 to +163
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2.0 updates

elems = bsoncore.AppendInt32Element(elems, "w", int32(w))
case string:
elems = bsoncore.AppendStringElement(elems, "w", w)
Expand Down
40 changes: 18 additions & 22 deletions x/bsonx/bsoncore/bsoncore.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package bsoncore

import (
"bytes"
"encoding/binary"
"fmt"
"math"
"strconv"
Expand Down Expand Up @@ -705,17 +706,16 @@ func ReserveLength(dst []byte) (int32, []byte) {

// UpdateLength updates the length at index with length and returns the []byte.
func UpdateLength(dst []byte, index, length int32) []byte {
dst[index] = byte(length)
dst[index+1] = byte(length >> 8)
dst[index+2] = byte(length >> 16)
dst[index+3] = byte(length >> 24)
binary.LittleEndian.PutUint32(dst[index:], uint32(length))
return dst
}

func appendLength(dst []byte, l int32) []byte { return appendi32(dst, l) }

func appendi32(dst []byte, i32 int32) []byte {
return append(dst, byte(i32), byte(i32>>8), byte(i32>>16), byte(i32>>24))
b := []byte{0, 0, 0, 0}
binary.LittleEndian.PutUint32(b, uint32(i32))
return append(dst, b...)
}

// ReadLength reads an int32 length from src and returns the length and the remaining bytes. If
Expand All @@ -733,51 +733,47 @@ func readi32(src []byte) (int32, []byte, bool) {
if len(src) < 4 {
return 0, src, false
}
return (int32(src[0]) | int32(src[1])<<8 | int32(src[2])<<16 | int32(src[3])<<24), src[4:], true
return int32(binary.LittleEndian.Uint32(src)), src[4:], true
}

func appendi64(dst []byte, i64 int64) []byte {
return append(dst,
byte(i64), byte(i64>>8), byte(i64>>16), byte(i64>>24),
byte(i64>>32), byte(i64>>40), byte(i64>>48), byte(i64>>56),
)
b := []byte{0, 0, 0, 0, 0, 0, 0, 0}
binary.LittleEndian.PutUint64(b, uint64(i64))
return append(dst, b...)
}

func readi64(src []byte) (int64, []byte, bool) {
if len(src) < 8 {
return 0, src, false
}
i64 := (int64(src[0]) | int64(src[1])<<8 | int64(src[2])<<16 | int64(src[3])<<24 |
int64(src[4])<<32 | int64(src[5])<<40 | int64(src[6])<<48 | int64(src[7])<<56)
return i64, src[8:], true
return int64(binary.LittleEndian.Uint64(src)), src[8:], true
}

func appendu32(dst []byte, u32 uint32) []byte {
return append(dst, byte(u32), byte(u32>>8), byte(u32>>16), byte(u32>>24))
b := []byte{0, 0, 0, 0}
binary.LittleEndian.PutUint32(b, u32)
return append(dst, b...)
}

func readu32(src []byte) (uint32, []byte, bool) {
if len(src) < 4 {
return 0, src, false
}

return (uint32(src[0]) | uint32(src[1])<<8 | uint32(src[2])<<16 | uint32(src[3])<<24), src[4:], true
return binary.LittleEndian.Uint32(src), src[4:], true
}

func appendu64(dst []byte, u64 uint64) []byte {
return append(dst,
byte(u64), byte(u64>>8), byte(u64>>16), byte(u64>>24),
byte(u64>>32), byte(u64>>40), byte(u64>>48), byte(u64>>56),
)
b := []byte{0, 0, 0, 0, 0, 0, 0, 0}
binary.LittleEndian.PutUint64(b, u64)
return append(dst, b...)
}

func readu64(src []byte) (uint64, []byte, bool) {
if len(src) < 8 {
return 0, src, false
}
u64 := (uint64(src[0]) | uint64(src[1])<<8 | uint64(src[2])<<16 | uint64(src[3])<<24 |
uint64(src[4])<<32 | uint64(src[5])<<40 | uint64(src[6])<<48 | uint64(src[7])<<56)
return u64, src[8:], true
return binary.LittleEndian.Uint64(src), src[8:], true
}

// keep in sync with readcstringbytes
Expand Down
Loading