From 41ebbc3ff73d64383d566cb9be77048c52c156f5 Mon Sep 17 00:00:00 2001 From: Matt Dale <9760375+matthewdale@users.noreply.github.com> Date: Wed, 21 Jun 2023 14:50:19 -0700 Subject: [PATCH] GODRIVER-2725 Allow setting Encoder and Decoder options on a Client. (#1282) Co-authored-by: Preston Vasquez --- bson/bsoncodec/bsoncodec.go | 8 +- bson/decoder.go | 11 +- bson/encoder.go | 4 + bson/unmarshal_test.go | 45 +- internal/assert/assertion_mongo.go | 90 ++ internal/assert/assertion_mongo_test.go | 125 +++ mongo/bulk_write.go | 109 ++- mongo/change_stream.go | 19 +- mongo/client.go | 14 +- mongo/client_encryption.go | 18 +- mongo/client_examples_test.go | 45 + mongo/client_test.go | 2 +- mongo/collection.go | 202 ++-- mongo/collection_test.go | 2 +- mongo/cursor.go | 80 +- mongo/cursor_test.go | 37 +- mongo/database.go | 42 +- mongo/database_test.go | 2 +- mongo/gridfs/bucket.go | 2 + mongo/index_view.go | 18 +- mongo/integration/change_stream_test.go | 39 + mongo/integration/client_test.go | 121 +++ mongo/mongo.go | 231 +++-- mongo/mongo_test.go | 977 +++++++++++--------- mongo/options/clientoptions.go | 92 ++ mongo/options/collectionoptions.go | 10 + mongo/options/dboptions.go | 10 + mongo/options/mongooptions.go | 4 + mongo/single_result.go | 21 +- mongo/single_result_test.go | 29 +- x/mongo/driver/operation/find_and_modify.go | 4 +- 31 files changed, 1716 insertions(+), 697 deletions(-) create mode 100644 internal/assert/assertion_mongo.go create mode 100644 internal/assert/assertion_mongo_test.go diff --git a/bson/bsoncodec/bsoncodec.go b/bson/bsoncodec/bsoncodec.go index 3f30af94e4..0693bd432f 100644 --- a/bson/bsoncodec/bsoncodec.go +++ b/bson/bsoncodec/bsoncodec.go @@ -269,16 +269,16 @@ func (dc *DecodeContext) ZeroStructs() { dc.zeroStructs = true } -// DefaultDocumentM will decode empty documents using the primitive.M type. This behavior is restricted to data typed as -// "interface{}" or "map[string]interface{}". +// DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This +// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". // // Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.DefaultDocumentM] instead. func (dc *DecodeContext) DefaultDocumentM() { dc.defaultDocumentType = reflect.TypeOf(primitive.M{}) } -// DefaultDocumentD will decode empty documents using the primitive.D type. This behavior is restricted to data typed as -// "interface{}" or "map[string]interface{}". +// DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This +// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". // // Deprecated: Use [go.mongodb.org/mongo-driver/bson.Decoder.DefaultDocumentD] instead. func (dc *DecodeContext) DefaultDocumentD() { diff --git a/bson/decoder.go b/bson/decoder.go index 10043fc8f9..eac74cd399 100644 --- a/bson/decoder.go +++ b/bson/decoder.go @@ -137,12 +137,14 @@ func (d *Decoder) Decode(val interface{}) error { // Reset will reset the state of the decoder, using the same *DecodeContext used in // the original construction but using vr for reading. func (d *Decoder) Reset(vr bsonrw.ValueReader) error { + // TODO:(GODRIVER-2719): Remove error return value. d.vr = vr return nil } // SetRegistry replaces the current registry of the decoder with r. func (d *Decoder) SetRegistry(r *bsoncodec.Registry) error { + // TODO:(GODRIVER-2719): Remove error return value. d.dc.Registry = r return nil } @@ -151,18 +153,19 @@ func (d *Decoder) SetRegistry(r *bsoncodec.Registry) error { // // Deprecated: Use the Decoder configuration methods to set the desired unmarshal behavior instead. func (d *Decoder) SetContext(dc bsoncodec.DecodeContext) error { + // TODO:(GODRIVER-2719): Remove error return value. d.dc = dc return nil } -// DefaultDocumentM will decode empty documents using the primitive.M type. This behavior is restricted to data typed as -// "interface{}" or "map[string]interface{}". +// DefaultDocumentM causes the Decoder to always unmarshal documents into the primitive.M type. This +// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". func (d *Decoder) DefaultDocumentM() { d.defaultDocumentM = true } -// DefaultDocumentD will decode empty documents using the primitive.D type. This behavior is restricted to data typed as -// "interface{}" or "map[string]interface{}". +// DefaultDocumentD causes the Decoder to always unmarshal documents into the primitive.D type. This +// behavior is restricted to data typed as "interface{}" or "map[string]interface{}". func (d *Decoder) DefaultDocumentD() { d.defaultDocumentD = true } diff --git a/bson/encoder.go b/bson/encoder.go index 93bf014663..0be2a97fbc 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -42,6 +42,7 @@ type Encoder struct { // NewEncoder returns a new encoder that uses the DefaultRegistry to write to vw. func NewEncoder(vw bsonrw.ValueWriter) (*Encoder, error) { + // TODO:(GODRIVER-2719): Remove error return value. if vw == nil { return nil, errors.New("cannot create a new Encoder with a nil ValueWriter") } @@ -121,12 +122,14 @@ func (e *Encoder) Encode(val interface{}) error { // Reset will reset the state of the Encoder, using the same *EncodeContext used in // the original construction but using vw. func (e *Encoder) Reset(vw bsonrw.ValueWriter) error { + // TODO:(GODRIVER-2719): Remove error return value. e.vw = vw return nil } // SetRegistry replaces the current registry of the Encoder with r. func (e *Encoder) SetRegistry(r *bsoncodec.Registry) error { + // TODO:(GODRIVER-2719): Remove error return value. e.ec.Registry = r return nil } @@ -135,6 +138,7 @@ func (e *Encoder) SetRegistry(r *bsoncodec.Registry) error { // // Deprecated: Use the Encoder configuration methods set the desired marshal behavior instead. func (e *Encoder) SetContext(ec bsoncodec.EncodeContext) error { + // TODO:(GODRIVER-2719): Remove error return value. e.ec = ec return nil } diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index 84fa26cc24..11452a895c 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -10,7 +10,6 @@ import ( "math/rand" "reflect" "testing" - "unsafe" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsonrw" @@ -770,49 +769,7 @@ func TestUnmarshalByteSlicesUseDistinctArrays(t *testing.T) { // Assert that the byte slice in the unmarshaled value does not share any memory // addresses with the input byte slice. - assertDifferentArrays(t, data, tc.getByteSlice(got)) + assert.DifferentAddressRanges(t, data, tc.getByteSlice(got)) }) } } - -// assertDifferentArrays asserts that two byte slices reference distinct memory ranges, meaning -// they reference different underlying byte arrays. -func assertDifferentArrays(t *testing.T, a, b []byte) { - // Find the start and end memory addresses for the underlying byte array for each input byte - // slice. - sliceAddrRange := func(b []byte) (uintptr, uintptr) { - sh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) - return sh.Data, sh.Data + uintptr(sh.Cap-1) - } - aStart, aEnd := sliceAddrRange(a) - bStart, bEnd := sliceAddrRange(b) - - // If "b" starts after "a" ends or "a" starts after "b" ends, there is no overlap. - if bStart > aEnd || aStart > bEnd { - return - } - - // Otherwise, calculate the overlap start and end and print the memory overlap error message. - min := func(a, b uintptr) uintptr { - if a < b { - return a - } - return b - } - max := func(a, b uintptr) uintptr { - if a > b { - return a - } - return b - } - overlapLow := max(aStart, bStart) - overlapHigh := min(aEnd, bEnd) - - t.Errorf("Byte slices point to the same the same underlying byte array:\n"+ - "\ta addresses:\t%d ... %d\n"+ - "\tb addresses:\t%d ... %d\n"+ - "\toverlap:\t%d ... %d", - aStart, aEnd, - bStart, bEnd, - overlapLow, overlapHigh) -} diff --git a/internal/assert/assertion_mongo.go b/internal/assert/assertion_mongo.go new file mode 100644 index 0000000000..45bcfe47f8 --- /dev/null +++ b/internal/assert/assertion_mongo.go @@ -0,0 +1,90 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +// assertion_mongo.go contains MongoDB-specific extensions to the "assert" +// package. + +package assert + +import ( + "fmt" + "reflect" + "unsafe" +) + +// DifferentAddressRanges asserts that two byte slices reference distinct memory +// address ranges, meaning they reference different underlying byte arrays. +func DifferentAddressRanges(t TestingT, a, b []byte) (ok bool) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + if len(a) == 0 || len(b) == 0 { + return true + } + + // Find the start and end memory addresses for the underlying byte array for + // each input byte slice. + sliceAddrRange := func(b []byte) (uintptr, uintptr) { + sh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + return sh.Data, sh.Data + uintptr(sh.Cap-1) + } + aStart, aEnd := sliceAddrRange(a) + bStart, bEnd := sliceAddrRange(b) + + // If "b" starts after "a" ends or "a" starts after "b" ends, there is no + // overlap. + if bStart > aEnd || aStart > bEnd { + return true + } + + // Otherwise, calculate the overlap start and end and print the memory + // overlap error message. + min := func(a, b uintptr) uintptr { + if a < b { + return a + } + return b + } + max := func(a, b uintptr) uintptr { + if a > b { + return a + } + return b + } + overlapLow := max(aStart, bStart) + overlapHigh := min(aEnd, bEnd) + + t.Errorf("Byte slices point to the same underlying byte array:\n"+ + "\ta addresses:\t%d ... %d\n"+ + "\tb addresses:\t%d ... %d\n"+ + "\toverlap:\t%d ... %d", + aStart, aEnd, + bStart, bEnd, + overlapLow, overlapHigh) + + return false +} + +// EqualBSON asserts that the expected and actual BSON binary values are equal. +// If the values are not equal, it prints both the binary and Extended JSON diff +// of the BSON values. The provided BSON value types must implement the +// fmt.Stringer interface. +func EqualBSON(t TestingT, expected, actual interface{}) bool { + if h, ok := t.(tHelper); ok { + h.Helper() + } + + return Equal(t, + expected, + actual, + `expected and actual BSON values do not match +As Extended JSON: +Expected: %s +Actual : %s`, + expected.(fmt.Stringer).String(), + actual.(fmt.Stringer).String()) +} diff --git a/internal/assert/assertion_mongo_test.go b/internal/assert/assertion_mongo_test.go new file mode 100644 index 0000000000..3f16af10ec --- /dev/null +++ b/internal/assert/assertion_mongo_test.go @@ -0,0 +1,125 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package assert + +import ( + "testing" + + "go.mongodb.org/mongo-driver/bson" +) + +func TestDifferentAddressRanges(t *testing.T) { + t.Parallel() + + slice := []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + + testCases := []struct { + name string + a []byte + b []byte + want bool + }{ + { + name: "distinct byte slices", + a: []byte{0, 1, 2, 3}, + b: []byte{0, 1, 2, 3}, + want: true, + }, + { + name: "same byte slice", + a: slice, + b: slice, + want: false, + }, + { + name: "whole and subslice", + a: slice, + b: slice[:4], + want: false, + }, + { + name: "two subslices", + a: slice[1:2], + b: slice[3:4], + want: false, + }, + { + name: "empty", + a: []byte{0, 1, 2, 3}, + b: []byte{}, + want: true, + }, + { + name: "nil", + a: []byte{0, 1, 2, 3}, + b: nil, + want: true, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := DifferentAddressRanges(new(testing.T), tc.a, tc.b) + if got != tc.want { + t.Errorf("DifferentAddressRanges(%p, %p) = %v, want %v", tc.a, tc.b, got, tc.want) + } + }) + } +} + +func TestEqualBSON(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + expected interface{} + actual interface{} + want bool + }{ + { + name: "equal bson.Raw", + expected: bson.Raw{5, 0, 0, 0, 0}, + actual: bson.Raw{5, 0, 0, 0, 0}, + want: true, + }, + { + name: "different bson.Raw", + expected: bson.Raw{8, 0, 0, 0, 10, 120, 0, 0}, + actual: bson.Raw{5, 0, 0, 0, 0}, + want: false, + }, + { + name: "invalid bson.Raw", + expected: bson.Raw{99, 99, 99, 99}, + actual: bson.Raw{5, 0, 0, 0, 0}, + want: false, + }, + { + name: "nil bson.Raw", + expected: bson.Raw(nil), + actual: bson.Raw(nil), + want: true, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := EqualBSON(new(testing.T), tc.expected, tc.actual) + if got != tc.want { + t.Errorf("EqualBSON(%#v, %#v) = %v, want %v", tc.expected, tc.actual, got, tc.want) + } + }) + } +} diff --git a/mongo/bulk_write.go b/mongo/bulk_write.go index 66a3b3f54e..58e64f1d9a 100644 --- a/mongo/bulk_write.go +++ b/mongo/bulk_write.go @@ -10,6 +10,7 @@ import ( "context" "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/writeconcern" @@ -165,7 +166,11 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera var i int for _, model := range batch.models { converted := model.(*InsertOneModel) - doc, _, err := transformAndEnsureID(bw.collection.registry, converted.Document) + doc, err := marshal(converted.Document, bw.collection.bsonOpts, bw.collection.registry) + if err != nil { + return operation.InsertResult{}, err + } + doc, _, err = ensureID(doc, primitive.NewObjectID(), bw.collection.bsonOpts, bw.collection.registry) if err != nil { return operation.InsertResult{}, err } @@ -182,7 +187,7 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout). Logger(bw.collection.client.logger) if bw.comment != nil { - comment, err := transformValue(bw.collection.registry, bw.comment, true, "comment") + comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { return op.Result(), err } @@ -217,10 +222,22 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera switch converted := model.(type) { case *DeleteOneModel: - doc, err = createDeleteDoc(converted.Filter, converted.Collation, converted.Hint, true, bw.collection.registry) + doc, err = createDeleteDoc( + converted.Filter, + converted.Collation, + converted.Hint, + true, + bw.collection.bsonOpts, + bw.collection.registry) hasHint = hasHint || (converted.Hint != nil) case *DeleteManyModel: - doc, err = createDeleteDoc(converted.Filter, converted.Collation, converted.Hint, false, bw.collection.registry) + doc, err = createDeleteDoc( + converted.Filter, + converted.Collation, + converted.Hint, + false, + bw.collection.bsonOpts, + bw.collection.registry) hasHint = hasHint || (converted.Hint != nil) } @@ -240,14 +257,14 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout). Logger(bw.collection.client.logger) if bw.comment != nil { - comment, err := transformValue(bw.collection.registry, bw.comment, true, "comment") + comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { return op.Result(), err } op.Comment(comment) } if bw.let != nil { - let, err := transformBsoncoreDocument(bw.collection.registry, bw.let, true, "let") + let, err := marshal(bw.let, bw.collection.bsonOpts, bw.collection.registry) if err != nil { return operation.DeleteResult{}, err } @@ -267,10 +284,15 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera return op.Result(), err } -func createDeleteDoc(filter interface{}, collation *options.Collation, hint interface{}, deleteOne bool, - registry *bsoncodec.Registry) (bsoncore.Document, error) { - - f, err := transformBsoncoreDocument(registry, filter, true, "filter") +func createDeleteDoc( + filter interface{}, + collation *options.Collation, + hint interface{}, + deleteOne bool, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, +) (bsoncore.Document, error) { + f, err := marshal(filter, bsonOpts, registry) if err != nil { return nil, err } @@ -286,7 +308,10 @@ func createDeleteDoc(filter interface{}, collation *options.Collation, hint inte doc = bsoncore.AppendDocumentElement(doc, "collation", collation.ToDocument()) } if hint != nil { - hintVal, err := transformValue(registry, hint, false, "hint") + if isUnorderedMap(hint) { + return nil, ErrMapForOrderedArgument{"hint"} + } + hintVal, err := marshalValue(hint, bsonOpts, registry) if err != nil { return nil, err } @@ -307,17 +332,44 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera switch converted := model.(type) { case *ReplaceOneModel: - doc, err = createUpdateDoc(converted.Filter, converted.Replacement, converted.Hint, nil, converted.Collation, converted.Upsert, false, - false, bw.collection.registry) + doc, err = createUpdateDoc( + converted.Filter, + converted.Replacement, + converted.Hint, + nil, + converted.Collation, + converted.Upsert, + false, + false, + bw.collection.bsonOpts, + bw.collection.registry) hasHint = hasHint || (converted.Hint != nil) case *UpdateOneModel: - doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.Hint, converted.ArrayFilters, converted.Collation, converted.Upsert, false, - true, bw.collection.registry) + doc, err = createUpdateDoc( + converted.Filter, + converted.Update, + converted.Hint, + converted.ArrayFilters, + converted.Collation, + converted.Upsert, + false, + true, + bw.collection.bsonOpts, + bw.collection.registry) hasHint = hasHint || (converted.Hint != nil) hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil) case *UpdateManyModel: - doc, err = createUpdateDoc(converted.Filter, converted.Update, converted.Hint, converted.ArrayFilters, converted.Collation, converted.Upsert, true, - true, bw.collection.registry) + doc, err = createUpdateDoc( + converted.Filter, + converted.Update, + converted.Hint, + converted.ArrayFilters, + converted.Collation, + converted.Upsert, + true, + true, + bw.collection.bsonOpts, + bw.collection.registry) hasHint = hasHint || (converted.Hint != nil) hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil) } @@ -336,14 +388,14 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera ArrayFilters(hasArrayFilters).ServerAPI(bw.collection.client.serverAPI). Timeout(bw.collection.client.timeout).Logger(bw.collection.client.logger) if bw.comment != nil { - comment, err := transformValue(bw.collection.registry, bw.comment, true, "comment") + comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { return op.Result(), err } op.Comment(comment) } if bw.let != nil { - let, err := transformBsoncoreDocument(bw.collection.registry, bw.let, true, "let") + let, err := marshal(bw.let, bw.collection.bsonOpts, bw.collection.registry) if err != nil { return operation.UpdateResult{}, err } @@ -365,6 +417,7 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera return op.Result(), err } + func createUpdateDoc( filter interface{}, update interface{}, @@ -374,9 +427,10 @@ func createUpdateDoc( upsert *bool, multi bool, checkDollarKey bool, + bsonOpts *options.BSONOptions, registry *bsoncodec.Registry, ) (bsoncore.Document, error) { - f, err := transformBsoncoreDocument(registry, filter, true, "filter") + f, err := marshal(filter, bsonOpts, registry) if err != nil { return nil, err } @@ -384,7 +438,7 @@ func createUpdateDoc( uidx, updateDoc := bsoncore.AppendDocumentStart(nil) updateDoc = bsoncore.AppendDocumentElement(updateDoc, "q", f) - u, err := transformUpdateValue(registry, update, checkDollarKey) + u, err := marshalUpdateValue(update, bsonOpts, registry, checkDollarKey) if err != nil { return nil, err } @@ -396,11 +450,15 @@ func createUpdateDoc( } if arrayFilters != nil { - arr, err := arrayFilters.ToArrayDocument() + reg := registry + if arrayFilters.Registry != nil { + reg = arrayFilters.Registry + } + arr, err := marshalValue(arrayFilters.Filters, bsonOpts, reg) if err != nil { return nil, err } - updateDoc = bsoncore.AppendArrayElement(updateDoc, "arrayFilters", arr) + updateDoc = bsoncore.AppendArrayElement(updateDoc, "arrayFilters", arr.Data) } if collation != nil { @@ -412,7 +470,10 @@ func createUpdateDoc( } if hint != nil { - hintVal, err := transformValue(registry, hint, false, "hint") + if isUnorderedMap(hint) { + return nil, ErrMapForOrderedArgument{"hint"} + } + hintVal, err := marshalValue(hint, bsonOpts, registry) if err != nil { return nil, err } diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 6857e1e3cd..76fe86f000 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -80,6 +80,7 @@ type ChangeStream struct { err error sess *session.Client client *Client + bsonOpts *options.BSONOptions registry *bsoncodec.Registry streamType StreamType options *options.ChangeStreamOptions @@ -92,6 +93,7 @@ type changeStreamConfig struct { readConcern *readconcern.ReadConcern readPreference *readpref.ReadPref client *Client + bsonOpts *options.BSONOptions registry *bsoncodec.Registry streamType StreamType collectionName string @@ -107,6 +109,7 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in cs := &ChangeStream{ client: config.client, + bsonOpts: config.bsonOpts, registry: config.registry, streamType: config.streamType, options: options.MergeChangeStreamOptions(opts...), @@ -138,7 +141,7 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in if comment := cs.options.Comment; comment != nil { cs.aggregate.Comment(*comment) - commentVal, err := transformValue(cs.registry, comment, true, "comment") + commentVal, err := marshalValue(comment, cs.bsonOpts, cs.registry) if err != nil { return nil, err } @@ -389,7 +392,7 @@ func (cs *ChangeStream) storeResumeToken() error { func (cs *ChangeStream) buildPipelineSlice(pipeline interface{}) error { val := reflect.ValueOf(pipeline) if !val.IsValid() || !(val.Kind() == reflect.Slice) { - cs.err = errors.New("can only transform slices and arrays into aggregation pipelines, but got invalid") + cs.err = errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid") return cs.err } @@ -410,7 +413,7 @@ func (cs *ChangeStream) buildPipelineSlice(pipeline interface{}) error { for i := 0; i < val.Len(); i++ { var elem []byte - elem, cs.err = transformBsoncoreDocument(cs.registry, val.Index(i).Interface(), true, fmt.Sprintf("pipeline stage :%v", i)) + elem, cs.err = marshal(val.Index(i).Interface(), cs.bsonOpts, cs.registry) if cs.err != nil { return cs.err } @@ -438,7 +441,7 @@ func (cs *ChangeStream) createPipelineOptionsDoc() (bsoncore.Document, error) { if cs.options.ResumeAfter != nil { var raDoc bsoncore.Document - raDoc, cs.err = transformBsoncoreDocument(cs.registry, cs.options.ResumeAfter, true, "resumeAfter") + raDoc, cs.err = marshal(cs.options.ResumeAfter, cs.bsonOpts, cs.registry) if cs.err != nil { return nil, cs.err } @@ -452,7 +455,7 @@ func (cs *ChangeStream) createPipelineOptionsDoc() (bsoncore.Document, error) { if cs.options.StartAfter != nil { var saDoc bsoncore.Document - saDoc, cs.err = transformBsoncoreDocument(cs.registry, cs.options.StartAfter, true, "startAfter") + saDoc, cs.err = marshal(cs.options.StartAfter, cs.bsonOpts, cs.registry) if cs.err != nil { return nil, cs.err } @@ -531,7 +534,11 @@ func (cs *ChangeStream) Decode(val interface{}) error { return ErrNilCursor } - return bson.UnmarshalWithRegistry(cs.registry, cs.Current, val) + dec, err := getDecoder(cs.Current, cs.bsonOpts, cs.registry) + if err != nil { + return fmt.Errorf("error configuring BSON decoder: %w", err) + } + return dec.Decode(val) } // Err returns the last error seen by the change stream, or nil if no errors has occurred. diff --git a/mongo/client.go b/mongo/client.go index 8bd36640eb..588d741fa2 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -61,6 +61,7 @@ type Client struct { readPreference *readpref.ReadPref readConcern *readconcern.ReadConcern writeConcern *writeconcern.WriteConcern + bsonOpts *options.BSONOptions registry *bsoncodec.Registry monitor *event.CommandMonitor serverAPI *driver.ServerAPIOptions @@ -164,6 +165,10 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { if clientOpt.ReadPreference != nil { client.readPreference = clientOpt.ReadPreference } + // BSONOptions + if clientOpt.BSONOptions != nil { + client.bsonOpts = clientOpt.BSONOptions + } // Registry client.registry = bson.DefaultRegistry if clientOpt.Registry != nil { @@ -531,7 +536,7 @@ func (c *Client) newMongoCrypt(opts *options.AutoEncryptionOptions) (*mongocrypt // convert schemas in SchemaMap to bsoncore documents cryptSchemaMap := make(map[string]bsoncore.Document) for k, v := range opts.SchemaMap { - schema, err := transformBsoncoreDocument(c.registry, v, true, "schemaMap") + schema, err := marshal(v, c.bsonOpts, c.registry) if err != nil { return nil, err } @@ -541,14 +546,14 @@ func (c *Client) newMongoCrypt(opts *options.AutoEncryptionOptions) (*mongocrypt // convert schemas in EncryptedFieldsMap to bsoncore documents cryptEncryptedFieldsMap := make(map[string]bsoncore.Document) for k, v := range opts.EncryptedFieldsMap { - encryptedFields, err := transformBsoncoreDocument(c.registry, v, true, "encryptedFieldsMap") + encryptedFields, err := marshal(v, c.bsonOpts, c.registry) if err != nil { return nil, err } cryptEncryptedFieldsMap[k] = encryptedFields } - kmsProviders, err := transformBsoncoreDocument(c.registry, opts.KmsProviders, true, "kmsProviders") + kmsProviders, err := marshal(opts.KmsProviders, c.bsonOpts, c.registry) if err != nil { return nil, fmt.Errorf("error creating KMS providers document: %v", err) } @@ -674,7 +679,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... return ListDatabasesResult{}, err } - filterDoc, err := transformBsoncoreDocument(c.registry, filter, true, "filter") + filterDoc, err := marshal(filter, c.bsonOpts, c.registry) if err != nil { return ListDatabasesResult{}, err } @@ -805,6 +810,7 @@ func (c *Client) Watch(ctx context.Context, pipeline interface{}, readConcern: c.readConcern, readPreference: c.readPreference, client: c, + bsonOpts: c.bsonOpts, registry: c.registry, streamType: ClientStream, crypt: c.cryptFLE, diff --git a/mongo/client_encryption.go b/mongo/client_encryption.go index 91c259c312..01c2ec3193 100644 --- a/mongo/client_encryption.go +++ b/mongo/client_encryption.go @@ -44,7 +44,7 @@ func NewClientEncryption(keyVaultClient *Client, opts ...*options.ClientEncrypti db, coll := splitNamespace(ceo.KeyVaultNamespace) ce.keyVaultColl = ce.keyVaultClient.Database(db).Collection(coll, keyVaultCollOpts) - kmsProviders, err := transformBsoncoreDocument(bson.DefaultRegistry, ceo.KmsProviders, true, "kmsProviders") + kmsProviders, err := marshal(ceo.KmsProviders, nil, nil) if err != nil { return nil, fmt.Errorf("error creating KMS providers map: %v", err) } @@ -86,7 +86,7 @@ func (ce *ClientEncryption) CreateEncryptedCollection(ctx context.Context, return nil, nil, errors.New("no EncryptedFields defined for the collection") } - efBSON, err := transformBsoncoreDocument(db.registry, ef, true, "encryptedFields") + efBSON, err := marshal(ef, db.bsonOpts, db.registry) if err != nil { return nil, nil, err } @@ -147,7 +147,10 @@ func (ce *ClientEncryption) CreateDataKey(ctx context.Context, kmsProvider strin dko := options.MergeDataKeyOptions(opts...) co := mcopts.DataKey().SetKeyAltNames(dko.KeyAltNames) if dko.MasterKey != nil { - keyDoc, err := transformBsoncoreDocument(ce.keyVaultClient.registry, dko.MasterKey, true, "masterKey") + keyDoc, err := marshal( + dko.MasterKey, + ce.keyVaultClient.bsonOpts, + ce.keyVaultClient.registry) if err != nil { return primitive.Binary{}, err } @@ -232,7 +235,7 @@ func (ce *ClientEncryption) Encrypt(ctx context.Context, val bson.RawValue, func (ce *ClientEncryption) EncryptExpression(ctx context.Context, expr interface{}, result interface{}, opts ...*options.EncryptOptions) error { transformed := transformExplicitEncryptionOptions(opts...) - exprDoc, err := transformBsoncoreDocument(bson.DefaultRegistry, expr, true, "expr") + exprDoc, err := marshal(expr, nil, nil) if err != nil { return err } @@ -380,7 +383,10 @@ func (ce *ClientEncryption) RewrapManyDataKey(ctx context.Context, filter interf // Transfer rmdko options to /x/ package options to publish the mongocrypt feed. co := mcopts.RewrapManyDataKey() if rmdko.MasterKey != nil { - keyDoc, err := transformBsoncoreDocument(ce.keyVaultClient.registry, rmdko.MasterKey, true, "masterKey") + keyDoc, err := marshal( + rmdko.MasterKey, + ce.keyVaultClient.bsonOpts, + ce.keyVaultClient.registry) if err != nil { return nil, err } @@ -391,7 +397,7 @@ func (ce *ClientEncryption) RewrapManyDataKey(ctx context.Context, filter interf } // Prepare the filters and rewrap the data key using mongocrypt. - filterdoc, err := transformBsoncoreDocument(ce.keyVaultClient.registry, filter, true, "filter") + filterdoc, err := marshal(filter, ce.keyVaultClient.bsonOpts, ce.keyVaultClient.registry) if err != nil { return nil, err } diff --git a/mongo/client_examples_test.go b/mongo/client_examples_test.go index d5db825984..4123e8e0e4 100644 --- a/mongo/client_examples_test.go +++ b/mongo/client_examples_test.go @@ -423,3 +423,48 @@ func ExampleConnect_stableAPI() { } _ = serverAPIDeprecationClient } + +func ExampleConnect_bSONOptions() { + // Configure a client that customizes the BSON marshal and unmarshal + // behavior. + + // Specify BSON options that cause the driver to fallback to "json" + // struct tags if "bson" struct tags are missing, marshal nil Go maps as + // empty BSON documents, and marshals nil Go slices as empty BSON + // arrays. + bsonOpts := &options.BSONOptions{ + UseJSONStructTags: true, + NilMapAsEmpty: true, + NilSliceAsEmpty: true, + } + + clientOpts := options.Client(). + ApplyURI("mongodb://localhost:27017"). + SetBSONOptions(bsonOpts) + + client, err := mongo.Connect(context.TODO(), clientOpts) + if err != nil { + panic(err) + } + defer func() { + if err := client.Disconnect(context.TODO()); err != nil { + panic(err) + } + }() + + coll := client.Database("db").Collection("coll") + + // Define a struct that contains a map and a slice and uses "json" struct + // tags to specify field names. + type myDocument struct { + MyMap map[string]interface{} `json:"a"` + MySlice []string `json:"b"` + } + + // Insert an instance of the struct with all empty fields. Expect the + // resulting BSON document to have a structure like {"a": {}, "b": []} + _, err = coll.InsertOne(context.TODO(), myDocument{}) + if err != nil { + panic(err) + } +} diff --git a/mongo/client_test.go b/mongo/client_test.go index bbfa13efef..c279de7939 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -75,7 +75,7 @@ func TestClient(t *testing.T) { client.sessionPool = &session.Pool{} _, err := client.Watch(bgCtx, nil) - watchErr := errors.New("can only transform slices and arrays into aggregation pipelines, but got invalid") + watchErr := errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid") assert.Equal(t, watchErr, err, "expected error %v, got %v", watchErr, err) _, err = client.ListDatabases(bgCtx, nil) diff --git a/mongo/collection.go b/mongo/collection.go index e26bca31be..1e696ded96 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -10,12 +10,14 @@ import ( "context" "errors" "fmt" + "reflect" "strings" "time" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/internal" "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" @@ -38,6 +40,7 @@ type Collection struct { readPreference *readpref.ReadPref readSelector description.ServerSelector writeSelector description.ServerSelector + bsonOpts *options.BSONOptions registry *bsoncodec.Registry } @@ -46,6 +49,7 @@ type aggregateParams struct { ctx context.Context pipeline interface{} client *Client + bsonOpts *options.BSONOptions registry *bsoncodec.Registry readConcern *readconcern.ReadConcern writeConcern *writeconcern.WriteConcern @@ -82,6 +86,11 @@ func newCollection(db *Database, name string, opts ...*options.CollectionOptions rp = collOpt.ReadPreference } + bsonOpts := db.bsonOpts + if collOpt.BSONOptions != nil { + bsonOpts = collOpt.BSONOptions + } + reg := db.registry if collOpt.Registry != nil { reg = collOpt.Registry @@ -106,6 +115,7 @@ func newCollection(db *Database, name string, opts ...*options.CollectionOptions writeConcern: wc, readSelector: readSelector, writeSelector: writeSelector, + bsonOpts: bsonOpts, registry: reg, } @@ -242,11 +252,17 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, docs := make([]bsoncore.Document, len(documents)) for i, doc := range documents { - var err error - docs[i], result[i], err = transformAndEnsureID(coll.registry, doc) + bsoncoreDoc, err := marshal(doc, coll.bsonOpts, coll.registry) + if err != nil { + return nil, err + } + bsoncoreDoc, id, err := ensureID(bsoncoreDoc, primitive.NewObjectID(), coll.bsonOpts, coll.registry) if err != nil { return nil, err } + + docs[i] = bsoncoreDoc + result[i] = id } sess := sessionFromContext(ctx) @@ -281,7 +297,7 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, op = op.BypassDocumentValidation(*imo.BypassDocumentValidation) } if imo.Comment != nil { - comment, err := transformValue(coll.registry, imo.Comment, true, "comment") + comment, err := marshalValue(imo.Comment, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -400,7 +416,7 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -438,7 +454,10 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn doc = bsoncore.AppendDocumentElement(doc, "collation", do.Collation.ToDocument()) } if do.Hint != nil { - hint, err := transformValue(coll.registry, do.Hint, false, "hint") + if isUnorderedMap(do.Hint) { + return nil, ErrMapForOrderedArgument{"hint"} + } + hint, err := marshalValue(do.Hint, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -454,7 +473,7 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Ordered(true). ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger) if do.Comment != nil { - comment, err := transformValue(coll.registry, do.Comment, true, "comment") + comment, err := marshalValue(do.Comment, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -464,7 +483,7 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn op = op.Hint(true) } if do.Let != nil { - let, err := transformBsoncoreDocument(coll.registry, do.Let, true, "let") + let, err := marshal(do.Let, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -527,8 +546,17 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc // collation, arrayFilters, upsert, and hint are included on the individual update documents rather than as part of the // command - updateDoc, err := createUpdateDoc(filter, update, uo.Hint, uo.ArrayFilters, uo.Collation, uo.Upsert, multi, - checkDollarKey, coll.registry) + updateDoc, err := createUpdateDoc( + filter, + update, + uo.Hint, + uo.ArrayFilters, + uo.Collation, + uo.Upsert, + multi, + checkDollarKey, + coll.bsonOpts, + coll.registry) if err != nil { return nil, err } @@ -562,7 +590,7 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc ArrayFilters(uo.ArrayFilters != nil).Ordered(true).ServerAPI(coll.client.serverAPI). Timeout(coll.client.timeout).Logger(coll.client.logger) if uo.Let != nil { - let, err := transformBsoncoreDocument(coll.registry, uo.Let, true, "let") + let, err := marshal(uo.Let, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -573,7 +601,7 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc op = op.BypassDocumentValidation(*uo.BypassDocumentValidation) } if uo.Comment != nil { - comment, err := transformValue(coll.registry, uo.Comment, true, "comment") + comment, err := marshalValue(uo.Comment, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -648,7 +676,7 @@ func (coll *Collection) UpdateOne(ctx context.Context, filter interface{}, updat ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -676,7 +704,7 @@ func (coll *Collection) UpdateMany(ctx context.Context, filter interface{}, upda ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -704,12 +732,12 @@ func (coll *Collection) ReplaceOne(ctx context.Context, filter interface{}, ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return nil, err } - r, err := transformBsoncoreDocument(coll.registry, replacement, true, "replacement") + r, err := marshal(replacement, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -756,6 +784,7 @@ func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{}, registry: coll.registry, readConcern: coll.readConcern, writeConcern: coll.writeConcern, + bsonOpts: coll.bsonOpts, retryRead: coll.client.retryReads, db: coll.db.name, col: coll.name, @@ -773,7 +802,7 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { a.ctx = context.Background() } - pipelineArr, hasOutputStage, err := transformAggregatePipeline(a.registry, a.pipeline) + pipelineArr, hasOutputStage, err := marshalAggregatePipeline(a.pipeline, a.bsonOpts, a.registry) if err != nil { return nil, err } @@ -851,21 +880,24 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { if ao.Comment != nil { op.Comment(*ao.Comment) - commentVal, err := transformValue(a.registry, ao.Comment, true, "comment") + commentVal, err := marshalValue(ao.Comment, a.bsonOpts, a.registry) if err != nil { return nil, err } cursorOpts.Comment = commentVal } if ao.Hint != nil { - hintVal, err := transformValue(a.registry, ao.Hint, false, "hint") + if isUnorderedMap(ao.Hint) { + return nil, ErrMapForOrderedArgument{"hint"} + } + hintVal, err := marshalValue(ao.Hint, a.bsonOpts, a.registry) if err != nil { return nil, err } op.Hint(hintVal) } if ao.Let != nil { - let, err := transformBsoncoreDocument(a.registry, ao.Let, true, "let") + let, err := marshal(ao.Let, a.bsonOpts, a.registry) if err != nil { return nil, err } @@ -904,7 +936,7 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { if err != nil { return nil, replaceErrors(err) } - cursor, err := newCursorWithSession(bc, a.registry, sess) + cursor, err := newCursorWithSession(bc, a.client.bsonOpts, a.registry, sess) return cursor, replaceErrors(err) } @@ -925,7 +957,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, countOpts := options.MergeCountOptions(opts...) - pipelineArr, err := countDocumentsAggregatePipeline(coll.registry, filter, countOpts) + pipelineArr, err := countDocumentsAggregatePipeline(filter, coll.bsonOpts, coll.registry, countOpts) if err != nil { return 0, err } @@ -956,7 +988,10 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, op.Comment(*countOpts.Comment) } if countOpts.Hint != nil { - hintVal, err := transformValue(coll.registry, countOpts.Hint, false, "hint") + if isUnorderedMap(countOpts.Hint) { + return 0, ErrMapForOrderedArgument{"hint"} + } + hintVal, err := marshalValue(countOpts.Hint, coll.bsonOpts, coll.registry) if err != nil { return 0, err } @@ -1033,7 +1068,7 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, Timeout(coll.client.timeout).MaxTime(co.MaxTime) if co.Comment != nil { - comment, err := transformValue(coll.registry, co.Comment, false, "comment") + comment, err := marshalValue(co.Comment, coll.bsonOpts, coll.registry) if err != nil { return 0, err } @@ -1067,7 +1102,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1103,7 +1138,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i op.Collation(bsoncore.Document(option.Collation.ToDocument())) } if option.Comment != nil { - comment, err := transformValue(coll.registry, option.Comment, true, "comment") + comment, err := marshalValue(option.Comment, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1158,7 +1193,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1211,7 +1246,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, if fo.Comment != nil { op.Comment(*fo.Comment) - commentVal, err := transformValue(coll.registry, fo.Comment, true, "comment") + commentVal, err := marshalValue(fo.Comment, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1227,14 +1262,17 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, } } if fo.Hint != nil { - hint, err := transformValue(coll.registry, fo.Hint, false, "hint") + if isUnorderedMap(fo.Hint) { + return nil, ErrMapForOrderedArgument{"hint"} + } + hint, err := marshalValue(fo.Hint, coll.bsonOpts, coll.registry) if err != nil { return nil, err } op.Hint(hint) } if fo.Let != nil { - let, err := transformBsoncoreDocument(coll.registry, fo.Let, true, "let") + let, err := marshal(fo.Let, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1250,7 +1288,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, op.Limit(limit) } if fo.Max != nil { - max, err := transformBsoncoreDocument(coll.registry, fo.Max, true, "max") + max, err := marshal(fo.Max, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1260,7 +1298,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, cursorOpts.MaxTimeMS = int64(*fo.MaxAwaitTime / time.Millisecond) } if fo.Min != nil { - min, err := transformBsoncoreDocument(coll.registry, fo.Min, true, "min") + min, err := marshal(fo.Min, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1273,7 +1311,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, op.OplogReplay(*fo.OplogReplay) } if fo.Projection != nil { - proj, err := transformBsoncoreDocument(coll.registry, fo.Projection, true, "projection") + proj, err := marshal(fo.Projection, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1292,7 +1330,10 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, op.Snapshot(*fo.Snapshot) } if fo.Sort != nil { - sort, err := transformBsoncoreDocument(coll.registry, fo.Sort, false, "sort") + if isUnorderedMap(fo.Sort) { + return nil, ErrMapForOrderedArgument{"sort"} + } + sort, err := marshal(fo.Sort, coll.bsonOpts, coll.registry) if err != nil { return nil, err } @@ -1312,7 +1353,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{}, if err != nil { return nil, replaceErrors(err) } - return newCursorWithSession(bc, coll.registry, sess) + return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess) } // FindOne executes a find command and returns a SingleResult for one document in the collection. @@ -1362,7 +1403,13 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{}, findOpts = append(findOpts, options.Find().SetLimit(-1)) cursor, err := coll.Find(ctx, filter, findOpts...) - return &SingleResult{ctx: ctx, cur: cursor, reg: coll.registry, err: replaceErrors(err)} + return &SingleResult{ + ctx: ctx, + cur: cursor, + bsonOpts: coll.bsonOpts, + reg: coll.registry, + err: replaceErrors(err), + } } func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAndModify) *SingleResult { @@ -1413,7 +1460,12 @@ func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAnd return &SingleResult{err: err} } - return &SingleResult{ctx: ctx, rdr: bson.Raw(op.Result().Value), reg: coll.registry} + return &SingleResult{ + ctx: ctx, + rdr: bson.Raw(op.Result().Value), + bsonOpts: coll.bsonOpts, + reg: coll.registry, + } } // FindOneAndDelete executes a findAndModify command to delete at most one document in the collection. and returns the @@ -1430,7 +1482,7 @@ func (coll *Collection) findAndModify(ctx context.Context, op *operation.FindAnd func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{}, opts ...*options.FindOneAndDeleteOptions) *SingleResult { - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1441,35 +1493,41 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} op = op.Collation(bsoncore.Document(fod.Collation.ToDocument())) } if fod.Comment != nil { - comment, err := transformValue(coll.registry, fod.Comment, true, "comment") + comment, err := marshalValue(fod.Comment, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Comment(comment) } if fod.Projection != nil { - proj, err := transformBsoncoreDocument(coll.registry, fod.Projection, true, "projection") + proj, err := marshal(fod.Projection, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Fields(proj) } if fod.Sort != nil { - sort, err := transformBsoncoreDocument(coll.registry, fod.Sort, false, "sort") + if isUnorderedMap(fod.Sort) { + return &SingleResult{err: ErrMapForOrderedArgument{"sort"}} + } + sort, err := marshal(fod.Sort, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Sort(sort) } if fod.Hint != nil { - hint, err := transformValue(coll.registry, fod.Hint, false, "hint") + if isUnorderedMap(fod.Hint) { + return &SingleResult{err: ErrMapForOrderedArgument{"hint"}} + } + hint, err := marshalValue(fod.Hint, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Hint(hint) } if fod.Let != nil { - let, err := transformBsoncoreDocument(coll.registry, fod.Let, true, "let") + let, err := marshal(fod.Let, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1496,11 +1554,11 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{}, replacement interface{}, opts ...*options.FindOneAndReplaceOptions) *SingleResult { - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } - r, err := transformBsoncoreDocument(coll.registry, replacement, true, "replacement") + r, err := marshal(replacement, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1518,14 +1576,14 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ op = op.Collation(bsoncore.Document(fo.Collation.ToDocument())) } if fo.Comment != nil { - comment, err := transformValue(coll.registry, fo.Comment, true, "comment") + comment, err := marshalValue(fo.Comment, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Comment(comment) } if fo.Projection != nil { - proj, err := transformBsoncoreDocument(coll.registry, fo.Projection, true, "projection") + proj, err := marshal(fo.Projection, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1535,7 +1593,10 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ op = op.NewDocument(*fo.ReturnDocument == options.After) } if fo.Sort != nil { - sort, err := transformBsoncoreDocument(coll.registry, fo.Sort, false, "sort") + if isUnorderedMap(fo.Sort) { + return &SingleResult{err: ErrMapForOrderedArgument{"sort"}} + } + sort, err := marshal(fo.Sort, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1545,14 +1606,17 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ op = op.Upsert(*fo.Upsert) } if fo.Hint != nil { - hint, err := transformValue(coll.registry, fo.Hint, false, "hint") + if isUnorderedMap(fo.Hint) { + return &SingleResult{err: ErrMapForOrderedArgument{"hint"}} + } + hint, err := marshalValue(fo.Hint, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Hint(hint) } if fo.Let != nil { - let, err := transformBsoncoreDocument(coll.registry, fo.Let, true, "let") + let, err := marshal(fo.Let, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1584,7 +1648,7 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} ctx = context.Background() } - f, err := transformBsoncoreDocument(coll.registry, filter, true, "filter") + f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1593,18 +1657,23 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). MaxTime(fo.MaxTime) - u, err := transformUpdateValue(coll.registry, update, true) + u, err := marshalUpdateValue(update, coll.bsonOpts, coll.registry, true) if err != nil { return &SingleResult{err: err} } op = op.Update(u) if fo.ArrayFilters != nil { - filtersDoc, err := fo.ArrayFilters.ToArrayDocument() + af := fo.ArrayFilters + reg := coll.registry + if af.Registry != nil { + reg = af.Registry + } + filtersDoc, err := marshalValue(af.Filters, coll.bsonOpts, reg) if err != nil { return &SingleResult{err: err} } - op = op.ArrayFilters(bsoncore.Document(filtersDoc)) + op = op.ArrayFilters(filtersDoc.Data) } if fo.BypassDocumentValidation != nil && *fo.BypassDocumentValidation { op = op.BypassDocumentValidation(*fo.BypassDocumentValidation) @@ -1613,14 +1682,14 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} op = op.Collation(bsoncore.Document(fo.Collation.ToDocument())) } if fo.Comment != nil { - comment, err := transformValue(coll.registry, fo.Comment, true, "comment") + comment, err := marshalValue(fo.Comment, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Comment(comment) } if fo.Projection != nil { - proj, err := transformBsoncoreDocument(coll.registry, fo.Projection, true, "projection") + proj, err := marshal(fo.Projection, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1630,7 +1699,10 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} op = op.NewDocument(*fo.ReturnDocument == options.After) } if fo.Sort != nil { - sort, err := transformBsoncoreDocument(coll.registry, fo.Sort, false, "sort") + if isUnorderedMap(fo.Sort) { + return &SingleResult{err: ErrMapForOrderedArgument{"sort"}} + } + sort, err := marshal(fo.Sort, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1640,14 +1712,17 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} op = op.Upsert(*fo.Upsert) } if fo.Hint != nil { - hint, err := transformValue(coll.registry, fo.Hint, false, "hint") + if isUnorderedMap(fo.Hint) { + return &SingleResult{err: ErrMapForOrderedArgument{"hint"}} + } + hint, err := marshalValue(fo.Hint, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } op = op.Hint(hint) } if fo.Let != nil { - let, err := transformBsoncoreDocument(coll.registry, fo.Let, true, "let") + let, err := marshal(fo.Let, coll.bsonOpts, coll.registry) if err != nil { return &SingleResult{err: err} } @@ -1677,6 +1752,7 @@ func (coll *Collection) Watch(ctx context.Context, pipeline interface{}, readConcern: coll.readConcern, readPreference: coll.readPreference, client: coll.client, + bsonOpts: coll.bsonOpts, registry: coll.registry, streamType: CollectionStream, collectionName: coll.Name(), @@ -1715,7 +1791,7 @@ func (coll *Collection) Drop(ctx context.Context) error { // dropEncryptedCollection drops a collection with EncryptedFields. func (coll *Collection) dropEncryptedCollection(ctx context.Context, ef interface{}) error { - efBSON, err := transformBsoncoreDocument(coll.registry, ef, true /* mapAllowed */, "encryptedFields") + efBSON, err := marshal(ef, coll.bsonOpts, coll.registry) if err != nil { return fmt.Errorf("error transforming document: %v", err) } @@ -1828,3 +1904,11 @@ func makeOutputAggregateSelector(sess *session.Client, rp *readpref.ReadPref, lo }) return makePinnedSelector(sess, selector) } + +// isUnorderedMap returns true if val is a map with more than 1 element. It is typically used to +// check for unordered Go values that are used in nested command documents where different field +// orders mean different things. Examples are the "sort" and "hint" fields. +func isUnorderedMap(val interface{}) bool { + refValue := reflect.ValueOf(val) + return refValue.Kind() == reflect.Map && refValue.Len() > 1 +} diff --git a/mongo/collection_test.go b/mongo/collection_test.go index 91c197a746..f17e6dfda7 100644 --- a/mongo/collection_test.go +++ b/mongo/collection_test.go @@ -206,7 +206,7 @@ func TestCollection(t *testing.T) { _, err = coll.BulkWrite(bgCtx, []WriteModel{nil}) assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) - aggErr := errors.New("can only transform slices and arrays into aggregation pipelines, but got invalid") + aggErr := errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid") _, err = coll.Aggregate(bgCtx, nil) assert.Equal(t, aggErr, err, "expected error %v, got %v", aggErr, err) diff --git a/mongo/cursor.go b/mongo/cursor.go index 26842eec53..9b348cb46a 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -15,6 +15,8 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/bsonrw" + "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/session" @@ -31,17 +33,27 @@ type Cursor struct { bc batchCursor batch *bsoncore.DocumentSequence batchLength int + bsonOpts *options.BSONOptions registry *bsoncodec.Registry clientSession *session.Client err error } -func newCursor(bc batchCursor, registry *bsoncodec.Registry) (*Cursor, error) { - return newCursorWithSession(bc, registry, nil) +func newCursor( + bc batchCursor, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, +) (*Cursor, error) { + return newCursorWithSession(bc, bsonOpts, registry, nil) } -func newCursorWithSession(bc batchCursor, registry *bsoncodec.Registry, clientSession *session.Client) (*Cursor, error) { +func newCursorWithSession( + bc batchCursor, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, + clientSession *session.Client, +) (*Cursor, error) { if registry == nil { registry = bson.DefaultRegistry } @@ -50,6 +62,7 @@ func newCursorWithSession(bc batchCursor, registry *bsoncodec.Registry, clientSe } c := &Cursor{ bc: bc, + bsonOpts: bsonOpts, registry: registry, clientSession: clientSession, } @@ -203,10 +216,62 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool { } } +func getDecoder( + data []byte, + opts *options.BSONOptions, + reg *bsoncodec.Registry, +) (*bson.Decoder, error) { + dec, err := bson.NewDecoder(bsonrw.NewBSONDocumentReader(data)) + if err != nil { + return nil, err + } + + if opts != nil { + if opts.AllowTruncatingDoubles { + dec.AllowTruncatingDoubles() + } + if opts.BinaryAsSlice { + dec.BinaryAsSlice() + } + if opts.DefaultDocumentD { + dec.DefaultDocumentD() + } + if opts.DefaultDocumentM { + dec.DefaultDocumentM() + } + if opts.UseJSONStructTags { + dec.UseJSONStructTags() + } + if opts.UseLocalTimeZone { + dec.UseLocalTimeZone() + } + if opts.ZeroMaps { + dec.ZeroMaps() + } + if opts.ZeroStructs { + dec.ZeroStructs() + } + } + + if reg != nil { + // TODO:(GODRIVER-2719): Remove error handling. + if err := dec.SetRegistry(reg); err != nil { + return nil, err + } + } + + return dec, nil +} + // Decode will unmarshal the current document into val and return any errors from the unmarshalling process without any // modification. If val is nil or is a typed nil, an error will be returned. func (c *Cursor) Decode(val interface{}) error { - return bson.UnmarshalWithRegistry(c.registry, c.Current, val) + dec, err := getDecoder(c.Current, c.bsonOpts, c.registry) + if err != nil { + return fmt.Errorf("error configuring BSON decoder: %w", err) + } + + return dec.Decode(val) } // Err returns the last error seen by the Cursor, or nil if no error has occurred. @@ -295,7 +360,12 @@ func (c *Cursor) addFromBatch(sliceVal reflect.Value, elemType reflect.Type, bat } currElem := sliceVal.Index(index).Addr().Interface() - if err = bson.UnmarshalWithRegistry(c.registry, doc, currElem); err != nil { + dec, err := getDecoder(doc, c.bsonOpts, c.registry) + if err != nil { + return sliceVal, index, fmt.Errorf("error configuring BSON decoder: %w", err) + } + err = dec.Decode(currElem) + if err != nil { return sliceVal, index, err } diff --git a/mongo/cursor_test.go b/mongo/cursor_test.go index ee16075574..c83d269e35 100644 --- a/mongo/cursor_test.go +++ b/mongo/cursor_test.go @@ -13,6 +13,8 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" + "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" ) @@ -96,14 +98,14 @@ func TestCursor(t *testing.T) { t.Run("TestAll", func(t *testing.T) { t.Run("errors if argument is not pointer to slice", func(t *testing.T) { - cursor, err := newCursor(newTestBatchCursor(1, 5), nil) + cursor, err := newCursor(newTestBatchCursor(1, 5), nil, nil) assert.Nil(t, err, "newCursor error: %v", err) err = cursor.All(context.Background(), []bson.D{}) assert.NotNil(t, err, "expected error, got nil") }) t.Run("fills slice with all documents", func(t *testing.T) { - cursor, err := newCursor(newTestBatchCursor(1, 5), nil) + cursor, err := newCursor(newTestBatchCursor(1, 5), nil, nil) assert.Nil(t, err, "newCursor error: %v", err) var docs []bson.D @@ -118,7 +120,7 @@ func TestCursor(t *testing.T) { }) t.Run("decodes each document into slice type", func(t *testing.T) { - cursor, err := newCursor(newTestBatchCursor(1, 5), nil) + cursor, err := newCursor(newTestBatchCursor(1, 5), nil, nil) assert.Nil(t, err, "newCursor error: %v", err) type Document struct { @@ -136,7 +138,7 @@ func TestCursor(t *testing.T) { }) t.Run("multiple batches are included", func(t *testing.T) { - cursor, err := newCursor(newTestBatchCursor(2, 5), nil) + cursor, err := newCursor(newTestBatchCursor(2, 5), nil, nil) assert.Nil(t, err, "newCursor error: %v", err) var docs []bson.D err = cursor.All(context.Background(), &docs) @@ -153,7 +155,7 @@ func TestCursor(t *testing.T) { var docs []bson.D tbc := newTestBatchCursor(1, 5) - cursor, err := newCursor(tbc, nil) + cursor, err := newCursor(tbc, nil, nil) assert.Nil(t, err, "newCursor error: %v", err) err = cursor.All(context.Background(), &docs) @@ -164,7 +166,7 @@ func TestCursor(t *testing.T) { t.Run("does not error given interface as parameter", func(t *testing.T) { var docs interface{} = []bson.D{} - cursor, err := newCursor(newTestBatchCursor(1, 5), nil) + cursor, err := newCursor(newTestBatchCursor(1, 5), nil, nil) assert.Nil(t, err, "newCursor error: %v", err) err = cursor.All(context.Background(), &docs) @@ -174,12 +176,33 @@ func TestCursor(t *testing.T) { t.Run("errors when not given pointer to slice", func(t *testing.T) { var docs interface{} = "test" - cursor, err := newCursor(newTestBatchCursor(1, 5), nil) + cursor, err := newCursor(newTestBatchCursor(1, 5), nil, nil) assert.Nil(t, err, "newCursor error: %v", err) err = cursor.All(context.Background(), &docs) assert.NotNil(t, err, "expected error, got: %v", err) }) + t.Run("with BSONOptions", func(t *testing.T) { + cursor, err := newCursor( + newTestBatchCursor(1, 5), + &options.BSONOptions{ + UseJSONStructTags: true, + }, + nil) + require.NoError(t, err, "newCursor error") + + type myDocument struct { + A int32 `json:"foo"` + } + var got []myDocument + + err = cursor.All(context.Background(), &got) + require.NoError(t, err, "All error") + + want := []myDocument{{A: 0}, {A: 1}, {A: 2}, {A: 3}, {A: 4}} + + assert.Equal(t, want, got, "expected and actual All results are different") + }) }) } diff --git a/mongo/database.go b/mongo/database.go index 26b216c943..8dd0352aed 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -39,6 +39,7 @@ type Database struct { readPreference *readpref.ReadPref readSelector description.ServerSelector writeSelector description.ServerSelector + bsonOpts *options.BSONOptions registry *bsoncodec.Registry } @@ -60,6 +61,11 @@ func newDatabase(client *Client, name string, opts ...*options.DatabaseOptions) wc = dbOpt.WriteConcern } + bsonOpts := client.bsonOpts + if dbOpt.BSONOptions != nil { + bsonOpts = dbOpt.BSONOptions + } + reg := client.registry if dbOpt.Registry != nil { reg = dbOpt.Registry @@ -71,6 +77,7 @@ func newDatabase(client *Client, name string, opts ...*options.DatabaseOptions) readPreference: rp, readConcern: rc, writeConcern: wc, + bsonOpts: bsonOpts, registry: reg, } @@ -150,7 +157,11 @@ func (db *Database) processRunCommand(ctx context.Context, cmd interface{}, return nil, sess, errors.New("read preference in a transaction must be primary") } - runCmdDoc, err := transformBsoncoreDocument(db.registry, cmd, false, "cmd") + if isUnorderedMap(cmd) { + return nil, sess, ErrMapForOrderedArgument{"cmd"} + } + + runCmdDoc, err := marshal(cmd, db.bsonOpts, db.registry) if err != nil { return nil, sess, err } @@ -208,10 +219,11 @@ func (db *Database) RunCommand(ctx context.Context, runCommand interface{}, opts // RunCommand can be used to run a write, thus execute may return a write error _, convErr := processWriteError(err) return &SingleResult{ - ctx: ctx, - err: convErr, - rdr: bson.Raw(op.Result()), - reg: db.registry, + ctx: ctx, + err: convErr, + rdr: bson.Raw(op.Result()), + bsonOpts: db.bsonOpts, + reg: db.registry, } } @@ -250,7 +262,7 @@ func (db *Database) RunCommandCursor(ctx context.Context, runCommand interface{} closeImplicitSession(sess) return nil, replaceErrors(err) } - cursor, err := newCursorWithSession(bc, db.registry, sess) + cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess) return cursor, replaceErrors(err) } @@ -353,7 +365,7 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt ctx = context.Background() } - filterDoc, err := transformBsoncoreDocument(db.registry, filter, true, "filter") + filterDoc, err := marshal(filter, db.bsonOpts, db.registry) if err != nil { return nil, err } @@ -411,7 +423,7 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt closeImplicitSession(sess) return nil, replaceErrors(err) } - cursor, err := newCursorWithSession(bc, db.registry, sess) + cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess) return cursor, replaceErrors(err) } @@ -575,7 +587,7 @@ func (db *Database) getEncryptedFieldsFromMap(collectionName string) interface{} // createCollectionWithEncryptedFields creates a collection with an EncryptedFields. func (db *Database) createCollectionWithEncryptedFields(ctx context.Context, name string, ef interface{}, opts ...*options.CreateCollectionOptions) error { - efBSON, err := transformBsoncoreDocument(db.registry, ef, true /* mapAllowed */, "encryptedFields") + efBSON, err := marshal(ef, db.bsonOpts, db.registry) if err != nil { return fmt.Errorf("error transforming document: %v", err) } @@ -663,7 +675,7 @@ func (db *Database) createCollectionOperation(name string, opts ...*options.Crea op.Collation(bsoncore.Document(cco.Collation.ToDocument())) } if cco.ChangeStreamPreAndPostImages != nil { - csppi, err := transformBsoncoreDocument(db.registry, cco.ChangeStreamPreAndPostImages, true, "changeStreamPreAndPostImages") + csppi, err := marshal(cco.ChangeStreamPreAndPostImages, db.bsonOpts, db.registry) if err != nil { return nil, err } @@ -672,7 +684,7 @@ func (db *Database) createCollectionOperation(name string, opts ...*options.Crea if cco.DefaultIndexOptions != nil { idx, doc := bsoncore.AppendDocumentStart(nil) if cco.DefaultIndexOptions.StorageEngine != nil { - storageEngine, err := transformBsoncoreDocument(db.registry, cco.DefaultIndexOptions.StorageEngine, true, "storageEngine") + storageEngine, err := marshal(cco.DefaultIndexOptions.StorageEngine, db.bsonOpts, db.registry) if err != nil { return nil, err } @@ -693,7 +705,7 @@ func (db *Database) createCollectionOperation(name string, opts ...*options.Crea op.Size(*cco.SizeInBytes) } if cco.StorageEngine != nil { - storageEngine, err := transformBsoncoreDocument(db.registry, cco.StorageEngine, true, "storageEngine") + storageEngine, err := marshal(cco.StorageEngine, db.bsonOpts, db.registry) if err != nil { return nil, err } @@ -706,7 +718,7 @@ func (db *Database) createCollectionOperation(name string, opts ...*options.Crea op.ValidationLevel(*cco.ValidationLevel) } if cco.Validator != nil { - validator, err := transformBsoncoreDocument(db.registry, cco.Validator, true, "validator") + validator, err := marshal(cco.Validator, db.bsonOpts, db.registry) if err != nil { return nil, err } @@ -746,7 +758,7 @@ func (db *Database) createCollectionOperation(name string, opts ...*options.Crea op.TimeSeries(doc) } if cco.ClusteredIndex != nil { - clusteredIndex, err := transformBsoncoreDocument(db.registry, cco.ClusteredIndex, true, "clusteredIndex") + clusteredIndex, err := marshal(cco.ClusteredIndex, db.bsonOpts, db.registry) if err != nil { return nil, err } @@ -772,7 +784,7 @@ func (db *Database) createCollectionOperation(name string, opts ...*options.Crea func (db *Database) CreateView(ctx context.Context, viewName, viewOn string, pipeline interface{}, opts ...*options.CreateViewOptions) error { - pipelineArray, _, err := transformAggregatePipeline(db.registry, pipeline) + pipelineArray, _, err := marshalAggregatePipeline(pipeline, db.bsonOpts, db.registry) if err != nil { return err } diff --git a/mongo/database_test.go b/mongo/database_test.go index 3232fb70a9..46e1fc3f19 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -139,7 +139,7 @@ func TestDatabase(t *testing.T) { assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err) _, err = db.Watch(context.Background(), nil) - watchErr := errors.New("can only transform slices and arrays into aggregation pipelines, but got invalid") + watchErr := errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid") assert.Equal(t, watchErr, err, "expected error %v, got %v", watchErr, err) _, err = db.ListCollections(context.Background(), nil) diff --git a/mongo/gridfs/bucket.go b/mongo/gridfs/bucket.go index f2c6b00b0e..e4b8db5245 100644 --- a/mongo/gridfs/bucket.go +++ b/mongo/gridfs/bucket.go @@ -650,6 +650,8 @@ func (b *Bucket) parseUploadOptions(opts ...*options.UploadOptions) (*Upload, er uo.Registry = bson.DefaultRegistry } if uo.Metadata != nil { + // TODO(GODRIVER-2726): Replace with marshal() and unmarshal() once the + // TODO gridfs package is merged into the mongo package. raw, err := bson.MarshalWithRegistry(uo.Registry, uo.Metadata) if err != nil { return nil, err diff --git a/mongo/index_view.go b/mongo/index_view.go index 3500b775fb..502de2f2f1 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -122,7 +122,7 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption closeImplicitSession(sess) return nil, replaceErrors(err) } - cursor, err := newCursorWithSession(bc, iv.coll.registry, sess) + cursor, err := newCursorWithSession(bc, iv.coll.bsonOpts, iv.coll.registry, sess) return cursor, replaceErrors(err) } @@ -181,7 +181,11 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. return nil, fmt.Errorf("index model keys cannot be nil") } - keys, err := transformBsoncoreDocument(iv.coll.registry, model.Keys, false, "keys") + if isUnorderedMap(model.Keys) { + return nil, ErrMapForOrderedArgument{"keys"} + } + + keys, err := marshal(model.Keys, iv.coll.bsonOpts, iv.coll.registry) if err != nil { return nil, err } @@ -250,7 +254,7 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. Deployment(iv.coll.client.deployment).ServerSelector(selector).ServerAPI(iv.coll.client.serverAPI). Timeout(iv.coll.client.timeout).MaxTime(option.MaxTime) if option.CommitQuorum != nil { - commitQuorum, err := transformValue(iv.coll.registry, option.CommitQuorum, true, "commitQuorum") + commitQuorum, err := marshalValue(option.CommitQuorum, iv.coll.bsonOpts, iv.coll.registry) if err != nil { return nil, err } @@ -282,7 +286,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum optsDoc = bsoncore.AppendBooleanElement(optsDoc, "sparse", *opts.Sparse) } if opts.StorageEngine != nil { - doc, err := transformBsoncoreDocument(iv.coll.registry, opts.StorageEngine, true, "storageEngine") + doc, err := marshal(opts.StorageEngine, iv.coll.bsonOpts, iv.coll.registry) if err != nil { return nil, err } @@ -305,7 +309,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum optsDoc = bsoncore.AppendInt32Element(optsDoc, "textIndexVersion", *opts.TextVersion) } if opts.Weights != nil { - doc, err := transformBsoncoreDocument(iv.coll.registry, opts.Weights, true, "weights") + doc, err := marshal(opts.Weights, iv.coll.bsonOpts, iv.coll.registry) if err != nil { return nil, err } @@ -328,7 +332,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum optsDoc = bsoncore.AppendInt32Element(optsDoc, "bucketSize", *opts.BucketSize) } if opts.PartialFilterExpression != nil { - doc, err := transformBsoncoreDocument(iv.coll.registry, opts.PartialFilterExpression, true, "partialFilterExpression") + doc, err := marshal(opts.PartialFilterExpression, iv.coll.bsonOpts, iv.coll.registry) if err != nil { return nil, err } @@ -339,7 +343,7 @@ func (iv IndexView) createOptionsDoc(opts *options.IndexOptions) (bsoncore.Docum optsDoc = bsoncore.AppendDocumentElement(optsDoc, "collation", bsoncore.Document(opts.Collation.ToDocument())) } if opts.WildcardProjection != nil { - doc, err := transformBsoncoreDocument(iv.coll.registry, opts.WildcardProjection, true, "wildcardProjection") + doc, err := marshal(opts.WildcardProjection, iv.coll.bsonOpts, iv.coll.registry) if err != nil { return nil, err } diff --git a/mongo/integration/change_stream_test.go b/mongo/integration/change_stream_test.go index 8abf3bd3bf..bcdab4a3c3 100644 --- a/mongo/integration/change_stream_test.go +++ b/mongo/integration/change_stream_test.go @@ -8,6 +8,7 @@ package integration import ( "context" + "sync" "testing" "time" @@ -15,6 +16,7 @@ import ( "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/internal/testutil/monitor" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/integration/mtest" @@ -682,6 +684,43 @@ func TestChangeStream_ReplicaSet(t *testing.T) { assert.True(mt, ok, "expected field 'allChangesForCluster' to be boolean, got %v", acfcVal.Type.String()) assert.False(mt, acfc, "expected field 'allChangesForCluster' to be false, got %v", acfc) }) + + withBSONOpts := mtest.NewOptions().ClientOptions( + options.Client().SetBSONOptions(&options.BSONOptions{ + UseJSONStructTags: true, + })) + mt.RunOpts("with BSONOptions", withBSONOpts, func(mt *mtest.T) { + cs, err := mt.Coll.Watch(context.Background(), mongo.Pipeline{}) + require.NoError(mt, err, "Watch error") + defer closeStream(cs) + + type myDocument struct { + A string `json:"x"` + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _, err := mt.Coll.InsertOne(context.Background(), myDocument{A: "foo"}) + require.NoError(mt, err, "InsertOne error") + }() + + cs.Next(context.Background()) + + var got struct { + FullDocument myDocument `bson:"fullDocument"` + } + err = cs.Decode(&got) + require.NoError(mt, err, "Decode error") + + want := myDocument{ + A: "foo", + } + assert.Equal(mt, want, got.FullDocument, "expected and actual Decode results are different") + + wg.Wait() + }) } func closeStream(cs *mongo.ChangeStream) { diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index e34df34a4f..267bd4aa7c 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -23,6 +23,7 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/internal/testutil" "go.mongodb.org/mongo-driver/internal/testutil/helpers" "go.mongodb.org/mongo-driver/internal/testutil/monitor" @@ -792,6 +793,126 @@ func TestClient(t *testing.T) { }) } +func TestClient_BSONOptions(t *testing.T) { + t.Parallel() + + mt := mtest.New(t, noClientOpts) + defer mt.Close() + + type jsonTagsTest struct { + A string + B string `json:"x"` + C string `json:"y" bson:"3"` + } + + testCases := []struct { + name string + bsonOpts *options.BSONOptions + doc interface{} + decodeInto func() interface{} + want interface{} + wantRaw bson.Raw + }{ + { + name: "UseJSONStructTags", + bsonOpts: &options.BSONOptions{ + UseJSONStructTags: true, + }, + doc: jsonTagsTest{ + A: "apple", + B: "banana", + C: "carrot", + }, + decodeInto: func() interface{} { return &jsonTagsTest{} }, + want: &jsonTagsTest{ + A: "apple", + B: "banana", + C: "carrot", + }, + wantRaw: bson.Raw(bsoncore.NewDocumentBuilder(). + AppendString("a", "apple"). + AppendString("x", "banana"). + AppendString("3", "carrot"). + Build()), + }, + { + name: "IntMinSize", + bsonOpts: &options.BSONOptions{ + IntMinSize: true, + }, + doc: bson.D{{Key: "x", Value: int64(1)}}, + decodeInto: func() interface{} { return &bson.D{} }, + want: &bson.D{{Key: "x", Value: int32(1)}}, + wantRaw: bson.Raw(bsoncore.NewDocumentBuilder(). + AppendInt32("x", 1). + Build()), + }, + { + name: "DefaultDocumentM", + bsonOpts: &options.BSONOptions{ + DefaultDocumentM: true, + }, + doc: bson.D{{Key: "doc", Value: bson.D{{Key: "a", Value: int64(1)}}}}, + decodeInto: func() interface{} { return &bson.D{} }, + want: &bson.D{{Key: "doc", Value: bson.M{"a": int64(1)}}}, + }, + } + + for _, tc := range testCases { + opts := mtest.NewOptions().ClientOptions( + options.Client().SetBSONOptions(tc.bsonOpts)) + mt.RunOpts(tc.name, opts, func(mt *mtest.T) { + res, err := mt.Coll.InsertOne(context.Background(), tc.doc) + require.NoError(mt, err, "InsertOne error") + + sr := mt.Coll.FindOne( + context.Background(), + bson.D{{Key: "_id", Value: res.InsertedID}}, + // Exclude the auto-generated "_id" field so we can make simple + // assertions on the return value. + options.FindOne().SetProjection(bson.D{{Key: "_id", Value: 0}})) + + if tc.want != nil { + got := tc.decodeInto() + err := sr.Decode(got) + require.NoError(mt, err, "Decode error") + + assert.Equal(mt, tc.want, got, "expected and actual decoded result are different") + } + + if tc.wantRaw != nil { + got, err := sr.DecodeBytes() + require.NoError(mt, err, "DecodeBytes error") + + assert.EqualBSON(mt, tc.wantRaw, got) + } + }) + } + + opts := mtest.NewOptions().ClientOptions( + options.Client().SetBSONOptions(&options.BSONOptions{ + ErrorOnInlineDuplicates: true, + })) + mt.RunOpts("ErrorOnInlineDuplicates", opts, func(mt *mtest.T) { + type inlineDupInner struct { + A string + } + + type inlineDupOuter struct { + A string + B *inlineDupInner `bson:"b,inline"` + } + + _, err := mt.Coll.InsertOne(context.Background(), inlineDupOuter{ + A: "outer", + B: &inlineDupInner{ + A: "inner", + }, + }) + require.Error(mt, err, "expected InsertOne to return an error") + }) +} + func TestClientStress(t *testing.T) { mtOpts := mtest.NewOptions().CreateClient(false) mt := mtest.New(t, mtOpts) diff --git a/mongo/mongo.go b/mongo/mongo.go index 6de404bb1a..ded99e4e2b 100644 --- a/mongo/mongo.go +++ b/mongo/mongo.go @@ -7,9 +7,11 @@ package mongo // import "go.mongodb.org/mongo-driver/mongo" import ( + "bytes" "context" "errors" "fmt" + "io" "net" "reflect" "strconv" @@ -20,6 +22,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/bson/bsonrw" "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" ) @@ -53,7 +56,7 @@ func (baf BSONAppenderFunc) AppendBSON(dst []byte, val interface{}) ([]byte, err return baf(dst, val) } -// MarshalError is returned when attempting to transform a value into a document +// MarshalError is returned when attempting to marshal a value into a document // results in an error. type MarshalError struct { Value interface{} @@ -62,7 +65,7 @@ type MarshalError struct { // Error implements the error interface. func (me MarshalError) Error() string { - return fmt.Sprintf("cannot transform type %s to a BSON Document: %v", reflect.TypeOf(me.Value), me.Err) + return fmt.Sprintf("cannot marshal type %s to a BSON Document: %v", reflect.TypeOf(me.Value), me.Err) } // Pipeline is a type that makes creating aggregation pipelines easier. It is a @@ -76,56 +79,69 @@ func (me MarshalError) Error() string { // } type Pipeline []bson.D -// transformAndEnsureID is a hack that makes it easy to get a RawValue as the _id value. -// It will also add an ObjectID _id as the first key if it not already present in the passed-in val. -func transformAndEnsureID(registry *bsoncodec.Registry, val interface{}) (bsoncore.Document, interface{}, error) { - if registry == nil { - registry = bson.NewRegistryBuilder().Build() - } - switch tt := val.(type) { - case nil: - return nil, nil, ErrNilDocument - case []byte: - // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. - val = bson.Raw(tt) - } - - // TODO(skriptble): Use a pool of these instead. - doc := make(bsoncore.Document, 0, 256) - doc, err := bson.MarshalAppendWithRegistry(registry, doc, val) +// bvwPool is a pool of BSON value writers. BSON value writers +var bvwPool = bsonrw.NewBSONValueWriterPool() + +// getEncoder takes a writer, BSON options, and a BSON registry and returns a properly configured +// bson.Encoder that writes to the given writer. +func getEncoder( + w io.Writer, + opts *options.BSONOptions, + reg *bsoncodec.Registry, +) (*bson.Encoder, error) { + vw := bvwPool.Get(w) + enc, err := bson.NewEncoder(vw) if err != nil { - return nil, nil, MarshalError{Value: val, Err: err} + return nil, err } - var id interface{} - - value := doc.Lookup("_id") - switch value.Type { - case bsontype.Type(0): - value = bsoncore.Value{Type: bsontype.ObjectID, Data: bsoncore.AppendObjectID(nil, primitive.NewObjectID())} - olddoc := doc - doc = make(bsoncore.Document, 0, len(olddoc)+17) // type byte + _id + null byte + object ID - _, doc = bsoncore.ReserveLength(doc) - doc = bsoncore.AppendValueElement(doc, "_id", value) - doc = append(doc, olddoc[4:]...) // remove the length - doc = bsoncore.UpdateLength(doc, 0, int32(len(doc))) - default: - // We copy the bytes here to ensure that any bytes returned to the user aren't modified - // later. - buf := make([]byte, len(value.Data)) - copy(buf, value.Data) - value.Data = buf + if opts != nil { + if opts.ErrorOnInlineDuplicates { + enc.ErrorOnInlineDuplicates() + } + if opts.IntMinSize { + enc.IntMinSize() + } + if opts.NilByteSliceAsEmpty { + enc.NilByteSliceAsEmpty() + } + if opts.NilMapAsEmpty { + enc.NilMapAsEmpty() + } + if opts.NilSliceAsEmpty { + enc.NilSliceAsEmpty() + } + if opts.OmitZeroStruct { + enc.OmitZeroStruct() + } + if opts.StringifyMapKeysWithFmt { + enc.StringifyMapKeysWithFmt() + } + if opts.UseJSONStructTags { + enc.UseJSONStructTags() + } } - err = bson.RawValue{Type: value.Type, Value: value.Data}.UnmarshalWithRegistry(registry, &id) - if err != nil { - return nil, nil, err + if reg != nil { + // TODO:(GODRIVER-2719): Remove error handling. + if err := enc.SetRegistry(reg); err != nil { + return nil, err + } } - return doc, id, nil + return enc, nil } -func transformBsoncoreDocument(registry *bsoncodec.Registry, val interface{}, mapAllowed bool, paramName string) (bsoncore.Document, error) { +// marshal marshals the given value as a BSON document. Byte slices are always converted to a +// bson.Raw before marshaling. +// +// If bsonOpts and registry are specified, the encoder is configured with the requested behaviors. +// If they are nil, the default behaviors are used. +func marshal( + val interface{}, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, +) (bsoncore.Document, error) { if registry == nil { registry = bson.DefaultRegistry } @@ -136,20 +152,72 @@ func transformBsoncoreDocument(registry *bsoncodec.Registry, val interface{}, ma // Slight optimization so we'll just use MarshalBSON and not go through the codec machinery. val = bson.Raw(bs) } - if !mapAllowed { - refValue := reflect.ValueOf(val) - if refValue.Kind() == reflect.Map && refValue.Len() > 1 { - return nil, ErrMapForOrderedArgument{paramName} - } + + buf := new(bytes.Buffer) + enc, err := getEncoder(buf, bsonOpts, registry) + if err != nil { + return nil, fmt.Errorf("error configuring BSON encoder: %w", err) } - // TODO(skriptble): Use a pool of these instead. - buf := make([]byte, 0, 256) - b, err := bson.MarshalAppendWithRegistry(registry, buf[:0], val) + err = enc.Encode(val) if err != nil { return nil, MarshalError{Value: val, Err: err} } - return b, nil + + return buf.Bytes(), nil +} + +// ensureID inserts the given ObjectID as an element named "_id" at the +// beginning of the given BSON document if there is not an "_id" already. If +// there is already an element named "_id", the document is not modified. It +// returns the resulting document and the decoded Go value of the "_id" element. +func ensureID( + doc bsoncore.Document, + oid primitive.ObjectID, + bsonOpts *options.BSONOptions, + reg *bsoncodec.Registry, +) (bsoncore.Document, interface{}, error) { + if reg == nil { + reg = bson.DefaultRegistry + } + + // Try to find the "_id" element. If it exists, try to unmarshal just the + // "_id" field as an interface{} and return it along with the unmodified + // BSON document. + if _, err := doc.LookupErr("_id"); err == nil { + var id struct { + ID interface{} `bson:"_id"` + } + dec, err := getDecoder(doc, bsonOpts, reg) + if err != nil { + return nil, nil, fmt.Errorf("error configuring BSON decoder: %w", err) + } + err = dec.Decode(&id) + if err != nil { + return nil, nil, fmt.Errorf("error unmarshaling BSON document: %w", err) + } + + return doc, id.ID, nil + } + + // We couldn't find an "_id" element, so add one with the value of the + // provided ObjectID. + + olddoc := doc + + // Reserve an extra 17 bytes for the "_id" field we're about to add: + // type (1) + "_id" (3) + terminator (1) + object ID (12) + const extraSpace = 17 + doc = make(bsoncore.Document, 0, len(olddoc)+extraSpace) + _, doc = bsoncore.ReserveLength(doc) + doc = bsoncore.AppendObjectIDElement(doc, "_id", oid) + + // Remove and re-write the BSON document length header. + const int32Len = 4 + doc = append(doc, olddoc[int32Len:]...) + doc = bsoncore.UpdateLength(doc, 0, int32(len(doc))) + + return doc, oid, nil } func ensureDollarKey(doc bsoncore.Document) error { @@ -172,7 +240,11 @@ func ensureNoDollarKey(doc bsoncore.Document) error { return nil } -func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface{}) (bsoncore.Document, bool, error) { +func marshalAggregatePipeline( + pipeline interface{}, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, +) (bsoncore.Document, bool, error) { switch t := pipeline.(type) { case bsoncodec.ValueMarshaler: btype, val, err := t.MarshalBSONValue() @@ -198,7 +270,7 @@ func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface default: val := reflect.ValueOf(t) if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) { - return nil, false, fmt.Errorf("can only transform slices and arrays into aggregation pipelines, but got %v", val.Kind()) + return nil, false, fmt.Errorf("can only marshal slices and arrays into aggregation pipelines, but got %v", val.Kind()) } var hasOutputStage bool @@ -212,7 +284,7 @@ func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface return nil, false, fmt.Errorf("%T is not an allowed pipeline type as it represents a single document. Use bson.A or mongo.Pipeline instead", t) } - // bsoncore.Arrays do not need to be transformed. Only check validity and presence of output stage. + // bsoncore.Arrays do not need to be marshaled. Only check validity and presence of output stage. case bsoncore.Array: if err := t.Validate(); err != nil { return nil, false, err @@ -239,7 +311,7 @@ func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface aidx, arr := bsoncore.AppendArrayStart(nil) for idx := 0; idx < valLen; idx++ { - doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface(), true, fmt.Sprintf("pipeline stage :%v", idx)) + doc, err := marshal(val.Index(idx).Interface(), bsonOpts, registry) if err != nil { return nil, false, err } @@ -256,7 +328,12 @@ func transformAggregatePipeline(registry *bsoncodec.Registry, pipeline interface } } -func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, dollarKeysAllowed bool) (bsoncore.Value, error) { +func marshalUpdateValue( + update interface{}, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, + dollarKeysAllowed bool, +) (bsoncore.Value, error) { documentCheckerFunc := ensureDollarKey if !dollarKeysAllowed { documentCheckerFunc = ensureNoDollarKey @@ -269,7 +346,7 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, doll return u, ErrNilDocument case primitive.D: u.Type = bsontype.EmbeddedDocument - u.Data, err = transformBsoncoreDocument(registry, update, true, "update") + u.Data, err = marshal(update, bsonOpts, registry) if err != nil { return u, err } @@ -307,11 +384,11 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, doll default: val := reflect.ValueOf(t) if !val.IsValid() { - return u, fmt.Errorf("can only transform slices and arrays into update pipelines, but got %v", val.Kind()) + return u, fmt.Errorf("can only marshal slices and arrays into update pipelines, but got %v", val.Kind()) } if val.Kind() != reflect.Slice && val.Kind() != reflect.Array { u.Type = bsontype.EmbeddedDocument - u.Data, err = transformBsoncoreDocument(registry, update, true, "update") + u.Data, err = marshal(update, bsonOpts, registry) if err != nil { return u, err } @@ -323,7 +400,7 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, doll aidx, arr := bsoncore.AppendArrayStart(nil) valLen := val.Len() for idx := 0; idx < valLen; idx++ { - doc, err := transformBsoncoreDocument(registry, val.Index(idx).Interface(), true, "update") + doc, err := marshal(val.Index(idx).Interface(), bsonOpts, registry) if err != nil { return u, err } @@ -339,7 +416,11 @@ func transformUpdateValue(registry *bsoncodec.Registry, update interface{}, doll } } -func transformValue(registry *bsoncodec.Registry, val interface{}, mapAllowed bool, paramName string) (bsoncore.Value, error) { +func marshalValue( + val interface{}, + bsonOpts *options.BSONOptions, + registry *bsoncodec.Registry, +) (bsoncore.Value, error) { if registry == nil { registry = bson.DefaultRegistry } @@ -347,25 +428,29 @@ func transformValue(registry *bsoncodec.Registry, val interface{}, mapAllowed bo return bsoncore.Value{}, ErrNilValue } - if !mapAllowed { - refValue := reflect.ValueOf(val) - if refValue.Kind() == reflect.Map && refValue.Len() > 1 { - return bsoncore.Value{}, ErrMapForOrderedArgument{paramName} - } + buf := new(bytes.Buffer) + enc, err := getEncoder(buf, bsonOpts, registry) + if err != nil { + return bsoncore.Value{}, fmt.Errorf("error configuring BSON encoder: %w", err) } - buf := make([]byte, 0, 256) - bsonType, bsonValue, err := bson.MarshalValueAppendWithRegistry(registry, buf[:0], val) + // Encode the value in a single-element document with an empty key. Use bsoncore to extract the + // first element and return the BSON value. + err = enc.Encode(bson.D{{Key: "", Value: val}}) if err != nil { return bsoncore.Value{}, MarshalError{Value: val, Err: err} } - - return bsoncore.Value{Type: bsonType, Data: bsonValue}, nil + return bsoncore.Document(buf.Bytes()).Index(0).Value(), nil } // Build the aggregation pipeline for the CountDocument command. -func countDocumentsAggregatePipeline(registry *bsoncodec.Registry, filter interface{}, opts *options.CountOptions) (bsoncore.Document, error) { - filterDoc, err := transformBsoncoreDocument(registry, filter, true, "filter") +func countDocumentsAggregatePipeline( + filter interface{}, + encOpts *options.BSONOptions, + registry *bsoncodec.Registry, + opts *options.CountOptions, +) (bsoncore.Document, error) { + filterDoc, err := marshal(filter, encOpts, registry) if err != nil { return nil, err } diff --git a/mongo/mongo_test.go b/mongo/mongo_test.go index 0189f438d7..74cf7594ca 100644 --- a/mongo/mongo_test.go +++ b/mongo/mongo_test.go @@ -16,452 +16,575 @@ import ( "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" + "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) -func TestMongoHelpers(t *testing.T) { - t.Run("transform and ensure ID", func(t *testing.T) { - t.Run("newly added _id should be first element", func(t *testing.T) { - doc := bson.D{{"foo", "bar"}, {"baz", "qux"}, {"hello", "world"}} - got, id, err := transformAndEnsureID(bson.DefaultRegistry, doc) - assert.Nil(t, err, "transformAndEnsureID error: %v", err) - oid, ok := id.(primitive.ObjectID) - assert.True(t, ok, "expected returned id type %T, got %T", primitive.ObjectID{}, id) - wantDoc := bson.D{ - {"_id", oid}, {"foo", "bar"}, - {"baz", "qux"}, {"hello", "world"}, +func TestEnsureID(t *testing.T) { + t.Parallel() + + oid := primitive.NewObjectID() + + testCases := []struct { + description string + // TODO: Registry? DecodeOptions? + doc bsoncore.Document + oid primitive.ObjectID + want bsoncore.Document + wantID interface{} + }{ + { + description: "missing _id should be first element", + doc: bsoncore.NewDocumentBuilder(). + AppendString("foo", "bar"). + AppendString("baz", "quix"). + AppendString("hello", "world"). + Build(), + want: bsoncore.NewDocumentBuilder(). + AppendObjectID("_id", oid). + AppendString("foo", "bar"). + AppendString("baz", "quix"). + AppendString("hello", "world"). + Build(), + wantID: oid, + }, + { + description: "existing ObjectID _id as should remain in place", + doc: bsoncore.NewDocumentBuilder(). + AppendString("foo", "bar"). + AppendObjectID("_id", oid). + AppendString("baz", "quix"). + AppendString("hello", "world"). + Build(), + want: bsoncore.NewDocumentBuilder(). + AppendString("foo", "bar"). + AppendObjectID("_id", oid). + AppendString("baz", "quix"). + AppendString("hello", "world"). + Build(), + wantID: oid, + }, + { + description: "existing float _id as should remain in place", + doc: bsoncore.NewDocumentBuilder(). + AppendString("foo", "bar"). + AppendDouble("_id", 3.14159). + AppendString("baz", "quix"). + AppendString("hello", "world"). + Build(), + want: bsoncore.NewDocumentBuilder(). + AppendString("foo", "bar"). + AppendDouble("_id", 3.14159). + AppendString("baz", "quix"). + AppendString("hello", "world"). + Build(), + wantID: 3.14159, + }, + { + description: "existing float _id as first element should remain first element", + doc: bsoncore.NewDocumentBuilder(). + AppendDouble("_id", 3.14159). + AppendString("foo", "bar"). + AppendString("baz", "quix"). + AppendString("hello", "world"). + Build(), + want: bsoncore.NewDocumentBuilder(). + AppendDouble("_id", 3.14159). + AppendString("foo", "bar"). + AppendString("baz", "quix"). + AppendString("hello", "world"). + Build(), + wantID: 3.14159, + }, + { + description: "existing binary _id as first field should not be overwritten", + doc: bsoncore.NewDocumentBuilder(). + AppendBinary("bin", 0, []byte{0, 0, 0}). + AppendString("_id", "LongEnoughIdentifier"). + Build(), + want: bsoncore.NewDocumentBuilder(). + AppendBinary("bin", 0, []byte{0, 0, 0}). + AppendString("_id", "LongEnoughIdentifier"). + Build(), + wantID: "LongEnoughIdentifier", + }, + } + + for _, tc := range testCases { + tc := tc // Capture range varible. + + t.Run(tc.description, func(t *testing.T) { + t.Parallel() + + got, gotID, err := ensureID(tc.doc, oid, nil, nil) + require.NoError(t, err, "ensureID error") + + assert.Equal(t, tc.want, got, "expected and actual documents are different") + assert.Equal(t, tc.wantID, gotID, "expected and actual IDs are different") + + // Ensure that if the unmarshaled "_id" value is a + // primitive.ObjectID that it is a deep copy and does not share any + // memory with the document byte slice. + if oid, ok := gotID.(primitive.ObjectID); ok { + assert.DifferentAddressRanges(t, tc.doc, oid[:]) } - _, wantBSON, err := bson.MarshalValue(wantDoc) - assert.Nil(t, err, "MarshalValue error: %v", err) - want := bsoncore.Document(wantBSON) - assert.Equal(t, want, got, "expected document %v, got %v", want, got) }) - t.Run("existing _id as should remain in place", func(t *testing.T) { - doc := bson.D{{"foo", "bar"}, {"_id", 3.14159}, {"baz", "qux"}, {"hello", "world"}} - got, id, err := transformAndEnsureID(bson.DefaultRegistry, doc) - assert.Nil(t, err, "transformAndEnsureID error: %v", err) - _, ok := id.(float64) - assert.True(t, ok, "expected returned id type %T, got %T", float64(0), id) - _, wantBSON, err := bson.MarshalValue(doc) - assert.Nil(t, err, "MarshalValue error: %v", err) - want := bsoncore.Document(wantBSON) - assert.Equal(t, want, got, "expected document %v, got %v", want, got) - }) - t.Run("existing _id as first element should remain first element", func(t *testing.T) { - doc := bson.D{{"_id", 3.14159}, {"foo", "bar"}, {"baz", "qux"}, {"hello", "world"}} - got, id, err := transformAndEnsureID(bson.DefaultRegistry, doc) - assert.Nil(t, err, "transformAndEnsureID error: %v", err) - _, ok := id.(float64) - assert.True(t, ok, "expected returned id type %T, got %T", float64(0), id) - _, wantBSON, err := bson.MarshalValue(doc) - assert.Nil(t, err, "MarshalValue error: %v", err) - want := bsoncore.Document(wantBSON) - assert.Equal(t, want, got, "expected document %v, got %v", want, got) - }) - t.Run("existing _id should not overwrite a first binary field", func(t *testing.T) { - doc := bson.D{{"bin", []byte{0, 0, 0}}, {"_id", "LongEnoughIdentifier"}} - got, id, err := transformAndEnsureID(bson.DefaultRegistry, doc) - assert.Nil(t, err, "transformAndEnsureID error: %v", err) - _, ok := id.(string) - assert.True(t, ok, "expected returned id type string, got %T", id) - _, wantBSON, err := bson.MarshalValue(doc) - assert.Nil(t, err, "MarshalValue error: %v", err) - want := bsoncore.Document(wantBSON) - assert.Equal(t, want, got, "expected document %v, got %v", want, got) - }) - }) - t.Run("transform aggregate pipeline", func(t *testing.T) { - // []byte of [{{"$limit", 12345}}] - index, arr := bsoncore.AppendArrayStart(nil) - dindex, arr := bsoncore.AppendDocumentElementStart(arr, "0") - arr = bsoncore.AppendInt32Element(arr, "$limit", 12345) - arr, _ = bsoncore.AppendDocumentEnd(arr, dindex) - arr, _ = bsoncore.AppendArrayEnd(arr, index) + } +} - // []byte of {{"x", 1}} - index, doc := bsoncore.AppendDocumentStart(nil) - doc = bsoncore.AppendInt32Element(doc, "x", 1) - doc, _ = bsoncore.AppendDocumentEnd(doc, index) +func TestMarshalAggregatePipeline(t *testing.T) { + // []byte of [{{"$limit", 12345}}] + index, arr := bsoncore.AppendArrayStart(nil) + dindex, arr := bsoncore.AppendDocumentElementStart(arr, "0") + arr = bsoncore.AppendInt32Element(arr, "$limit", 12345) + arr, _ = bsoncore.AppendDocumentEnd(arr, dindex) + arr, _ = bsoncore.AppendArrayEnd(arr, index) - // bsoncore.Array of [{{"$merge", {}}}] - mergeStage := bsoncore.NewDocumentBuilder(). - StartDocument("$merge"). - FinishDocument(). - Build() - arrMergeStage := bsoncore.NewArrayBuilder().AppendDocument(mergeStage).Build() + // []byte of {{"x", 1}} + index, doc := bsoncore.AppendDocumentStart(nil) + doc = bsoncore.AppendInt32Element(doc, "x", 1) + doc, _ = bsoncore.AppendDocumentEnd(doc, index) - fooStage := bsoncore.NewDocumentBuilder().AppendString("foo", "bar").Build() - bazStage := bsoncore.NewDocumentBuilder().AppendString("baz", "qux").Build() - outStage := bsoncore.NewDocumentBuilder().AppendString("$out", "myColl").Build() + // bsoncore.Array of [{{"$merge", {}}}] + mergeStage := bsoncore.NewDocumentBuilder(). + StartDocument("$merge"). + FinishDocument(). + Build() + arrMergeStage := bsoncore.NewArrayBuilder().AppendDocument(mergeStage).Build() - // bsoncore.Array of [{{"foo", "bar"}}, {{"baz", "qux"}}, {{"$out", "myColl"}}] - arrOutStage := bsoncore.NewArrayBuilder(). - AppendDocument(fooStage). - AppendDocument(bazStage). - AppendDocument(outStage). - Build() + fooStage := bsoncore.NewDocumentBuilder().AppendString("foo", "bar").Build() + bazStage := bsoncore.NewDocumentBuilder().AppendString("baz", "qux").Build() + outStage := bsoncore.NewDocumentBuilder().AppendString("$out", "myColl").Build() - // bsoncore.Array of [{{"foo", "bar"}}, {{"$out", "myColl"}}, {{"baz", "qux"}}] - arrMiddleOutStage := bsoncore.NewArrayBuilder(). - AppendDocument(fooStage). - AppendDocument(outStage). - AppendDocument(bazStage). - Build() + // bsoncore.Array of [{{"foo", "bar"}}, {{"baz", "qux"}}, {{"$out", "myColl"}}] + arrOutStage := bsoncore.NewArrayBuilder(). + AppendDocument(fooStage). + AppendDocument(bazStage). + AppendDocument(outStage). + Build() - testCases := []struct { - name string - pipeline interface{} - arr bson.A - hasOutputStage bool - err error - }{ - { - "Pipeline/error", - Pipeline{{{"hello", func() {}}}}, - nil, - false, - MarshalError{Value: primitive.D{}, Err: errors.New("no encoder found for func()")}, - }, - { - "Pipeline/success", - Pipeline{{{"hello", "world"}}, {{"pi", 3.14159}}}, - bson.A{ - bson.D{{"hello", "world"}}, - bson.D{{"pi", 3.14159}}, - }, - false, - nil, - }, - { - "bson.A", - bson.A{ - bson.D{{"$limit", 12345}}, - }, - bson.A{ - bson.D{{"$limit", 12345}}, - }, - false, - nil, - }, - { - "[]bson.D", - []bson.D{{{"$limit", 12345}}}, - bson.A{ - bson.D{{"$limit", 12345}}, - }, - false, - nil, - }, - { - "primitive.A/error", - primitive.A{"5"}, - nil, - false, - MarshalError{Value: "", Err: errors.New("WriteString can only write while positioned on a Element or Value but is positioned on a TopLevel")}, - }, - { - "primitive.A/success", - primitive.A{bson.D{{"$limit", int32(12345)}}, map[string]interface{}{"$count": "foobar"}}, - bson.A{ - bson.D{{"$limit", int(12345)}}, - bson.D{{"$count", "foobar"}}, - }, - false, - nil, - }, - { - "bson.A/error", - bson.A{"5"}, - nil, - false, - MarshalError{Value: "", Err: errors.New("WriteString can only write while positioned on a Element or Value but is positioned on a TopLevel")}, - }, - { - "bson.A/success", - bson.A{bson.D{{"$limit", int32(12345)}}, map[string]interface{}{"$count": "foobar"}}, - bson.A{ - bson.D{{"$limit", int32(12345)}}, - bson.D{{"$count", "foobar"}}, - }, - false, - nil, - }, - { - "[]interface{}/error", - []interface{}{"5"}, - nil, - false, - MarshalError{Value: "", Err: errors.New("WriteString can only write while positioned on a Element or Value but is positioned on a TopLevel")}, - }, - { - "[]interface{}/success", - []interface{}{bson.D{{"$limit", int32(12345)}}, map[string]interface{}{"$count": "foobar"}}, - bson.A{ - bson.D{{"$limit", int32(12345)}}, - bson.D{{"$count", "foobar"}}, - }, - false, - nil, - }, - { - "bsoncodec.ValueMarshaler/MarshalBSONValue error", - bvMarsh{err: errors.New("MarshalBSONValue error")}, - nil, - false, - errors.New("MarshalBSONValue error"), - }, - { - "bsoncodec.ValueMarshaler/not array", - bvMarsh{t: bsontype.String}, - nil, - false, - fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", bsontype.String, bsontype.Array), - }, - { - "bsoncodec.ValueMarshaler/UnmarshalBSONValue error", - bvMarsh{err: errors.New("UnmarshalBSONValue error")}, - nil, - false, - errors.New("UnmarshalBSONValue error"), - }, - { - "bsoncodec.ValueMarshaler/success", - bvMarsh{t: bsontype.Array, data: arr}, - bson.A{ - bson.D{{"$limit", int32(12345)}}, - }, - false, - nil, - }, - { - "bsoncodec.ValueMarshaler/success nil", - bvMarsh{t: bsontype.Array}, - nil, - false, - nil, - }, - { - "nil", - nil, - nil, - false, - errors.New("can only transform slices and arrays into aggregation pipelines, but got invalid"), - }, - { - "not array or slice", - int64(42), - nil, - false, - errors.New("can only transform slices and arrays into aggregation pipelines, but got int64"), - }, - { - "array/error", - [1]interface{}{int64(42)}, - nil, - false, - MarshalError{Value: int64(0), Err: errors.New("WriteInt64 can only write while positioned on a Element or Value but is positioned on a TopLevel")}, - }, - { - "array/success", - [1]interface{}{primitive.D{{"$limit", int64(12345)}}}, - bson.A{ - bson.D{{"$limit", int64(12345)}}, - }, - false, - nil, - }, - { - "slice/error", - []interface{}{int64(42)}, - nil, - false, - MarshalError{Value: int64(0), Err: errors.New("WriteInt64 can only write while positioned on a Element or Value but is positioned on a TopLevel")}, - }, - { - "slice/success", - []interface{}{primitive.D{{"$limit", int64(12345)}}}, - bson.A{ - bson.D{{"$limit", int64(12345)}}, - }, - false, - nil, - }, - { - "hasOutputStage/out", - bson.A{ - bson.D{{"$out", bson.D{ + // bsoncore.Array of [{{"foo", "bar"}}, {{"$out", "myColl"}}, {{"baz", "qux"}}] + arrMiddleOutStage := bsoncore.NewArrayBuilder(). + AppendDocument(fooStage). + AppendDocument(outStage). + AppendDocument(bazStage). + Build() + + testCases := []struct { + name string + pipeline interface{} + arr bson.A + hasOutputStage bool + err error + }{ + { + "Pipeline/error", + Pipeline{{{"hello", func() {}}}}, + nil, + false, + MarshalError{Value: primitive.D{}, Err: errors.New("no encoder found for func()")}, + }, + { + "Pipeline/success", + Pipeline{{{"hello", "world"}}, {{"pi", 3.14159}}}, + bson.A{ + bson.D{{"hello", "world"}}, + bson.D{{"pi", 3.14159}}, + }, + false, + nil, + }, + { + "bson.A", + bson.A{ + bson.D{{"$limit", 12345}}, + }, + bson.A{ + bson.D{{"$limit", 12345}}, + }, + false, + nil, + }, + { + "[]bson.D", + []bson.D{{{"$limit", 12345}}}, + bson.A{ + bson.D{{"$limit", 12345}}, + }, + false, + nil, + }, + { + "primitive.A/error", + primitive.A{"5"}, + nil, + false, + MarshalError{Value: "", Err: errors.New("WriteString can only write while positioned on a Element or Value but is positioned on a TopLevel")}, + }, + { + "primitive.A/success", + primitive.A{bson.D{{"$limit", int32(12345)}}, map[string]interface{}{"$count": "foobar"}}, + bson.A{ + bson.D{{"$limit", int(12345)}}, + bson.D{{"$count", "foobar"}}, + }, + false, + nil, + }, + { + "bson.A/error", + bson.A{"5"}, + nil, + false, + MarshalError{Value: "", Err: errors.New("WriteString can only write while positioned on a Element or Value but is positioned on a TopLevel")}, + }, + { + "bson.A/success", + bson.A{bson.D{{"$limit", int32(12345)}}, map[string]interface{}{"$count": "foobar"}}, + bson.A{ + bson.D{{"$limit", int32(12345)}}, + bson.D{{"$count", "foobar"}}, + }, + false, + nil, + }, + { + "[]interface{}/error", + []interface{}{"5"}, + nil, + false, + MarshalError{Value: "", Err: errors.New("WriteString can only write while positioned on a Element or Value but is positioned on a TopLevel")}, + }, + { + "[]interface{}/success", + []interface{}{bson.D{{"$limit", int32(12345)}}, map[string]interface{}{"$count": "foobar"}}, + bson.A{ + bson.D{{"$limit", int32(12345)}}, + bson.D{{"$count", "foobar"}}, + }, + false, + nil, + }, + { + "bsoncodec.ValueMarshaler/MarshalBSONValue error", + bvMarsh{err: errors.New("MarshalBSONValue error")}, + nil, + false, + errors.New("MarshalBSONValue error"), + }, + { + "bsoncodec.ValueMarshaler/not array", + bvMarsh{t: bsontype.String}, + nil, + false, + fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", bsontype.String, bsontype.Array), + }, + { + "bsoncodec.ValueMarshaler/UnmarshalBSONValue error", + bvMarsh{err: errors.New("UnmarshalBSONValue error")}, + nil, + false, + errors.New("UnmarshalBSONValue error"), + }, + { + "bsoncodec.ValueMarshaler/success", + bvMarsh{t: bsontype.Array, data: arr}, + bson.A{ + bson.D{{"$limit", int32(12345)}}, + }, + false, + nil, + }, + { + "bsoncodec.ValueMarshaler/success nil", + bvMarsh{t: bsontype.Array}, + nil, + false, + nil, + }, + { + "nil", + nil, + nil, + false, + errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid"), + }, + { + "not array or slice", + int64(42), + nil, + false, + errors.New("can only marshal slices and arrays into aggregation pipelines, but got int64"), + }, + { + "array/error", + [1]interface{}{int64(42)}, + nil, + false, + MarshalError{Value: int64(0), Err: errors.New("WriteInt64 can only write while positioned on a Element or Value but is positioned on a TopLevel")}, + }, + { + "array/success", + [1]interface{}{primitive.D{{"$limit", int64(12345)}}}, + bson.A{ + bson.D{{"$limit", int64(12345)}}, + }, + false, + nil, + }, + { + "slice/error", + []interface{}{int64(42)}, + nil, + false, + MarshalError{Value: int64(0), Err: errors.New("WriteInt64 can only write while positioned on a Element or Value but is positioned on a TopLevel")}, + }, + { + "slice/success", + []interface{}{primitive.D{{"$limit", int64(12345)}}}, + bson.A{ + bson.D{{"$limit", int64(12345)}}, + }, + false, + nil, + }, + { + "hasOutputStage/out", + bson.A{ + bson.D{{"$out", bson.D{ + {"db", "output-db"}, + {"coll", "output-collection"}, + }}}, + }, + bson.A{ + bson.D{{"$out", bson.D{ + {"db", "output-db"}, + {"coll", "output-collection"}, + }}}, + }, + true, + nil, + }, + { + "hasOutputStage/merge", + bson.A{ + bson.D{{"$merge", bson.D{ + {"into", bson.D{ {"db", "output-db"}, {"coll", "output-collection"}, - }}}, - }, - bson.A{ - bson.D{{"$out", bson.D{ + }}, + }}}, + }, + bson.A{ + bson.D{{"$merge", bson.D{ + {"into", bson.D{ {"db", "output-db"}, {"coll", "output-collection"}, - }}}, - }, - true, - nil, - }, - { - "hasOutputStage/merge", - bson.A{ - bson.D{{"$merge", bson.D{ - {"into", bson.D{ - {"db", "output-db"}, - {"coll", "output-collection"}, - }}, - }}}, - }, - bson.A{ - bson.D{{"$merge", bson.D{ - {"into", bson.D{ - {"db", "output-db"}, - {"coll", "output-collection"}, - }}, - }}}, - }, - true, - nil, - }, - { - "semantic single document/bson.D", - bson.D{{"x", 1}}, - nil, - false, - errors.New("primitive.D is not an allowed pipeline type as it represents a single document. Use bson.A or mongo.Pipeline instead"), - }, - { - "semantic single document/bson.Raw", - bson.Raw(doc), - nil, - false, - errors.New("bson.Raw is not an allowed pipeline type as it represents a single document. Use bson.A or mongo.Pipeline instead"), - }, - { - "semantic single document/bsoncore.Document", - bsoncore.Document(doc), - nil, - false, - errors.New("bsoncore.Document is not an allowed pipeline type as it represents a single document. Use bson.A or mongo.Pipeline instead"), - }, - { - "semantic single document/empty bson.D", - bson.D{}, - bson.A{}, - false, - nil, - }, - { - "semantic single document/empty bson.Raw", - bson.Raw{}, - bson.A{}, - false, - nil, - }, - { - "semantic single document/empty bsoncore.Document", - bsoncore.Document{}, - bson.A{}, - false, - nil, - }, - { - "bsoncore.Array/success", - bsoncore.Array(arr), - bson.A{ - bson.D{{"$limit", int32(12345)}}, - }, - false, - nil, - }, - { - "bsoncore.Array/mergeStage", - arrMergeStage, - bson.A{ - bson.D{{"$merge", bson.D{}}}, - }, - true, - nil, - }, - { - "bsoncore.Array/outStage", - arrOutStage, - bson.A{ - bson.D{{"foo", "bar"}}, - bson.D{{"baz", "qux"}}, - bson.D{{"$out", "myColl"}}, - }, - true, - nil, - }, - { - "bsoncore.Array/middleOutStage", - arrMiddleOutStage, - bson.A{ - bson.D{{"foo", "bar"}}, - bson.D{{"$out", "myColl"}}, - bson.D{{"baz", "qux"}}, - }, - false, - nil, - }, - } + }}, + }}}, + }, + true, + nil, + }, + { + "semantic single document/bson.D", + bson.D{{"x", 1}}, + nil, + false, + errors.New("primitive.D is not an allowed pipeline type as it represents a single document. Use bson.A or mongo.Pipeline instead"), + }, + { + "semantic single document/bson.Raw", + bson.Raw(doc), + nil, + false, + errors.New("bson.Raw is not an allowed pipeline type as it represents a single document. Use bson.A or mongo.Pipeline instead"), + }, + { + "semantic single document/bsoncore.Document", + bsoncore.Document(doc), + nil, + false, + errors.New("bsoncore.Document is not an allowed pipeline type as it represents a single document. Use bson.A or mongo.Pipeline instead"), + }, + { + "semantic single document/empty bson.D", + bson.D{}, + bson.A{}, + false, + nil, + }, + { + "semantic single document/empty bson.Raw", + bson.Raw{}, + bson.A{}, + false, + nil, + }, + { + "semantic single document/empty bsoncore.Document", + bsoncore.Document{}, + bson.A{}, + false, + nil, + }, + { + "bsoncore.Array/success", + bsoncore.Array(arr), + bson.A{ + bson.D{{"$limit", int32(12345)}}, + }, + false, + nil, + }, + { + "bsoncore.Array/mergeStage", + arrMergeStage, + bson.A{ + bson.D{{"$merge", bson.D{}}}, + }, + true, + nil, + }, + { + "bsoncore.Array/outStage", + arrOutStage, + bson.A{ + bson.D{{"foo", "bar"}}, + bson.D{{"baz", "qux"}}, + bson.D{{"$out", "myColl"}}, + }, + true, + nil, + }, + { + "bsoncore.Array/middleOutStage", + arrMiddleOutStage, + bson.A{ + bson.D{{"foo", "bar"}}, + bson.D{{"$out", "myColl"}}, + bson.D{{"baz", "qux"}}, + }, + false, + nil, + }, + } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - arr, hasOutputStage, err := transformAggregatePipeline(bson.NewRegistryBuilder().Build(), tc.pipeline) - assert.Equal(t, tc.hasOutputStage, hasOutputStage, "expected hasOutputStage %v, got %v", - tc.hasOutputStage, hasOutputStage) - if tc.err != nil { - assert.NotNil(t, err) - assert.EqualError(t, err, tc.err.Error()) - } else { - assert.Nil(t, err) - } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + arr, hasOutputStage, err := marshalAggregatePipeline(tc.pipeline, nil, nil) + assert.Equal(t, tc.hasOutputStage, hasOutputStage, "expected hasOutputStage %v, got %v", + tc.hasOutputStage, hasOutputStage) + if tc.err != nil { + assert.NotNil(t, err) + assert.EqualError(t, err, tc.err.Error()) + } else { + assert.Nil(t, err) + } - var expected bsoncore.Document - if tc.arr != nil { - _, expectedBSON, err := bson.MarshalValue(tc.arr) - assert.Nil(t, err, "MarshalValue error: %v", err) - expected = bsoncore.Document(expectedBSON) - } - assert.Equal(t, expected, arr, "expected array %v, got %v", expected, arr) - }) - } - }) - t.Run("transform value", func(t *testing.T) { - valueMarshaler := bvMarsh{ - t: bsontype.String, - data: bsoncore.AppendString(nil, "foo"), - } - doc := bson.D{{"x", 1}} - docBytes, _ := bson.Marshal(doc) + var expected bsoncore.Document + if tc.arr != nil { + _, expectedBSON, err := bson.MarshalValue(tc.arr) + assert.Nil(t, err, "MarshalValue error: %v", err) + expected = bsoncore.Document(expectedBSON) + } + assert.Equal(t, expected, arr, "expected array %v, got %v", expected, arr) + }) + } +} + +func TestMarshalValue(t *testing.T) { + t.Parallel() - testCases := []struct { - name string - value interface{} - err error - bsonType bsontype.Type - bsonValue []byte - }{ - {"nil document", nil, ErrNilValue, 0, nil}, - {"value marshaler", valueMarshaler, nil, valueMarshaler.t, valueMarshaler.data}, - {"document", doc, nil, bsontype.EmbeddedDocument, docBytes}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - res, err := transformValue(nil, tc.value, true, "") - if tc.err != nil { - assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err) - return - } + valueMarshaler := bvMarsh{ + t: bson.TypeString, + data: bsoncore.AppendString(nil, "foo"), + } - assert.Equal(t, tc.bsonType, res.Type, "expected BSON type %s, got %s", tc.bsonType, res.Type) - assert.Equal(t, tc.bsonValue, res.Data, "expected BSON data %v, got %v", tc.bsonValue, res.Data) - }) - } - }) + testCases := []struct { + name string + value interface{} + bsonOpts *options.BSONOptions + registry *bsoncodec.Registry + want bsoncore.Value + wantErr error + }{ + { + name: "nil document", + value: nil, + wantErr: ErrNilValue, + }, + { + name: "value marshaler", + value: valueMarshaler, + want: bsoncore.Value{ + Type: valueMarshaler.t, + Data: valueMarshaler.data, + }, + }, + { + name: "document", + value: bson.D{{Key: "x", Value: int64(1)}}, + want: bsoncore.Value{ + Type: bson.TypeEmbeddedDocument, + Data: bsoncore.NewDocumentBuilder(). + AppendInt64("x", 1). + Build(), + }, + }, + { + name: "custom encode options", + value: struct { + Int int64 + NilBytes []byte + NilMap map[string]interface{} + NilStrings []string + ZeroStruct struct{ X int } `bson:"_,omitempty"` + StringerMap map[*bson.RawValue]bool + BSONField string `json:"jsonField"` + }{ + Int: 1, + NilBytes: nil, + NilMap: nil, + NilStrings: nil, + StringerMap: map[*bson.RawValue]bool{{}: true}, + }, + bsonOpts: &options.BSONOptions{ + IntMinSize: true, + NilByteSliceAsEmpty: true, + NilMapAsEmpty: true, + NilSliceAsEmpty: true, + OmitZeroStruct: true, + StringifyMapKeysWithFmt: true, + UseJSONStructTags: true, + }, + want: bsoncore.Value{ + Type: bson.TypeEmbeddedDocument, + Data: bsoncore.NewDocumentBuilder(). + AppendInt32("int", 1). + AppendBinary("nilbytes", 0, []byte{}). + AppendDocument("nilmap", bsoncore.NewDocumentBuilder().Build()). + AppendArray("nilstrings", bsoncore.NewArrayBuilder().Build()). + AppendDocument("stringermap", bsoncore.NewDocumentBuilder(). + AppendBoolean("", true). + Build()). + AppendString("jsonField", ""). + Build(), + }, + }, + } + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got, err := marshalValue(tc.value, tc.bsonOpts, tc.registry) + assert.EqualBSON(t, tc.want, got) + assert.Equal(t, tc.wantErr, err, "expected and actual error do not match") + }) + } } var _ bsoncodec.ValueMarshaler = bvMarsh{} diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index ee63be1728..1c2e5bed51 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -92,6 +92,88 @@ type Credential struct { PasswordSet bool } +// BSONOptions are optional BSON marshaling and unmarshaling behaviors. +type BSONOptions struct { + // UseJSONStructTags causes the driver to fall back to using the "json" + // struct tag if a "bson" struct tag is not specified. + UseJSONStructTags bool + + // ErrorOnInlineDuplicates causes the driver to return an error if there is + // a duplicate field in the marshaled BSON when the "inline" struct tag + // option is set. + ErrorOnInlineDuplicates bool + + // IntMinSize causes the driver to marshal Go integer values (int, int8, + // int16, int32, int64, uint, uint8, uint16, uint32, or uint64) as the + // minimum BSON int size (either 32 or 64 bits) that can represent the + // integer value. + IntMinSize bool + + // NilMapAsEmpty causes the driver to marshal nil Go maps as empty BSON + // documents instead of BSON null. + // + // Empty BSON documents take up slightly more space than BSON null, but + // preserve the ability to use document update operations like "$set" that + // do not work on BSON null. + NilMapAsEmpty bool + + // NilSliceAsEmpty causes the driver to marshal nil Go slices as empty BSON + // arrays instead of BSON null. + // + // Empty BSON arrays take up slightly more space than BSON null, but + // preserve the ability to use array update operations like "$push" or + // "$addToSet" that do not work on BSON null. + NilSliceAsEmpty bool + + // NilByteSliceAsEmpty causes the driver to marshal nil Go byte slices as + // empty BSON binary values instead of BSON null. + NilByteSliceAsEmpty bool + + // OmitZeroStruct causes the driver to consider the zero value for a struct + // (e.g. MyStruct{}) as empty and omit it from the marshaled BSON when the + // "omitempty" struct tag option is set. + OmitZeroStruct bool + + // StringifyMapKeysWithFmt causes the driver to convert Go map keys to BSON + // document field name strings using fmt.Sprint instead of the default + // string conversion logic. + StringifyMapKeysWithFmt bool + + // AllowTruncatingDoubles causes the driver to truncate the fractional part + // of BSON "double" values when attempting to unmarshal them into a Go + // integer (int, int8, int16, int32, or int64) struct field. The truncation + // logic does not apply to BSON "decimal128" values. + AllowTruncatingDoubles bool + + // BinaryAsSlice causes the driver to unmarshal BSON binary field values + // that are the "Generic" or "Old" BSON binary subtype as a Go byte slice + // instead of a primitive.Binary. + BinaryAsSlice bool + + // DefaultDocumentD causes the driver to always unmarshal documents into the + // primitive.D type. This behavior is restricted to data typed as + // "interface{}" or "map[string]interface{}". + DefaultDocumentD bool + + // DefaultDocumentM causes the driver to always unmarshal documents into the + // primitive.M type. This behavior is restricted to data typed as + // "interface{}" or "map[string]interface{}". + DefaultDocumentM bool + + // UseLocalTimeZone causes the driver to unmarshal time.Time values in the + // local timezone instead of the UTC timezone. + UseLocalTimeZone bool + + // ZeroMaps causes the driver to delete any existing values from Go maps in + // the destination value before unmarshaling BSON documents into them. + ZeroMaps bool + + // ZeroStructs causes the driver to delete any existing values from Go + // structs in the destination value before unmarshaling BSON documents into + // them. + ZeroStructs bool +} + // ClientOptions contains options to configure a Client instance. Each option can be set through setter functions. See // documentation for each setter function for an explanation of the option. type ClientOptions struct { @@ -118,6 +200,7 @@ type ClientOptions struct { ServerMonitor *event.ServerMonitor ReadConcern *readconcern.ReadConcern ReadPreference *readpref.ReadPref + BSONOptions *BSONOptions Registry *bsoncodec.Registry ReplicaSet *string RetryReads *bool @@ -669,6 +752,12 @@ func (c *ClientOptions) SetReadPreference(rp *readpref.ReadPref) *ClientOptions return c } +// SetBSONOptions configures optional BSON marshaling and unmarshaling behavior. +func (c *ClientOptions) SetBSONOptions(opts *BSONOptions) *ClientOptions { + c.BSONOptions = opts + return c +} + // SetRegistry specifies the BSON registry to use for BSON marshalling/unmarshalling operations. The default is // bson.DefaultRegistry. func (c *ClientOptions) SetRegistry(registry *bsoncodec.Registry) *ClientOptions { @@ -953,6 +1042,9 @@ func MergeClientOptions(opts ...*ClientOptions) *ClientOptions { if opt.ReadPreference != nil { c.ReadPreference = opt.ReadPreference } + if opt.BSONOptions != nil { + c.BSONOptions = opt.BSONOptions + } if opt.Registry != nil { c.Registry = opt.Registry } diff --git a/mongo/options/collectionoptions.go b/mongo/options/collectionoptions.go index 2ed5e6fa8e..04fda6d779 100644 --- a/mongo/options/collectionoptions.go +++ b/mongo/options/collectionoptions.go @@ -27,6 +27,10 @@ type CollectionOptions struct { // the read preference of the Database used to configure the Collection will be used. ReadPreference *readpref.ReadPref + // BSONOptions configures optional BSON marshaling and unmarshaling + // behavior. + BSONOptions *BSONOptions + // Registry is the BSON registry to marshal and unmarshal documents for operations executed on the Collection. The default value // is nil, which means that the registry of the Database used to configure the Collection will be used. Registry *bsoncodec.Registry @@ -55,6 +59,12 @@ func (c *CollectionOptions) SetReadPreference(rp *readpref.ReadPref) *Collection return c } +// SetBSONOptions configures optional BSON marshaling and unmarshaling behavior. +func (c *CollectionOptions) SetBSONOptions(opts *BSONOptions) *CollectionOptions { + c.BSONOptions = opts + return c +} + // SetRegistry sets the value for the Registry field. func (c *CollectionOptions) SetRegistry(r *bsoncodec.Registry) *CollectionOptions { c.Registry = r diff --git a/mongo/options/dboptions.go b/mongo/options/dboptions.go index e50d9bdaf0..8a380d2168 100644 --- a/mongo/options/dboptions.go +++ b/mongo/options/dboptions.go @@ -27,6 +27,10 @@ type DatabaseOptions struct { // the read preference of the Client used to configure the Database will be used. ReadPreference *readpref.ReadPref + // BSONOptions configures optional BSON marshaling and unmarshaling + // behavior. + BSONOptions *BSONOptions + // Registry is the BSON registry to marshal and unmarshal documents for operations executed on the Database. The default value // is nil, which means that the registry of the Client used to configure the Database will be used. Registry *bsoncodec.Registry @@ -55,6 +59,12 @@ func (d *DatabaseOptions) SetReadPreference(rp *readpref.ReadPref) *DatabaseOpti return d } +// SetBSONOptions configures optional BSON marshaling and unmarshaling behavior. +func (d *DatabaseOptions) SetBSONOptions(opts *BSONOptions) *DatabaseOptions { + d.BSONOptions = opts + return d +} + // SetRegistry sets the value for the Registry field. func (d *DatabaseOptions) SetRegistry(r *bsoncodec.Registry) *DatabaseOptions { d.Registry = r diff --git a/mongo/options/mongooptions.go b/mongo/options/mongooptions.go index aa1795f9dd..fd17ce44e1 100644 --- a/mongo/options/mongooptions.go +++ b/mongo/options/mongooptions.go @@ -109,6 +109,10 @@ const ( WhenAvailable FullDocument = "whenAvailable" ) +// TODO(GODRIVER-2617): Once Registry is removed, ArrayFilters doesn't need to +// TODO be a separate type. Remove the type and update all ArrayFilters fields +// TODO to be type []interface{}. + // ArrayFilters is used to hold filters for the array filters CRUD option. If a registry is nil, bson.DefaultRegistry // will be used when converting the filter interfaces to BSON. type ArrayFilters struct { diff --git a/mongo/single_result.go b/mongo/single_result.go index 89a73a46e0..9c9b4f4fc6 100644 --- a/mongo/single_result.go +++ b/mongo/single_result.go @@ -9,9 +9,11 @@ package mongo import ( "context" "errors" + "fmt" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsoncodec" + "go.mongodb.org/mongo-driver/mongo/options" ) // ErrNoDocuments is returned by SingleResult methods when the operation that created the SingleResult did not return @@ -22,11 +24,12 @@ var ErrNoDocuments = errors.New("mongo: no documents in result") // SingleResult methods will return that error. If the operation did not return any documents, all SingleResult methods // will return ErrNoDocuments. type SingleResult struct { - ctx context.Context - err error - cur *Cursor - rdr bson.Raw - reg *bsoncodec.Registry + ctx context.Context + err error + cur *Cursor + rdr bson.Raw + bsonOpts *options.BSONOptions + reg *bsoncodec.Registry } // NewSingleResultFromDocument creates a SingleResult with the provided error, registry, and an underlying Cursor pre-loaded with @@ -71,7 +74,13 @@ func (sr *SingleResult) Decode(v interface{}) error { if sr.err = sr.setRdrContents(); sr.err != nil { return sr.err } - return bson.UnmarshalWithRegistry(sr.reg, sr.rdr, v) + + dec, err := getDecoder(sr.rdr, sr.bsonOpts, sr.reg) + if err != nil { + return fmt.Errorf("error configuring BSON decoder: %w", err) + } + + return dec.Decode(v) } // DecodeBytes will return the document represented by this SingleResult as a bson.Raw. If there was an error from the diff --git a/mongo/single_result_test.go b/mongo/single_result_test.go index 0fba2705ce..3e561208ce 100644 --- a/mongo/single_result_test.go +++ b/mongo/single_result_test.go @@ -14,13 +14,15 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" + "go.mongodb.org/mongo-driver/mongo/options" ) func TestSingleResult(t *testing.T) { t.Run("Decode", func(t *testing.T) { t.Run("decode twice", func(t *testing.T) { // Test that Decode and DecodeBytes can be called more than once - c, err := newCursor(newTestBatchCursor(1, 1), bson.DefaultRegistry) + c, err := newCursor(newTestBatchCursor(1, 1), nil, bson.DefaultRegistry) assert.Nil(t, err, "newCursor error: %v", err) sr := &SingleResult{cur: c, reg: bson.DefaultRegistry} @@ -44,6 +46,31 @@ func TestSingleResult(t *testing.T) { assert.Equal(t, r, resBytes, "expected contents %v, got %v", r, resBytes) assert.Equal(t, sr.err, err, "expected error %v, got %v", sr.err, err) }) + t.Run("with BSONOptions", func(t *testing.T) { + c, err := newCursor(newTestBatchCursor(1, 1), nil, bson.DefaultRegistry) + require.NoError(t, err, "newCursor error") + + sr := &SingleResult{ + cur: c, + bsonOpts: &options.BSONOptions{ + UseJSONStructTags: true, + }, + reg: bson.DefaultRegistry, + } + + type myDocument struct { + A *int32 `json:"foo"` + } + + var got myDocument + err = sr.Decode(&got) + require.NoError(t, err, "Decode error") + + i := int32(0) + want := myDocument{A: &i} + + assert.Equal(t, want, got, "expected and actual Decode results are different") + }) }) t.Run("Err", func(t *testing.T) { diff --git a/x/mongo/driver/operation/find_and_modify.go b/x/mongo/driver/operation/find_and_modify.go index 7c4cb527bf..38d001c716 100644 --- a/x/mongo/driver/operation/find_and_modify.go +++ b/x/mongo/driver/operation/find_and_modify.go @@ -24,7 +24,7 @@ import ( // FindAndModify performs a findAndModify operation. type FindAndModify struct { - arrayFilters bsoncore.Document + arrayFilters bsoncore.Array bypassDocumentValidation *bool collation bsoncore.Document comment bsoncore.Value @@ -215,7 +215,7 @@ func (fam *FindAndModify) command(dst []byte, desc description.SelectedServer) ( } // ArrayFilters specifies an array of filter documents that determines which array elements to modify for an update operation on an array field. -func (fam *FindAndModify) ArrayFilters(arrayFilters bsoncore.Document) *FindAndModify { +func (fam *FindAndModify) ArrayFilters(arrayFilters bsoncore.Array) *FindAndModify { if fam == nil { fam = new(FindAndModify) }