Skip to content

Commit

Permalink
Merge branch 'release/1.12' of github.com:mongodb/mongo-go-driver int…
Browse files Browse the repository at this point in the history
…o release/1.12
  • Loading branch information
prestonvasquez committed Aug 1, 2023
2 parents f14bd3a + d219098 commit 730e825
Show file tree
Hide file tree
Showing 33 changed files with 614 additions and 116 deletions.
20 changes: 10 additions & 10 deletions bson/bsoncodec/default_value_decoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -1540,12 +1540,12 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr
return err
}

fn := val.Convert(tValueUnmarshaler).MethodByName("UnmarshalBSONValue")
errVal := fn.Call([]reflect.Value{reflect.ValueOf(t), reflect.ValueOf(src)})[0]
if !errVal.IsNil() {
return errVal.Interface().(error)
m, ok := val.Interface().(ValueUnmarshaler)
if !ok {
// NB: this error should be unreachable due to the above checks
return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val}
}
return nil
return m.UnmarshalBSONValue(t, src)
}

// UnmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations.
Expand Down Expand Up @@ -1588,12 +1588,12 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr bsonr
val = val.Addr() // If the type doesn't implement the interface, a pointer to it must.
}

fn := val.Convert(tUnmarshaler).MethodByName("UnmarshalBSON")
errVal := fn.Call([]reflect.Value{reflect.ValueOf(src)})[0]
if !errVal.IsNil() {
return errVal.Interface().(error)
m, ok := val.Interface().(Unmarshaler)
if !ok {
// NB: this error should be unreachable due to the above checks
return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val}
}
return nil
return m.UnmarshalBSON(src)
}

// EmptyInterfaceDecodeValue is the ValueDecoderFunc for interface{}.
Expand Down
11 changes: 10 additions & 1 deletion bson/bsoncodec/default_value_decoders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1530,13 +1530,22 @@ func TestDefaultValueDecoders(t *testing.T) {
errors.New("copy error"),
},
{
"Unmarshaler",
// Only the pointer form of testUnmarshaler implements Unmarshaler
"value does not implement Unmarshaler",
testUnmarshaler{Val: bsoncore.AppendDouble(nil, 3.14159)},
nil,
&bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)},
bsonrwtest.ReadDouble,
nil,
},
{
"Unmarshaler",
&testUnmarshaler{Val: bsoncore.AppendDouble(nil, 3.14159)},
nil,
&bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)},
bsonrwtest.ReadDouble,
nil,
},
},
},
{
Expand Down
56 changes: 34 additions & 22 deletions bson/bsoncodec/default_value_encoders.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,12 +564,14 @@ func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(_ EncodeContext, vw bs
return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val}
}

fn := val.Convert(tValueMarshaler).MethodByName("MarshalBSONValue")
returns := fn.Call(nil)
if !returns[2].IsNil() {
return returns[2].Interface().(error)
m, ok := val.Interface().(ValueMarshaler)
if !ok {
return vw.WriteNull()
}
t, data, err := m.MarshalBSONValue()
if err != nil {
return err
}
t, data := returns[0].Interface().(bsontype.Type), returns[1].Interface().([]byte)
return bsonrw.Copier{}.CopyValueFromBytes(vw, t, data)
}

Expand All @@ -593,12 +595,14 @@ func (dve DefaultValueEncoders) MarshalerEncodeValue(_ EncodeContext, vw bsonrw.
return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val}
}

fn := val.Convert(tMarshaler).MethodByName("MarshalBSON")
returns := fn.Call(nil)
if !returns[1].IsNil() {
return returns[1].Interface().(error)
m, ok := val.Interface().(Marshaler)
if !ok {
return vw.WriteNull()
}
data, err := m.MarshalBSON()
if err != nil {
return err
}
data := returns[0].Interface().([]byte)
return bsonrw.Copier{}.CopyValueFromBytes(vw, bsontype.EmbeddedDocument, data)
}

Expand All @@ -622,23 +626,31 @@ func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw bsonrw.Val
return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val}
}

