From e9e975d3cc3e980f431e09d76963055ca859b694 Mon Sep 17 00:00:00 2001 From: Matt Dale <9760375+matthewdale@users.noreply.github.com> Date: Tue, 25 Jul 2023 21:59:55 -0700 Subject: [PATCH 01/12] GODRIVER-2867 Unpin connections when ending a session. (#1330) --- mongo/integration/load_balancer_prose_test.go | 29 +++++++++++++++++++ x/mongo/driver/session/client_session.go | 13 +++++++-- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/mongo/integration/load_balancer_prose_test.go b/mongo/integration/load_balancer_prose_test.go index 06b45fa814..b9c5c84d4a 100644 --- a/mongo/integration/load_balancer_prose_test.go +++ b/mongo/integration/load_balancer_prose_test.go @@ -15,6 +15,7 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/integration/mtest" "go.mongodb.org/mongo-driver/mongo/options" @@ -99,5 +100,33 @@ func TestLoadBalancerSupport(t *testing.T) { _, err := mt.Coll.InsertOne(ctx, bson.M{"x": 1}) assertErrorHasInfo(mt, err, 0, 1, 0) }) + + // GODRIVER-2867: Test that connections are unpinned from transactions + // when the transaction session is ended. Create a Client with + // maxPoolSize=1 and expect that it can start and commit 5 transactions + // with that 1 connection. + mt.RunOpts("transaction connections are unpinned", maxPoolSizeMtOpts, func(mt *mtest.T) { + { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + defer cancel() + + for i := 0; i < 5; i++ { + sess, err := mt.Client.StartSession() + require.NoError(mt, err, "StartSession error") + + err = sess.StartTransaction() + require.NoError(mt, err, "StartTransaction error") + + ctx := mongo.NewSessionContext(ctx, sess) + _, err = mt.Coll.InsertOne(ctx, bson.M{"x": 1}) + assert.NoError(mt, err, "InsertOne error") + + err = sess.CommitTransaction(ctx) + assert.NoError(mt, err, "CommitTransaction error") + + sess.EndSession(ctx) + } + } + }) }) } diff --git a/x/mongo/driver/session/client_session.go b/x/mongo/driver/session/client_session.go index ba244b101e..8dac0932de 100644 --- a/x/mongo/driver/session/client_session.go +++ b/x/mongo/driver/session/client_session.go @@ -331,9 +331,10 @@ func (c *Client) ClearPinnedResources() error { return nil } -// UnpinConnection gracefully unpins the connection associated with the session if there is one. This is done via -// the pinned connection's UnpinFromTransaction function. -func (c *Client) UnpinConnection() error { +// unpinConnection gracefully unpins the connection associated with the session +// if there is one. This is done via the pinned connection's +// UnpinFromTransaction function. +func (c *Client) unpinConnection() error { if c == nil || c.PinnedConnection == nil { return nil } @@ -353,6 +354,12 @@ func (c *Client) EndSession() { return } c.Terminated = true + + // Ignore the error when unpinning the connection because we can't do + // anything about it if it doesn't work. Typically the only errors that can + // happen here indicate that something went wrong with the connection state, + // like it wasn't marked as pinned or attempted to return to the wrong pool. + _ = c.unpinConnection() c.pool.ReturnSession(c.Server) } From 6c3474948935af10dc57388763d12de6837f1e3a Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 28 Jul 2023 10:50:04 -0600 Subject: [PATCH 02/12] GODRIVER-2896 Add IsZero to BSON RawValue (#1332) Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com> --- bson/bsontype/bsontype.go | 11 +++++ bson/primitive_codecs.go | 18 ++++++-- bson/primitive_codecs_test.go | 38 ++++++++++++++++ bson/raw_value.go | 6 +++ bson/raw_value_test.go | 84 +++++++++++++++++++++++++++++++++++ 5 files changed, 154 insertions(+), 3 deletions(-) diff --git a/bson/bsontype/bsontype.go b/bson/bsontype/bsontype.go index f38c263a4c..8cff5492d1 100644 --- a/bson/bsontype/bsontype.go +++ b/bson/bsontype/bsontype.go @@ -102,3 +102,14 @@ func (bt Type) String() string { return "invalid" } } + +// IsValid will return true if the Type is valid. +func (bt Type) IsValid() bool { + switch bt { + case Double, String, EmbeddedDocument, Array, Binary, Undefined, ObjectID, Boolean, DateTime, Null, Regex, + DBPointer, JavaScript, Symbol, CodeWithScope, Int32, Timestamp, Int64, Decimal128, MinKey, MaxKey: + return true + default: + return false + } +} diff --git a/bson/primitive_codecs.go b/bson/primitive_codecs.go index 6b9602589c..ff32a87a79 100644 --- a/bson/primitive_codecs.go +++ b/bson/primitive_codecs.go @@ -8,6 +8,7 @@ package bson import ( "errors" + "fmt" "reflect" "go.mongodb.org/mongo-driver/bson/bsoncodec" @@ -45,15 +46,26 @@ func (pc PrimitiveCodecs) RegisterPrimitiveCodecs(rb *bsoncodec.RegistryBuilder) // RawValueEncodeValue is the ValueEncoderFunc for RawValue. // -// Deprecated: Use bson.NewRegistry to get a registry with all primitive encoders and decoders -// registered. +// If the RawValue's Type is "invalid" and the RawValue's Value is not empty or +// nil, then this method will return an error. +// +// Deprecated: Use bson.NewRegistry to get a registry with all primitive +// encoders and decoders registered. func (PrimitiveCodecs) RawValueEncodeValue(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRawValue { - return bsoncodec.ValueEncoderError{Name: "RawValueEncodeValue", Types: []reflect.Type{tRawValue}, Received: val} + return bsoncodec.ValueEncoderError{ + Name: "RawValueEncodeValue", + Types: []reflect.Type{tRawValue}, + Received: val, + } } rawvalue := val.Interface().(RawValue) + if !rawvalue.Type.IsValid() { + return fmt.Errorf("the RawValue Type specifies an invalid BSON type: %#x", byte(rawvalue.Type)) + } + return bsonrw.Copier{}.CopyValueFromBytes(vw, rawvalue.Type, rawvalue.Value) } diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index 3fb606d2f4..466f135e83 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -65,6 +65,8 @@ func compareErrors(err1, err2 error) bool { } func TestDefaultValueEncoders(t *testing.T) { + t.Parallel() + var pc PrimitiveCodecs var wrong = func(string, string) string { return "wrong" } @@ -107,6 +109,28 @@ func TestDefaultValueEncoders(t *testing.T) { bsonrwtest.WriteDouble, nil, }, + { + "RawValue Type is zero with non-zero value", + RawValue{ + Type: 0x00, + Value: bsoncore.AppendDouble(nil, 3.14159), + }, + nil, + nil, + bsonrwtest.Nothing, + fmt.Errorf("the RawValue Type specifies an invalid BSON type: 0x0"), + }, + { + "RawValue Type is invalid", + RawValue{ + Type: 0x8F, + Value: bsoncore.AppendDouble(nil, 3.14159), + }, + nil, + nil, + bsonrwtest.Nothing, + fmt.Errorf("the RawValue Type specifies an invalid BSON type: 0x8f"), + }, }, }, { @@ -166,9 +190,17 @@ func TestDefaultValueEncoders(t *testing.T) { } for _, tc := range testCases { + tc := tc // Capture the range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + for _, subtest := range tc.subtests { + subtest := subtest // Capture the range variable + t.Run(subtest.name, func(t *testing.T) { + t.Parallel() + var ec bsoncodec.EncodeContext if subtest.ectx != nil { ec = *subtest.ectx @@ -192,6 +224,8 @@ func TestDefaultValueEncoders(t *testing.T) { } t.Run("success path", func(t *testing.T) { + t.Parallel() + oid := primitive.NewObjectID() oids := []primitive.ObjectID{primitive.NewObjectID(), primitive.NewObjectID(), primitive.NewObjectID()} var str = new(string) @@ -426,7 +460,11 @@ func TestDefaultValueEncoders(t *testing.T) { } for _, tc := range testCases { + tc := tc // Capture the range variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + b := make(bsonrw.SliceWriter, 0, 512) vw, err := bsonrw.NewBSONValueWriter(&b) noerr(t, err) diff --git a/bson/raw_value.go b/bson/raw_value.go index 6627294c4d..4d1bfb3160 100644 --- a/bson/raw_value.go +++ b/bson/raw_value.go @@ -37,6 +37,12 @@ type RawValue struct { r *bsoncodec.Registry } +// IsZero reports whether the RawValue is zero, i.e. no data is present on +// the RawValue. It returns true if Type is 0 and Value is empty or nil. +func (rv RawValue) IsZero() bool { + return rv.Type == 0x00 && len(rv.Value) == 0 +} + // Unmarshal deserializes BSON into the provided val. If RawValue cannot be unmarshaled into val, an // error is returned. This method will use the registry used to create the RawValue, if the RawValue // was created from partial BSON processing, or it will use the default registry. Users wishing to diff --git a/bson/raw_value_test.go b/bson/raw_value_test.go index fbc0715600..87f08c4a55 100644 --- a/bson/raw_value_test.go +++ b/bson/raw_value_test.go @@ -13,12 +13,19 @@ import ( "go.mongodb.org/mongo-driver/bson/bsoncodec" "go.mongodb.org/mongo-driver/bson/bsontype" + "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) func TestRawValue(t *testing.T) { + t.Parallel() + t.Run("Unmarshal", func(t *testing.T) { + t.Parallel() + t.Run("Uses registry attached to value", func(t *testing.T) { + t.Parallel() + reg := bsoncodec.NewRegistryBuilder().Build() val := RawValue{Type: bsontype.String, Value: bsoncore.AppendString(nil, "foobar"), r: reg} var s string @@ -29,6 +36,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Uses default registry if no registry attached", func(t *testing.T) { + t.Parallel() + want := "foobar" val := RawValue{Type: bsontype.String, Value: bsoncore.AppendString(nil, want)} var got string @@ -40,7 +49,11 @@ func TestRawValue(t *testing.T) { }) }) t.Run("UnmarshalWithRegistry", func(t *testing.T) { + t.Parallel() + t.Run("Returns error when registry is nil", func(t *testing.T) { + t.Parallel() + want := ErrNilRegistry var val RawValue got := val.UnmarshalWithRegistry(nil, &D{}) @@ -49,6 +62,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Returns lookup error", func(t *testing.T) { + t.Parallel() + reg := bsoncodec.NewRegistryBuilder().Build() var val RawValue var s string @@ -59,6 +74,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Returns DecodeValue error", func(t *testing.T) { + t.Parallel() + reg := NewRegistryBuilder().Build() val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string @@ -69,6 +86,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Success", func(t *testing.T) { + t.Parallel() + reg := NewRegistryBuilder().Build() want := float64(3.14159) val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, want)} @@ -81,7 +100,11 @@ func TestRawValue(t *testing.T) { }) }) t.Run("UnmarshalWithContext", func(t *testing.T) { + t.Parallel() + t.Run("Returns error when DecodeContext is nil", func(t *testing.T) { + t.Parallel() + want := ErrNilContext var val RawValue got := val.UnmarshalWithContext(nil, &D{}) @@ -90,6 +113,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Returns lookup error", func(t *testing.T) { + t.Parallel() + dc := bsoncodec.DecodeContext{Registry: bsoncodec.NewRegistryBuilder().Build()} var val RawValue var s string @@ -100,6 +125,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Returns DecodeValue error", func(t *testing.T) { + t.Parallel() + dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()} val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string @@ -110,6 +137,8 @@ func TestRawValue(t *testing.T) { } }) t.Run("Success", func(t *testing.T) { + t.Parallel() + dc := bsoncodec.DecodeContext{Registry: NewRegistryBuilder().Build()} want := float64(3.14159) val := RawValue{Type: bsontype.Double, Value: bsoncore.AppendDouble(nil, want)} @@ -121,4 +150,59 @@ func TestRawValue(t *testing.T) { } }) }) + + t.Run("IsZero", func(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + val RawValue + want bool + }{ + { + name: "empty", + val: RawValue{}, + want: true, + }, + { + name: "zero type but non-zero value", + val: RawValue{ + Type: 0x00, + Value: bsoncore.AppendInt32(nil, 0), + }, + want: false, + }, + { + name: "zero type and zero value", + val: RawValue{ + Type: 0x00, + Value: bsoncore.AppendInt32(nil, 0), + }, + }, + { + name: "non-zero type and non-zero value", + val: RawValue{ + Type: bsontype.String, + Value: bsoncore.AppendString(nil, "foobar"), + }, + want: false, + }, + { + name: "non-zero type and zero value", + val: RawValue{ + Type: bsontype.String, + Value: bsoncore.AppendString(nil, "foobar"), + }, + }, + } + + for _, tt := range tests { + tt := tt // Capture the range variable + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.want, tt.val.IsZero()) + }) + } + }) } From e353cb917be81561233395903a5c594b9730b45f Mon Sep 17 00:00:00 2001 From: Qingyang Hu <103950869+qingyang-hu@users.noreply.github.com> Date: Fri, 28 Jul 2023 13:34:14 -0400 Subject: [PATCH 03/12] GODRIVER-2139 Fix invalid field name in regex parse error test. (#1335) --- testdata/bson-corpus/top.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/testdata/bson-corpus/top.json b/testdata/bson-corpus/top.json index 7eca5aa70c..9c649b5e3f 100644 --- a/testdata/bson-corpus/top.json +++ b/testdata/bson-corpus/top.json @@ -96,11 +96,11 @@ }, { "description": "Bad $regularExpression (pattern is number, not string)", - "string": "{\"x\" : {\"$regularExpression\" : { \"pattern\": 42, \"$options\" : \"\"}}}" + "string": "{\"x\" : {\"$regularExpression\" : { \"pattern\": 42, \"options\" : \"\"}}}" }, { "description": "Bad $regularExpression (options are number, not string)", - "string": "{\"x\" : {\"$regularExpression\" : { \"pattern\": \"a\", \"$options\" : 0}}}" + "string": "{\"x\" : {\"$regularExpression\" : { \"pattern\": \"a\", \"options\" : 0}}}" }, { "description" : "Bad $regularExpression (missing pattern field)", From 8759478052ab26194211a4554fab7fdc04d499b4 Mon Sep 17 00:00:00 2001 From: Qingyang Hu <103950869+qingyang-hu@users.noreply.github.com> Date: Tue, 1 Aug 2023 09:02:01 -0400 Subject: [PATCH 04/12] GODRIVER-2246 Sync spec test for prefer-error-code. (#1339) --- .../errors/prefer-error-code.json | 4 ++-- .../errors/prefer-error-code.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/testdata/server-discovery-and-monitoring/errors/prefer-error-code.json b/testdata/server-discovery-and-monitoring/errors/prefer-error-code.json index 21d123f429..eb00b69613 100644 --- a/testdata/server-discovery-and-monitoring/errors/prefer-error-code.json +++ b/testdata/server-discovery-and-monitoring/errors/prefer-error-code.json @@ -52,7 +52,7 @@ } }, { - "description": "errmsg \"not writable primary\" gets ignored when error code exists", + "description": "errmsg \"not master\" gets ignored when error code exists", "applicationErrors": [ { "address": "a:27017", @@ -61,7 +61,7 @@ "type": "command", "response": { "ok": 0, - "errmsg": "not writable primary", + "errmsg": "not master", "code": 1 } } diff --git a/testdata/server-discovery-and-monitoring/errors/prefer-error-code.yml b/testdata/server-discovery-and-monitoring/errors/prefer-error-code.yml index dcbe0da41a..6bd98386bb 100644 --- a/testdata/server-discovery-and-monitoring/errors/prefer-error-code.yml +++ b/testdata/server-discovery-and-monitoring/errors/prefer-error-code.yml @@ -29,7 +29,7 @@ phases: logicalSessionTimeoutMinutes: null setName: rs -- description: errmsg "not writable primary" gets ignored when error code exists +- description: errmsg "not master" gets ignored when error code exists applicationErrors: - address: a:27017 when: afterHandshakeCompletes @@ -37,7 +37,7 @@ phases: type: command response: ok: 0 - errmsg: "not writable primary" + errmsg: "not master" # NOTE: This needs to be "not master" and not "not writable primary". code: 1 # Not a "not writable primary" error code. outcome: *outcome From fb6660f21896a520cf40d2903e54ac2dc9de59ca Mon Sep 17 00:00:00 2001 From: Qingyang Hu <103950869+qingyang-hu@users.noreply.github.com> Date: Fri, 28 Jul 2023 13:34:34 -0400 Subject: [PATCH 05/12] GODRIVER-2253 Remove srvMaxHosts tests expecting an error for invalid values. (#1336) --- .../replica-set/srvMaxHosts-invalid_integer.json | 7 ------- .../replica-set/srvMaxHosts-invalid_integer.yml | 5 ----- .../replica-set/srvMaxHosts-invalid_type.json | 7 ------- .../replica-set/srvMaxHosts-invalid_type.yml | 5 ----- .../sharded/srvMaxHosts-invalid_integer.json | 7 ------- .../sharded/srvMaxHosts-invalid_integer.yml | 5 ----- .../sharded/srvMaxHosts-invalid_type.json | 7 ------- .../sharded/srvMaxHosts-invalid_type.yml | 5 ----- 8 files changed, 48 deletions(-) delete mode 100644 testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_integer.json delete mode 100644 testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_integer.yml delete mode 100644 testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_type.json delete mode 100644 testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_type.yml delete mode 100644 testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_integer.json delete mode 100644 testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_integer.yml delete mode 100644 testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_type.json delete mode 100644 testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_type.yml diff --git a/testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_integer.json b/testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_integer.json deleted file mode 100644 index 5ba1a3b540..0000000000 --- a/testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_integer.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "uri": "mongodb+srv://test1.test.build.10gen.cc/?replicaSet=repl0&srvMaxHosts=-1", - "seeds": [], - "hosts": [], - "error": true, - "comment": "Should fail because srvMaxHosts is not greater than or equal to zero" -} diff --git a/testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_integer.yml b/testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_integer.yml deleted file mode 100644 index c813e95765..0000000000 --- a/testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_integer.yml +++ /dev/null @@ -1,5 +0,0 @@ -uri: "mongodb+srv://test1.test.build.10gen.cc/?replicaSet=repl0&srvMaxHosts=-1" -seeds: [] -hosts: [] -error: true -comment: Should fail because srvMaxHosts is not greater than or equal to zero diff --git a/testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_type.json b/testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_type.json deleted file mode 100644 index 79e75b9b15..0000000000 --- a/testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_type.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "uri": "mongodb+srv://test1.test.build.10gen.cc/?replicaSet=repl0&srvMaxHosts=foo", - "seeds": [], - "hosts": [], - "error": true, - "comment": "Should fail because srvMaxHosts is not an integer" -} diff --git a/testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_type.yml b/testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_type.yml deleted file mode 100644 index 27109b8856..0000000000 --- a/testdata/initial-dns-seedlist-discovery/replica-set/srvMaxHosts-invalid_type.yml +++ /dev/null @@ -1,5 +0,0 @@ -uri: "mongodb+srv://test1.test.build.10gen.cc/?replicaSet=repl0&srvMaxHosts=foo" -seeds: [] -hosts: [] -error: true -comment: Should fail because srvMaxHosts is not an integer diff --git a/testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_integer.json b/testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_integer.json deleted file mode 100644 index 0939624fc3..0000000000 --- a/testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_integer.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "uri": "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=-1", - "seeds": [], - "hosts": [], - "error": true, - "comment": "Should fail because srvMaxHosts is not greater than or equal to zero" -} diff --git a/testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_integer.yml b/testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_integer.yml deleted file mode 100644 index 836e0191fa..0000000000 --- a/testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_integer.yml +++ /dev/null @@ -1,5 +0,0 @@ -uri: "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=-1" -seeds: [] -hosts: [] -error: true -comment: Should fail because srvMaxHosts is not greater than or equal to zero diff --git a/testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_type.json b/testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_type.json deleted file mode 100644 index c228d26612..0000000000 --- a/testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_type.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "uri": "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=foo", - "seeds": [], - "hosts": [], - "error": true, - "comment": "Should fail because srvMaxHosts is not an integer" -} diff --git a/testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_type.yml b/testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_type.yml deleted file mode 100644 index b852934b81..0000000000 --- a/testdata/initial-dns-seedlist-discovery/sharded/srvMaxHosts-invalid_type.yml +++ /dev/null @@ -1,5 +0,0 @@ -uri: "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=foo" -seeds: [] -hosts: [] -error: true -comment: Should fail because srvMaxHosts is not an integer From b687278eeb70d57d0a210a91fb3e87bfe9edd714 Mon Sep 17 00:00:00 2001 From: Qingyang Hu <103950869+qingyang-hu@users.noreply.github.com> Date: Thu, 27 Jul 2023 16:21:40 -0400 Subject: [PATCH 06/12] GODRIVER-2822 Add error on empty for ReadConcern marshaler. (#1327) --- mongo/readconcern/readconcern.go | 6 +++ mongo/readconcern/readconcern_test.go | 64 +++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 mongo/readconcern/readconcern_test.go diff --git a/mongo/readconcern/readconcern.go b/mongo/readconcern/readconcern.go index 987f416055..51408e142d 100644 --- a/mongo/readconcern/readconcern.go +++ b/mongo/readconcern/readconcern.go @@ -11,6 +11,8 @@ package readconcern // import "go.mongodb.org/mongo-driver/mongo/readconcern" import ( + "errors" + "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) @@ -106,6 +108,10 @@ func New(options ...Option) *ReadConcern { // // Deprecated: Marshaling a ReadConcern to BSON will not be supported in Go Driver 2.0. func (rc *ReadConcern) MarshalBSONValue() (bsontype.Type, []byte, error) { + if rc == nil { + return 0, nil, errors.New("cannot marshal nil ReadConcern") + } + var elems []byte if len(rc.Level) > 0 { diff --git a/mongo/readconcern/readconcern_test.go b/mongo/readconcern/readconcern_test.go new file mode 100644 index 0000000000..2f6ea79f3e --- /dev/null +++ b/mongo/readconcern/readconcern_test.go @@ -0,0 +1,64 @@ +// Copyright (C) MongoDB, Inc. 2023-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package readconcern_test + +import ( + "testing" + + "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/mongo/readconcern" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" +) + +func TestReadConcern_MarshalBSONValue(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + rc *readconcern.ReadConcern + bytes []byte + wantErrorMsg *string + }{ + { + name: "local", + rc: readconcern.Local(), + bytes: bsoncore.BuildDocument(nil, bsoncore.AppendStringElement(nil, "level", "local")), + wantErrorMsg: nil, + }, + { + name: "empty", + rc: readconcern.New(), + bytes: bsoncore.BuildDocument(nil, nil), + wantErrorMsg: nil, + }, + { + name: "nil", + rc: nil, + bytes: nil, + wantErrorMsg: func() *string { + msg := "cannot marshal nil ReadConcern" + return &msg + }(), + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + _, b, err := tc.rc.MarshalBSONValue() + assert.Equal(t, tc.bytes, b, "expected and actual outputs do not match") + if tc.wantErrorMsg == nil { + assert.NoError(t, err, "an unexpected error is returned") + } else { + assert.ErrorContains(t, err, *tc.wantErrorMsg, "expected and actual errors do not match") + } + }) + } +} From 989b3e1c8f5e8e07bfcc479eaa1e9e6bfa1eb8cf Mon Sep 17 00:00:00 2001 From: Qingyang Hu <103950869+qingyang-hu@users.noreply.github.com> Date: Fri, 21 Jul 2023 18:16:11 -0400 Subject: [PATCH 07/12] GODRIVER-2851 Fix failing "TestCMAPProse" test. (#1316) --- x/mongo/driver/topology/cmap_prose_test.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/x/mongo/driver/topology/cmap_prose_test.go b/x/mongo/driver/topology/cmap_prose_test.go index 1187a31f7f..53c6e9ba36 100644 --- a/x/mongo/driver/topology/cmap_prose_test.go +++ b/x/mongo/driver/topology/cmap_prose_test.go @@ -106,9 +106,6 @@ func TestCMAPProse(t *testing.T) { }) t.Run("checkOut", func(t *testing.T) { t.Run("connection error publishes events", func(t *testing.T) { - // TODO(GODRIVER-2851): Fix and unskip this test case. - t.Skip("Test fails frequently, skipping. See GODRIVER-2851") - // If checkOut() creates a connection that encounters an error while connecting, // the pool should publish connection created and closed events and checkOut should // return the error. @@ -131,8 +128,7 @@ func TestCMAPProse(t *testing.T) { _, err := pool.checkOut(context.Background()) assert.NotNil(t, err, "expected checkOut() error, got nil") - assert.Equal(t, 1, len(created), "expected 1 opened events, got %d", len(created)) - assert.Equal(t, 1, len(closed), "expected 1 closed events, got %d", len(closed)) + assertConnectionCounts(t, pool, 1, 1) }) t.Run("pool is empty", func(t *testing.T) { // If a checkOut() has to create a new connection and that connection encounters an From 436a9821764514d48feb3362d67133e82df05963 Mon Sep 17 00:00:00 2001 From: Mike Jensen Date: Thu, 22 Jun 2023 15:22:13 -0600 Subject: [PATCH 08/12] GODRIVER-2869 Two protocol validations to reduce client denial of service risks (#1291) Co-authored-by: Alan Parra Co-authored-by: Qingyang Hu <103950869+qingyang-hu@users.noreply.github.com> --- x/bsonx/bsoncore/bsoncore.go | 3 +++ x/bsonx/bsoncore/bsoncore_test.go | 9 +++++++++ x/mongo/driver/compression.go | 6 ++++++ x/mongo/driver/compression_test.go | 17 +++++++++++++++++ 4 files changed, 35 insertions(+) diff --git a/x/bsonx/bsoncore/bsoncore.go b/x/bsonx/bsoncore/bsoncore.go index 94d479428f..e52674aacf 100644 --- a/x/bsonx/bsoncore/bsoncore.go +++ b/x/bsonx/bsoncore/bsoncore.go @@ -825,6 +825,9 @@ func readLengthBytes(src []byte) ([]byte, []byte, bool) { if !ok { return nil, src, false } + if l < 4 { + return nil, src, false + } if len(src) < int(l) { return nil, src, false } diff --git a/x/bsonx/bsoncore/bsoncore_test.go b/x/bsonx/bsoncore/bsoncore_test.go index b7d91a7715..ba2688ebe4 100644 --- a/x/bsonx/bsoncore/bsoncore_test.go +++ b/x/bsonx/bsoncore/bsoncore_test.go @@ -14,6 +14,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/internal/assert" @@ -943,6 +944,14 @@ func TestNullBytes(t *testing.T) { }) } +func TestInvalidBytes(t *testing.T) { + t.Run("read length less than 4 int bytes", func(t *testing.T) { + _, src, ok := readLengthBytes([]byte{0x00, 0x00, 0x00, 0x01}) + assert.False(t, ok, "expected not ok response for invalid length read") + assert.Equal(t, 4, len(src), "expected src to contain the size parameter still") + }) +} + func compareDecimal128(d1, d2 primitive.Decimal128) bool { d1H, d1L := d1.GetBytes() d2H, d2L := d2.GetBytes() diff --git a/x/mongo/driver/compression.go b/x/mongo/driver/compression.go index c474714ff4..7f355f61a4 100644 --- a/x/mongo/driver/compression.go +++ b/x/mongo/driver/compression.go @@ -111,6 +111,12 @@ func DecompressPayload(in []byte, opts CompressionOpts) (uncompressed []byte, er case wiremessage.CompressorNoOp: return in, nil case wiremessage.CompressorSnappy: + l, err := snappy.DecodedLen(in) + if err != nil { + return nil, fmt.Errorf("decoding compressed length %w", err) + } else if int32(l) != opts.UncompressedSize { + return nil, fmt.Errorf("unexpected decompression size, expected %v but got %v", opts.UncompressedSize, l) + } uncompressed = make([]byte, opts.UncompressedSize) return snappy.Decode(uncompressed, in) case wiremessage.CompressorZLib: diff --git a/x/mongo/driver/compression_test.go b/x/mongo/driver/compression_test.go index 5c74bef9ae..8a37de65c2 100644 --- a/x/mongo/driver/compression_test.go +++ b/x/mongo/driver/compression_test.go @@ -41,6 +41,23 @@ func TestCompression(t *testing.T) { } } +func TestDecompressFailures(t *testing.T) { + t.Run("snappy decompress huge size", func(t *testing.T) { + opts := CompressionOpts{ + Compressor: wiremessage.CompressorSnappy, + UncompressedSize: 100, // reasonable size + } + // Compressed data is twice as large as declared above. + // In test we use actual compression so that the decompress action would pass without fix (thus failing test). + // When decompression starts it allocates a buffer of the defined size, regardless of a valid compressed body following. + compressedData, err := CompressPayload(make([]byte, opts.UncompressedSize*2), opts) + assert.NoError(t, err, "premature error making compressed example") + + _, err = DecompressPayload(compressedData, opts) + assert.Error(t, err) + }) +} + func BenchmarkCompressPayload(b *testing.B) { payload := func() []byte { buf, err := os.ReadFile("compression.go") From 9318bc286d4ae3c2618fb3b17cac16ce548bc836 Mon Sep 17 00:00:00 2001 From: Qingyang Hu <103950869+qingyang-hu@users.noreply.github.com> Date: Wed, 28 Jun 2023 15:20:20 -0400 Subject: [PATCH 09/12] GODRIVER-2869 Test touchup (#1307) --- x/bsonx/bsoncore/bsoncore_test.go | 6 +++++- x/mongo/driver/compression_test.go | 4 ++++ x/mongo/driver/operation_test.go | 19 +++++++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/x/bsonx/bsoncore/bsoncore_test.go b/x/bsonx/bsoncore/bsoncore_test.go index ba2688ebe4..ace784c4a8 100644 --- a/x/bsonx/bsoncore/bsoncore_test.go +++ b/x/bsonx/bsoncore/bsoncore_test.go @@ -945,8 +945,12 @@ func TestNullBytes(t *testing.T) { } func TestInvalidBytes(t *testing.T) { + t.Parallel() + t.Run("read length less than 4 int bytes", func(t *testing.T) { - _, src, ok := readLengthBytes([]byte{0x00, 0x00, 0x00, 0x01}) + t.Parallel() + + _, src, ok := readLengthBytes([]byte{0x01, 0x00, 0x00, 0x00}) assert.False(t, ok, "expected not ok response for invalid length read") assert.Equal(t, 4, len(src), "expected src to contain the size parameter still") }) diff --git a/x/mongo/driver/compression_test.go b/x/mongo/driver/compression_test.go index 8a37de65c2..b477cb32c1 100644 --- a/x/mongo/driver/compression_test.go +++ b/x/mongo/driver/compression_test.go @@ -42,7 +42,11 @@ func TestCompression(t *testing.T) { } func TestDecompressFailures(t *testing.T) { + t.Parallel() + t.Run("snappy decompress huge size", func(t *testing.T) { + t.Parallel() + opts := CompressionOpts{ Compressor: wiremessage.CompressorSnappy, UncompressedSize: 100, // reasonable size diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index f52425fe51..49bd46d8fc 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -892,3 +892,22 @@ func TestConvertI64PtrToI32Ptr(t *testing.T) { }) } } + +func TestDecodeOpReply(t *testing.T) { + t.Parallel() + + // GODRIVER-2869: Prevent infinite loop caused by malformatted wiremessage with length of 0. + t.Run("malformatted wiremessage with length of 0", func(t *testing.T) { + t.Parallel() + + var wm []byte + wm = wiremessage.AppendReplyFlags(wm, 0) + wm = wiremessage.AppendReplyCursorID(wm, int64(0)) + wm = wiremessage.AppendReplyStartingFrom(wm, 0) + wm = wiremessage.AppendReplyNumberReturned(wm, 0) + idx, wm := bsoncore.ReserveLength(wm) + wm = bsoncore.UpdateLength(wm, idx, 0) + reply := Operation{}.decodeOpReply(wm) + assert.Equal(t, []bsoncore.Document(nil), reply.documents) + }) +} From 8857a04b3ad7c9060a2335e72561562c64a52db1 Mon Sep 17 00:00:00 2001 From: Charlie Vieth Date: Thu, 6 Jul 2023 14:24:30 -0400 Subject: [PATCH 10/12] GODRIVER-2887 Remove use of reflect.Value.MethodByName in bson (#1308) --- bson/bsoncodec/default_value_decoders.go | 20 +++---- bson/bsoncodec/default_value_decoders_test.go | 11 +++- bson/bsoncodec/default_value_encoders.go | 56 +++++++++++-------- bson/mgocompat/setter_getter.go | 38 +++++++------ 4 files changed, 76 insertions(+), 49 deletions(-) diff --git a/bson/bsoncodec/default_value_decoders.go b/bson/bsoncodec/default_value_decoders.go index b5e22c498a..e479c3585b 100644 --- a/bson/bsoncodec/default_value_decoders.go +++ b/bson/bsoncodec/default_value_decoders.go @@ -1540,12 +1540,12 @@ func (dvd DefaultValueDecoders) ValueUnmarshalerDecodeValue(_ DecodeContext, vr return err } - fn := val.Convert(tValueUnmarshaler).MethodByName("UnmarshalBSONValue") - errVal := fn.Call([]reflect.Value{reflect.ValueOf(t), reflect.ValueOf(src)})[0] - if !errVal.IsNil() { - return errVal.Interface().(error) + m, ok := val.Interface().(ValueUnmarshaler) + if !ok { + // NB: this error should be unreachable due to the above checks + return ValueDecoderError{Name: "ValueUnmarshalerDecodeValue", Types: []reflect.Type{tValueUnmarshaler}, Received: val} } - return nil + return m.UnmarshalBSONValue(t, src) } // UnmarshalerDecodeValue is the ValueDecoderFunc for Unmarshaler implementations. @@ -1588,12 +1588,12 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr bsonr val = val.Addr() // If the type doesn't implement the interface, a pointer to it must. } - fn := val.Convert(tUnmarshaler).MethodByName("UnmarshalBSON") - errVal := fn.Call([]reflect.Value{reflect.ValueOf(src)})[0] - if !errVal.IsNil() { - return errVal.Interface().(error) + m, ok := val.Interface().(Unmarshaler) + if !ok { + // NB: this error should be unreachable due to the above checks + return ValueDecoderError{Name: "UnmarshalerDecodeValue", Types: []reflect.Type{tUnmarshaler}, Received: val} } - return nil + return m.UnmarshalBSON(src) } // EmptyInterfaceDecodeValue is the ValueDecoderFunc for interface{}. diff --git a/bson/bsoncodec/default_value_decoders_test.go b/bson/bsoncodec/default_value_decoders_test.go index d8c8389654..bac92e04f8 100644 --- a/bson/bsoncodec/default_value_decoders_test.go +++ b/bson/bsoncodec/default_value_decoders_test.go @@ -1530,13 +1530,22 @@ func TestDefaultValueDecoders(t *testing.T) { errors.New("copy error"), }, { - "Unmarshaler", + // Only the pointer form of testUnmarshaler implements Unmarshaler + "value does not implement Unmarshaler", testUnmarshaler{Val: bsoncore.AppendDouble(nil, 3.14159)}, nil, &bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)}, bsonrwtest.ReadDouble, nil, }, + { + "Unmarshaler", + &testUnmarshaler{Val: bsoncore.AppendDouble(nil, 3.14159)}, + nil, + &bsonrwtest.ValueReaderWriter{BSONType: bsontype.Double, Return: float64(3.14159)}, + bsonrwtest.ReadDouble, + nil, + }, }, }, { diff --git a/bson/bsoncodec/default_value_encoders.go b/bson/bsoncodec/default_value_encoders.go index 7d526c4ef8..4ab14a668c 100644 --- a/bson/bsoncodec/default_value_encoders.go +++ b/bson/bsoncodec/default_value_encoders.go @@ -564,12 +564,14 @@ func (dve DefaultValueEncoders) ValueMarshalerEncodeValue(_ EncodeContext, vw bs return ValueEncoderError{Name: "ValueMarshalerEncodeValue", Types: []reflect.Type{tValueMarshaler}, Received: val} } - fn := val.Convert(tValueMarshaler).MethodByName("MarshalBSONValue") - returns := fn.Call(nil) - if !returns[2].IsNil() { - return returns[2].Interface().(error) + m, ok := val.Interface().(ValueMarshaler) + if !ok { + return vw.WriteNull() + } + t, data, err := m.MarshalBSONValue() + if err != nil { + return err } - t, data := returns[0].Interface().(bsontype.Type), returns[1].Interface().([]byte) return bsonrw.Copier{}.CopyValueFromBytes(vw, t, data) } @@ -593,12 +595,14 @@ func (dve DefaultValueEncoders) MarshalerEncodeValue(_ EncodeContext, vw bsonrw. return ValueEncoderError{Name: "MarshalerEncodeValue", Types: []reflect.Type{tMarshaler}, Received: val} } - fn := val.Convert(tMarshaler).MethodByName("MarshalBSON") - returns := fn.Call(nil) - if !returns[1].IsNil() { - return returns[1].Interface().(error) + m, ok := val.Interface().(Marshaler) + if !ok { + return vw.WriteNull() + } + data, err := m.MarshalBSON() + if err != nil { + return err } - data := returns[0].Interface().([]byte) return bsonrw.Copier{}.CopyValueFromBytes(vw, bsontype.EmbeddedDocument, data) } @@ -622,23 +626,31 @@ func (dve DefaultValueEncoders) ProxyEncodeValue(ec EncodeContext, vw bsonrw.Val return ValueEncoderError{Name: "ProxyEncodeValue", Types: []reflect.Type{tProxy}, Received: val} } - fn := val.Convert(tProxy).MethodByName("ProxyBSON") - returns := fn.Call(nil) - if !returns[1].IsNil() { - return returns[1].Interface().(error) + m, ok := val.Interface().(Proxy) + if !ok { + return vw.WriteNull() + } + v, err := m.ProxyBSON() + if err != nil { + return err + } + if v == nil { + encoder, err := ec.LookupEncoder(nil) + if err != nil { + return err + } + return encoder.EncodeValue(ec, vw, reflect.ValueOf(nil)) } - data := returns[0] - var encoder ValueEncoder - var err error - if data.Elem().IsValid() { - encoder, err = ec.LookupEncoder(data.Elem().Type()) - } else { - encoder, err = ec.LookupEncoder(nil) + vv := reflect.ValueOf(v) + switch vv.Kind() { + case reflect.Ptr, reflect.Interface: + vv = vv.Elem() } + encoder, err := ec.LookupEncoder(vv.Type()) if err != nil { return err } - return encoder.EncodeValue(ec, vw, data.Elem()) + return encoder.EncodeValue(ec, vw, vv) } // JavaScriptEncodeValue is the ValueEncoderFunc for the primitive.JavaScript type. diff --git a/bson/mgocompat/setter_getter.go b/bson/mgocompat/setter_getter.go index 55af549d40..e9c9cae834 100644 --- a/bson/mgocompat/setter_getter.go +++ b/bson/mgocompat/setter_getter.go @@ -7,6 +7,7 @@ package mgocompat import ( + "errors" "reflect" "go.mongodb.org/mongo-driver/bson" @@ -73,16 +74,15 @@ func SetterDecodeValue(_ bsoncodec.DecodeContext, vr bsonrw.ValueReader, val ref return err } - fn := val.Convert(tSetter).MethodByName("SetBSON") - - errVal := fn.Call([]reflect.Value{reflect.ValueOf(bson.RawValue{Type: t, Value: src})})[0] - if !errVal.IsNil() { - err = errVal.Interface().(error) - if err == ErrSetZero { - val.Set(reflect.Zero(val.Type())) - return nil + m, ok := val.Interface().(Setter) + if !ok { + return bsoncodec.ValueDecoderError{Name: "SetterDecodeValue", Types: []reflect.Type{tSetter}, Received: val} + } + if err := m.SetBSON(bson.RawValue{Type: t, Value: src}); err != nil { + if !errors.Is(err, ErrSetZero) { + return err } - return err + val.Set(reflect.Zero(val.Type())) } return nil } @@ -104,17 +104,23 @@ func GetterEncodeValue(ec bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val re return bsoncodec.ValueEncoderError{Name: "GetterEncodeValue", Types: []reflect.Type{tGetter}, Received: val} } - fn := val.Convert(tGetter).MethodByName("GetBSON") - returns := fn.Call(nil) - if !returns[1].IsNil() { - return returns[1].Interface().(error) + m, ok := val.Interface().(Getter) + if !ok { + return vw.WriteNull() + } + x, err := m.GetBSON() + if err != nil { + return err + } + if x == nil { + return vw.WriteNull() } - intermediate := returns[0] - encoder, err := ec.Registry.LookupEncoder(intermediate.Type()) + vv := reflect.ValueOf(x) + encoder, err := ec.Registry.LookupEncoder(vv.Type()) if err != nil { return err } - return encoder.EncodeValue(ec, vw, intermediate) + return encoder.EncodeValue(ec, vw, vv) } // isImplementationNil returns if val is a nil pointer and inter is implemented on a concrete type From afb541910cef7e700c1d6062b1da434f3a3d72c6 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Thu, 20 Jul 2023 09:27:49 -0600 Subject: [PATCH 11/12] GODRIVER-2891 Add documentation for log levels (#1312) Co-authored-by: Steven Silvester Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com> --- mongo/options/loggeroptions.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mongo/options/loggeroptions.go b/mongo/options/loggeroptions.go index 4a33e449a5..b837935812 100644 --- a/mongo/options/loggeroptions.go +++ b/mongo/options/loggeroptions.go @@ -51,6 +51,18 @@ type LogSink interface { // Info logs a non-error message with the given key/value pairs. This // method will only be called if the provided level has been defined // for a component in the LoggerOptions. + // + // Here are the following level mappings for V = "Verbosity": + // + // - V(0): off + // - V(1): informational + // - V(2): debugging + // + // This level mapping is taken from the go-logr/logr library + // specifications, specifically: + // + // "Level V(0) is the default, and logger.V(0).Info() has the same + // meaning as logger.Info()." Info(level int, message string, keysAndValues ...interface{}) // Error logs an error message with the given key/value pairs From d219098916140466493fcca942a98b51582e32ca Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Tue, 1 Aug 2023 10:29:15 -0600 Subject: [PATCH 12/12] GODRIVER-2881 Enable logging for ComponentAll (#1340) --- internal/logger/logger.go | 20 +++- internal/logger/logger_test.go | 176 +++++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+), 2 deletions(-) diff --git a/internal/logger/logger.go b/internal/logger/logger.go index c4053ea3df..07dcffe66b 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -75,9 +75,25 @@ func (logger *Logger) Close() error { } // LevelComponentEnabled will return true if the given LogLevel is enabled for -// the given LogComponent. +// the given LogComponent. If the ComponentLevels on the logger are enabled for +// "ComponentAll", then this function will return true for any level bound by +// the level assigned to "ComponentAll". +// +// If the level is not enabled (i.e. LevelOff), then false is returned. This is +// to avoid false positives, such as returning "true" for a component that is +// not enabled. For example, without this condition, an empty LevelComponent +// would be considered "enabled" for "LevelOff". func (logger *Logger) LevelComponentEnabled(level Level, component Component) bool { - return logger.ComponentLevels[component] >= level + if level == LevelOff { + return false + } + + if logger.ComponentLevels == nil { + return false + } + + return logger.ComponentLevels[component] >= level || + logger.ComponentLevels[ComponentAll] >= level } // Print will synchronously print the given message to the configured LogSink. diff --git a/internal/logger/logger_test.go b/internal/logger/logger_test.go index eead29c96c..8629a10748 100644 --- a/internal/logger/logger_test.go +++ b/internal/logger/logger_test.go @@ -14,6 +14,8 @@ import ( "reflect" "sync" "testing" + + "go.mongodb.org/mongo-driver/internal/assert" ) type mockLogSink struct{} @@ -334,3 +336,177 @@ func TestTruncate(t *testing.T) { } } + +func TestLogger_LevelComponentEnabled(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + logger Logger + level Level + component Component + want bool + }{ + { + name: "zero", + logger: Logger{}, + level: LevelOff, + component: ComponentCommand, + want: false, + }, + { + name: "empty", + logger: Logger{ + ComponentLevels: map[Component]Level{}, + }, + level: LevelOff, + component: ComponentCommand, + want: false, // LevelOff should never be considered enabled. + }, + { + name: "one level below", + logger: Logger{ + ComponentLevels: map[Component]Level{ + ComponentCommand: LevelDebug, + }, + }, + level: LevelInfo, + component: ComponentCommand, + want: true, + }, + { + name: "equal levels", + logger: Logger{ + ComponentLevels: map[Component]Level{ + ComponentCommand: LevelDebug, + }, + }, + level: LevelDebug, + component: ComponentCommand, + want: true, + }, + { + name: "one level above", + logger: Logger{ + ComponentLevels: map[Component]Level{ + ComponentCommand: LevelInfo, + }, + }, + level: LevelDebug, + component: ComponentCommand, + want: false, + }, + { + name: "component mismatch", + logger: Logger{ + ComponentLevels: map[Component]Level{ + ComponentCommand: LevelDebug, + }, + }, + level: LevelDebug, + component: ComponentTopology, + want: false, + }, + { + name: "component all enables with topology", + logger: Logger{ + ComponentLevels: map[Component]Level{ + ComponentAll: LevelDebug, + }, + }, + level: LevelDebug, + component: ComponentTopology, + want: true, + }, + { + name: "component all enables with server selection", + logger: Logger{ + ComponentLevels: map[Component]Level{ + ComponentAll: LevelDebug, + }, + }, + level: LevelDebug, + component: ComponentServerSelection, + want: true, + }, + { + name: "component all enables with connection", + logger: Logger{ + ComponentLevels: map[Component]Level{ + ComponentAll: LevelDebug, + }, + }, + level: LevelDebug, + component: ComponentConnection, + want: true, + }, + { + name: "component all enables with command", + logger: Logger{ + ComponentLevels: map[Component]Level{ + ComponentAll: LevelDebug, + }, + }, + level: LevelDebug, + component: ComponentCommand, + want: true, + }, + { + name: "component all enables with all", + logger: Logger{ + ComponentLevels: map[Component]Level{ + ComponentAll: LevelDebug, + }, + }, + level: LevelDebug, + component: ComponentAll, + want: true, + }, + { + name: "component all does not enable with lower level", + logger: Logger{ + ComponentLevels: map[Component]Level{ + ComponentAll: LevelInfo, + }, + }, + level: LevelDebug, + component: ComponentCommand, + want: false, + }, + { + name: "component all has a lower log level than command", + logger: Logger{ + ComponentLevels: map[Component]Level{ + ComponentAll: LevelInfo, + ComponentCommand: LevelDebug, + }, + }, + level: LevelDebug, + component: ComponentCommand, + want: true, + }, + { + name: "component all has a higher log level than command", + logger: Logger{ + ComponentLevels: map[Component]Level{ + ComponentAll: LevelDebug, + ComponentCommand: LevelInfo, + }, + }, + level: LevelDebug, + component: ComponentCommand, + want: true, + }, + } + + for _, tcase := range tests { + tcase := tcase // Capture the range variable. + + t.Run(tcase.name, func(t *testing.T) { + t.Parallel() + + got := tcase.logger.LevelComponentEnabled(tcase.level, tcase.component) + assert.Equal(t, tcase.want, got, "unexpected result for LevelComponentEnabled") + }) + } +}