diff --git a/bson/bsoncodec/default_value_decoders.go b/bson/bsoncodec/default_value_decoders.go index b5e22c498a..e479c3585b 100644 --- a/bson/bsoncodec/default_value_decoders.go +++ b/bson/bsoncodec/default_value_decoders.go @@ -1540,12 +1540,12 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr return err } - fn := val.Convert(tValueUnmarshaler).MethodByName("UnmarshalBSONValue") - errVal := fn.Call([]reflect.Value{reflect.ValueOf(t), reflect.ValueOf(src)})[0] - if !errVal.IsNil() { - return errVal.Interface().(error) + m, ok := val.Interface().(ValueUnmarshaler) + if !ok { + // NB: this error should be unreachable due to the above checks + return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } - return nil + return m.UnmarshalBSONValue(t, src) } // UnmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations. @@ -1588,12 +1588,12 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr bsonr val = val.Addr() // If the type doesn't implement the interface, a pointer to it must. } - fn := val.Convert(tUnmarshaler).MethodByName("UnmarshalBSON") - errVal := fn.Call([]reflect.Value{reflect.ValueOf(src)})[0] - if !errVal.IsNil() { - return errVal.Interface().(error) + m, ok := val.Interface().(Unmarshaler) + if !ok { + // NB: this error should be unreachable due to the above checks + return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} } - return nil + return m.UnmarshalBSON(src) } // EmptyInterfaceDecodeValue is the ValueDecoderFunc for interface{}. diff --git a/bson/bsoncodec/default_value_decoders_test.go b/bson/bsoncodec/default_value_decoders_test.go index d8c8389654..bac92e04f8 100644 --- a/bson/bsoncodec/default_value_decoders_test.go +++ b/bson/bsoncodec/default_value_decoders_test.go @@ -1530,13 +1530,22 @@ func TestDefaultValueDecoders(t *testing.T) { errors.New("copy error"), }, { - "Unmarshaler", + // Only the pointer form of testUnmarshaler implements Unmarshaler + "value does not implement Unmarshaler", testUnmarshaler{Val: bsoncore.AppendDouble(nil, 3.14159)}, nil, &bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)}, bsonrwtest.ReadDouble, nil, }, + { + "Unmarshaler", + &testUnmarshaler{Val: bsoncore.AppendDouble(nil, 3.14159)}, + nil, + &bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)}, + bsonrwtest.ReadDouble, + nil, + }, }, }, { diff --git a/bson/bsoncodec/default_value_encoders.go b/bson/bsoncodec/default_value_encoders.go index 7d526c4ef8..4ab14a668c 100644 --- a/bson/bsoncodec/default_value_encoders.go +++ b/bson/bsoncodec/default_value_encoders.go @@ -564,12 +564,14 @@ func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(_ EncodeContext, vw bs return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val} } - fn := val.Convert(tValueMarshaler).MethodByName("MarshalBSONValue") - returns := fn.Call(nil) - if !returns[2].IsNil() { - return returns[2].Interface().(error) + m, ok := val.Interface().(ValueMarshaler) + if !ok { + return vw.WriteNull() + } + t, data, err := m.MarshalBSONValue() + if err != nil { + return err } - t, data := returns[0].Interface().(bsontype.Type), returns[1].Interface().([]byte) return bsonrw.Copier{}.CopyValueFromBytes(vw, t, data) } @@ -593,12 +595,14 @@ func (dve DefaultValueEncoders) MarshalerEncodeValue(_ EncodeContext, vw bsonrw. return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val} } - fn := val.Convert(tMarshaler).MethodByName("MarshalBSON") - returns := fn.Call(nil) - if !returns[1].IsNil() { - return returns[1].Interface().(error) + m, ok := val.Interface().(Marshaler) + if !ok { + return vw.WriteNull() + } + data, err := m.MarshalBSON() + if err != nil { + return err } - data := returns[0].Interface().([]byte) return bsonrw.Copier{}.CopyValueFromBytes(vw, bsontype.EmbeddedDocument, data) } @@ -622,23 +626,31 @@ func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw bsonrw.Val return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val} } - fn := val.Convert(tProxy).MethodByName("ProxyBSON") - returns := fn.Call(nil) - if !returns[1].IsNil() { - return returns[1].Interface().(error) + m, ok := val.Interface().(Proxy) + if !ok { + return vw.WriteNull() + } + v, err := m.ProxyBSON() + if err != nil { + return err + } + if v == nil { + encoder, err := ec.LookupEncoder(nil) + if err != nil { + return err + } + return encoder.EncodeValue(ec, vw, reflect.ValueOf(nil)) } - data := returns[0] - var encoder ValueEncoder - var err error - if data.Elem().IsValid() { - encoder, err = ec.LookupEncoder(data.Elem().Type()) - } else { - encoder, err = ec.LookupEncoder(nil) + vv := reflect.ValueOf(v) + switch vv.Kind() { + case reflect.Ptr, reflect.Interface: + vv = vv.Elem() } + encoder, err := ec.LookupEncoder(vv.Type()) if err != nil { return err } - return encoder.EncodeValue(ec, vw, data.Elem()) + return encoder.EncodeValue(ec, vw, vv) } // JavaScriptEncodeValue is the ValueEncoderFunc for the primitive.JavaScript type. diff --git a/bson/mgocompat/setter_getter.go b/bson/mgocompat/setter_getter.go index 55af549d40..e9c9cae834 100644 --- a/bson/mgocompat/setter_getter.go +++ b/bson/mgocompat/setter_getter.go @@ -7,6 +7,7 @@ package mgocompat import ( + "errors" "reflect" "go.mongodb.org/mongo-driver/bson" @@ -73,16 +74,15 @@ func SetterDecodeValue(_ bsoncodec.DecodeContext, vr bsonrw.ValueReader, val ref return err } - fn := val.Convert(tSetter).MethodByName("SetBSON") - - errVal := fn.Call([]reflect.Value{reflect.ValueOf(bson.RawValue{Type: t, Value: src})})[0] - if !errVal.IsNil() { - err = errVal.Interface().(error) - if err == ErrSetZero { - val.Set(reflect.Zero(val.Type())) - return nil + m, ok := val.Interface().(Setter) + if !ok { + return bsoncodec.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + } + if err := m.SetBSON(bson.RawValue{Type: t, Value: src}); err != nil { + if !errors.Is(err, ErrSetZero) { + return err } - return err + val.Set(reflect.Zero(val.Type())) } return nil } @@ -104,17 +104,23 @@ func GetterEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val re return bsoncodec.ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} } - fn := val.Convert(tGetter).MethodByName("GetBSON") - returns := fn.Call(nil) - if !returns[1].IsNil() { - return returns[1].Interface().(error) + m, ok := val.Interface().(Getter) + if !ok { + return vw.WriteNull() + } + x, err := m.GetBSON() + if err != nil { + return err + } + if x == nil { + return vw.WriteNull() } - intermediate := returns[0] - encoder, err := ec.Registry.LookupEncoder(intermediate.Type()) + vv := reflect.ValueOf(x) + encoder, err := ec.Registry.LookupEncoder(vv.Type()) if err != nil { return err } - return encoder.EncodeValue(ec, vw, intermediate) + return encoder.EncodeValue(ec, vw, vv) } // isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type