fn := val.Convert(tProxy).MethodByName("ProxyBSON")
returns := fn.Call(nil)
if !returns[1].IsNil() {
return returns[1].Interface().(error)
m, ok := val.Interface().(Proxy)
if !ok {
return vw.WriteNull()
}
v, err := m.ProxyBSON()
if err != nil {
return err
}
if v == nil {
encoder, err := ec.LookupEncoder(nil)
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, reflect.ValueOf(nil))
}
data := returns[0]
var encoder ValueEncoder
var err error
if data.Elem().IsValid() {
encoder, err = ec.LookupEncoder(data.Elem().Type())
} else {
encoder, err = ec.LookupEncoder(nil)
vv := reflect.ValueOf(v)
switch vv.Kind() {
case reflect.Ptr, reflect.Interface:
vv = vv.Elem()
}
encoder, err := ec.LookupEncoder(vv.Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, data.Elem())
return encoder.EncodeValue(ec, vw, vv)
}

// JavaScriptEncodeValue is the ValueEncoderFunc for the primitive.JavaScript type.
Expand Down
11 changes: 11 additions & 0 deletions bson/bsontype/bsontype.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,14 @@ func (bt Type) String() string {
return "invalid"
}
}

// IsValid will return true if the Type is valid.
func (bt Type) IsValid() bool {
switch bt {
case Double, String, EmbeddedDocument, Array, Binary, Undefined, ObjectID, Boolean, DateTime, Null, Regex,
DBPointer, JavaScript, Symbol, CodeWithScope, Int32, Timestamp, Int64, Decimal128, MinKey, MaxKey:
return true
default:
return false
}
}
38 changes: 22 additions & 16 deletions bson/mgocompat/setter_getter.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package mgocompat

import (
"errors"
"reflect"

"go.mongodb.org/mongo-driver/bson"
Expand Down Expand Up @@ -73,16 +74,15 @@ func SetterDecodeValue(_ bsoncodec.DecodeContext, vr bsonrw.ValueReader, val ref
return err
}

fn := val.Convert(tSetter).MethodByName("SetBSON")

errVal := fn.Call([]reflect.Value{reflect.ValueOf(bson.RawValue{Type: t, Value: src})})[0]
if !errVal.IsNil() {
err = errVal.Interface().(error)
if err == ErrSetZero {
val.Set(reflect.Zero(val.Type()))
return nil
m, ok := val.Interface().(Setter)
if !ok {
return bsoncodec.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val}
}
if err := m.SetBSON(bson.RawValue{Type: t, Value: src}); err != nil {
if !errors.Is(err, ErrSetZero) {
return err
}
return err
val.Set(reflect.Zero(val.Type()))
}
return nil
}
Expand All @@ -104,17 +104,23 @@ func GetterEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val re
return bsoncodec.ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val}
}

fn := val.Convert(tGetter).MethodByName("GetBSON")
returns := fn.Call(nil)
if !returns[1].IsNil() {
return returns[1].Interface().(error)
m, ok := val.Interface().(Getter)
if !ok {
return vw.WriteNull()
}
x, err := m.GetBSON()
if err != nil {
return err
}
if x == nil {
return vw.WriteNull()
}
intermediate := returns[0]
encoder, err := ec.Registry.LookupEncoder(intermediate.Type())
vv := reflect.ValueOf(x)
encoder, err := ec.Registry.LookupEncoder(vv.Type())
if err != nil {
return err
}
return encoder.EncodeValue(ec, vw, intermediate)
return encoder.EncodeValue(ec, vw, vv)
}

// isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type
Expand Down
18 changes: 15 additions & 3 deletions bson/primitive_codecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package bson

import (
"errors"
"fmt"
"reflect"

"go.mongodb.org/mongo-driver/bson/bsoncodec"
Expand Down Expand Up @@ -45,15 +46,26 @@ func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder)

// RawValueEncodeValue is the ValueEncoderFunc for RawValue.
//
// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders
// registered.
// If the RawValue's Type is "invalid" and the RawValue's Value is not empty or
// nil, then this method will return an error.
//
// Deprecated: Use bson.NewRegistry to get a registry with all primitive
// encoders and decoders registered.
func (PrimitiveCodecs) RawValueEncodeValue(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
if !val.IsValid() || val.Type() != tRawValue {
return bsoncodec.ValueEncoderError{Name: "RawValueEncodeValue", Types: []reflect.Type{tRawValue}, Received: val}
return bsoncodec.ValueEncoderError{
Name: "RawValueEncodeValue",
Types: []reflect.Type{tRawValue},
Received: val,
}
}

rawvalue := val.Interface().(RawValue)

if !rawvalue.Type.IsValid() {
return fmt.Errorf("the RawValue Type specifies an invalid BSON type: %#x", byte(rawvalue.Type))
}

return bsonrw.Copier{}.CopyValueFromBytes(vw, rawvalue.Type, rawvalue.Value)
}

Expand Down
38 changes: 38 additions & 0 deletions bson/primitive_codecs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ func compareErrors(err1, err2 error) bool {
}

func TestDefaultValueEncoders(t *testing.T) {
t.Parallel()

var pc PrimitiveCodecs

var wrong = func(string, string) string { return "wrong" }
Expand Down Expand Up @@ -107,6 +109,28 @@ func TestDefaultValueEncoders(t *testing.T) {
bsonrwtest.WriteDouble,
nil,
},
{
"RawValue Type is zero with non-zero value",
RawValue{
Type: 0x00,
Value: bsoncore.AppendDouble(nil, 3.14159),
},
nil,
nil,
bsonrwtest.Nothing,
fmt.Errorf("the RawValue Type specifies an invalid BSON type: 0x0"),
},
{
"RawValue Type is invalid",
RawValue{
Type: 0x8F,
Value: bsoncore.AppendDouble(nil, 3.14159),
},
nil,
nil,
bsonrwtest.Nothing,
fmt.Errorf("the RawValue Type specifies an invalid BSON type: 0x8f"),
},
},
},
{
Expand Down Expand Up @@ -166,9 +190,17 @@ func TestDefaultValueEncoders(t *testing.T) {
}

for _, tc := range testCases {
tc := tc // Capture the range variable

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

for _, subtest := range tc.subtests {
subtest := subtest // Capture the range variable

t.Run(subtest.name, func(t *testing.T) {
t.Parallel()

var ec bsoncodec.EncodeContext
if subtest.ectx != nil {
ec = *subtest.ectx
Expand All @@ -192,6 +224,8 @@ func TestDefaultValueEncoders(t *testing.T) {
}

t.Run("success path", func(t *testing.T) {
t.Parallel()

oid := primitive.NewObjectID()
oids := []primitive.ObjectID{primitive.NewObjectID(), primitive.NewObjectID(), primitive.NewObjectID()}
var str = new(string)
Expand Down Expand Up @@ -426,7 +460,11 @@ func TestDefaultValueEncoders(t *testing.T) {
}

for _, tc := range testCases {
tc := tc // Capture the range variable

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

b := make(bsonrw.SliceWriter, 0, 512)
vw, err := bsonrw.NewBSONValueWriter(&b)
noerr(t, err)
Expand Down
6 changes: 6 additions & 0 deletions bson/raw_value.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ type RawValue struct {
r *bsoncodec.Registry
}

// IsZero reports whether the RawValue is zero, i.e. no data is present on
// the RawValue. It returns true if Type is 0 and Value is empty or nil.
func (rv RawValue) IsZero() bool {
return rv.Type == 0x00 && len(rv.Value) == 0
}

// Unmarshal deserializes BSON into the provided val. If RawValue cannot be unmarshaled into val, an
// error is returned. This method will use the registry used to create the RawValue, if the RawValue
// was created from partial BSON processing, or it will use the default registry. Users wishing to
Expand Down
Loading

0 comments on commit 730e825

Please sign in to comment.