diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 78d1bee892..ec23858817 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -350,6 +350,23 @@ functions: chmod +x $i done + assume-ec2-role: + - command: ec2.assume_role + params: + role_arn: ${aws_test_secrets_role} + + run-oidc-auth-test-with-test-credentials: + - command: shell.exec + type: test + params: + working_dir: src/go.mongodb.org/mongo-driver + shell: bash + include_expansions_in_env: ["DRIVERS_TOOLS", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + script: | + ${PREPARE_SHELL} + export OIDC="oidc" + bash ${PROJECT_DIRECTORY}/etc/run-oidc-test.sh + run-make: - command: shell.exec type: test @@ -1954,6 +1971,10 @@ tasks: popd ./.evergreen/run-deployed-lambda-aws-tests.sh + - name: "oidc-auth-test-latest" + commands: + - func: "run-oidc-auth-test-with-test-credentials" + - name: "test-search-index" commands: - func: "bootstrap-mongo-orchestration" @@ -2247,6 +2268,31 @@ task_groups: tasks: - testazurekms-task + - name: testoidc_task_group + setup_group: + - func: fetch-source + - func: prepare-resources + - func: fix-absolute-paths + - func: make-files-executable + - func: assume-ec2-role + - command: shell.exec + params: + shell: bash + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + script: | + ${PREPARE_SHELL} + ${DRIVERS_TOOLS}/.evergreen/auth_oidc/setup.sh + teardown_task: + - command: subprocess.exec + params: + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/teardown.sh + setup_group_can_fail_task: true + setup_group_timeout_secs: 1800 + tasks: + - oidc-auth-test-latest + - name: test-aws-lambda-task-group setup_group: - func: fetch-source @@ -2586,3 +2632,13 @@ buildvariants: - name: testazurekms_task_group batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README - testazurekms-fail-task + + - name: testoidc-variant + display_name: "OIDC" + run_on: + - ubuntu2204-large + expansions: + GO_DIST: "/opt/golang/go1.20" + tasks: + - name: testoidc_task_group + batchtime: 20160 # Use a batchtime of 14 days as suggested by the CSFLE test README diff --git a/Makefile b/Makefile index 88bc756390..b38bb4b6f0 100644 --- a/Makefile +++ b/Makefile @@ -132,6 +132,11 @@ evg-test-atlas-data-lake: evg-test-enterprise-auth: go run -tags gssapi ./cmd/testentauth/main.go +.PHONY: evg-test-oidc-auth +evg-test-oidc-auth: + go run ./cmd/testoidcauth/main.go + go run -race ./cmd/testoidcauth/main.go + .PHONY: evg-test-kmip evg-test-kmip: go test -exec "env PKG_CONFIG_PATH=$(PKG_CONFIG_PATH) LD_LIBRARY_PATH=$(LD_LIBRARY_PATH) DYLD_LIBRARY_PATH=$(MACOS_LIBRARY_PATH)" $(BUILD_TAGS) -v -timeout $(TEST_TIMEOUT)s ./mongo/integration -run TestClientSideEncryptionSpec/kmipKMS >> test.suite diff --git a/cmd/testoidcauth/main.go b/cmd/testoidcauth/main.go new file mode 100644 index 0000000000..82e95f1db1 --- /dev/null +++ b/cmd/testoidcauth/main.go @@ -0,0 +1,688 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package main + +import ( + "context" + "fmt" + "log" + "os" + "path" + "reflect" + "sync" + "time" + "unsafe" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/x/mongo/driver/auth" +) + +var uriAdmin = os.Getenv("MONGODB_URI") +var uriSingle = os.Getenv("MONGODB_URI_SINGLE") + +// var uriMulti = os.Getenv("MONGODB_URI_MULTI") +var oidcTokenDir = os.Getenv("OIDC_TOKEN_DIR") + +//var oidcDomain = os.Getenv("OIDC_DOMAIN") + +//func explicitUser(user string) string { +// return fmt.Sprintf("%s@%s", user, oidcDomain) +//} + +func tokenFile(user string) string { + return path.Join(oidcTokenDir, user) +} + +func connectAdminClinet() (*mongo.Client, error) { + return mongo.Connect(context.Background(), options.Client().ApplyURI(uriAdmin)) +} + +func connectWithMachineCB(uri string, cb options.OIDCCallback) (*mongo.Client, error) { + opts := options.Client().ApplyURI(uri) + + opts.Auth.OIDCMachineCallback = cb + return mongo.Connect(context.Background(), opts) +} + +func connectWithMachineCBAndProperties(uri string, cb options.OIDCCallback, props map[string]string) (*mongo.Client, error) { + opts := options.Client().ApplyURI(uri) + + opts.Auth.OIDCMachineCallback = cb + opts.Auth.AuthMechanismProperties = props + return mongo.Connect(context.Background(), opts) +} + +func main() { + // be quiet linter + _ = tokenFile("test_user2") + + hasError := false + aux := func(test_name string, f func() error) { + fmt.Printf("%s...", test_name) + err := f() + if err != nil { + fmt.Println("Test Error: ", err) + fmt.Println("...Failed") + hasError = true + } else { + fmt.Println("...Ok") + } + } + aux("machine_1_1_callbackIsCalled", machine11callbackIsCalled) + aux("machine_1_2_callbackIsCalledOnlyOneForMultipleConnections", machine12callbackIsCalledOnlyOneForMultipleConnections) + aux("machine_2_1_validCallbackInputs", machine21validCallbackInputs) + aux("machine_2_3_oidcCallbackReturnMissingData", machine23oidcCallbackReturnMissingData) + aux("machine_2_4_invalidClientConfigurationWithCallback", machine24invalidClientConfigurationWithCallback) + aux("machine_3_1_failureWithCachedTokensFetchANewTokenAndRetryAuth", machine31failureWithCachedTokensFetchANewTokenAndRetryAuth) + aux("machine_3_2_authFailuresWithoutCachedTokensReturnsAnError", machine32authFailuresWithoutCachedTokensReturnsAnError) + aux("machine_3_3_UnexpectedErrorCodeDoesNotClearTheCache", machine33UnexpectedErrorCodeDoesNotClearTheCache) + aux("machine_4_1_reauthenticationSucceeds", machine41ReauthenticationSucceeds) + aux("machine_4_2_readCommandsFailIfReauthenticationFails", machine42ReadCommandsFailIfReauthenticationFails) + aux("machine_4_3_writeCommandsFailIfReauthenticationFails", machine43WriteCommandsFailIfReauthenticationFails) + if hasError { + log.Fatal("One or more tests failed") + } +} + +func machine11callbackIsCalled() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_1_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_1_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_1_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_1_1: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine12callbackIsCalledOnlyOneForMultipleConnections() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_1_2: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_1_2: failed connecting client: %v", err) + } + + var wg sync.WaitGroup + + var findFailed error + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + coll := client.Database("test").Collection("test") + _, err := coll.Find(context.Background(), bson.D{}) + if err != nil { + findFailed = fmt.Errorf("machine_1_2: failed executing Find: %v", err) + } + }() + } + + wg.Wait() + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_1_2: expected callback count to be 1, got %d", callbackCount) + } + if callbackFailed != nil { + return callbackFailed + } + return findFailed +} + +func machine21validCallbackInputs() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + if args.RefreshToken != nil { + callbackFailed = fmt.Errorf("machine_2_1: expected RefreshToken to be nil, got %v", args.RefreshToken) + } + timeout, ok := ctx.Deadline() + if !ok { + callbackFailed = fmt.Errorf("machine_2_1: expected context to have deadline, got %v", ctx) + } + if timeout.Before(time.Now()) { + callbackFailed = fmt.Errorf("machine_2_1: expected timeout to be in the future, got %v", timeout) + } + if args.Version < 1 { + callbackFailed = fmt.Errorf("machine_2_1: expected Version to be at least 1, got %d", args.Version) + } + if args.IDPInfo != nil { + callbackFailed = fmt.Errorf("machine_2_1: expected IdpID to be nil for Machine flow, got %v", args.IDPInfo) + } + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + fmt.Printf("machine_2_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_2_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_2_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_2_1: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine23oidcCallbackReturnMissingData() error { + callbackCount := 0 + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + return &options.OIDCCredential{ + AccessToken: "", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_2_3: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_2_3: should have failed to executed Find, but succeeded") + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_2_3: expected callback count to be 1, got %d", callbackCount) + } + return nil +} + +func machine24invalidClientConfigurationWithCallback() error { + _, err := connectWithMachineCBAndProperties(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + t := time.Now().Add(time.Hour) + return &options.OIDCCredential{ + AccessToken: "", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }, + map[string]string{"ENVIRONMENT": "test"}, + ) + if err == nil { + return fmt.Errorf("machine_2_4: succeeded building client when it should fail") + } + return nil +} + +func machine31failureWithCachedTokensFetchANewTokenAndRetryAuth() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_3_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_3_1: failed connecting client: %v", err) + } + + // Poison the cache with a random token + clientElem := reflect.ValueOf(client).Elem() + authenticatorField := clientElem.FieldByName("authenticator") + authenticatorField = reflect.NewAt( + authenticatorField.Type(), + unsafe.Pointer(authenticatorField.UnsafeAddr())).Elem() + // this is the only usage of the x packages in the test, showing the the public interface is + // correct. + authenticatorField.Interface().(*auth.OIDCAuthenticator).SetAccessToken("some random happy sunshine string") + + coll := client.Database("test").Collection("test") + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_3_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_1: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine32authFailuresWithoutCachedTokensReturnsAnError() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + return &options.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_3_2: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_3_2: Find ucceeded when it should fail") + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_2: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine33UnexpectedErrorCodeDoesNotClearTheCache() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + defer adminClient.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_3_3: failed connecting admin client: %v", err) + } + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_3_3: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_3_3: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "saslStart", + }}, + {Key: "errorCode", Value: 20}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_3_3: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_3_3: Find succeeded when it should fail") + } + + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 1 { + return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d", callbackCount) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_3_3: failed executing Find: %v", err) + } + if callbackCount != 1 { + return fmt.Errorf("machine_3_3: expected callback count to be 1, got %d", callbackCount) + } + return callbackFailed +} + +func machine41ReauthenticationSucceeds() error { + callbackCount := 0 + var callbackFailed error + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + defer adminClient.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_1: failed connecting admin client: %v", err) + } + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_4_1: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_1: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_4_1: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_4_1: failed executing Find: %v", err) + } + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("machine_4_1: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} + +func machine42ReadCommandsFailIfReauthenticationFails() error { + callbackCount := 0 + var callbackFailed error + firstCall := true + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + defer adminClient.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_2: failed connecting admin client: %v", err) + } + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + if firstCall { + firstCall = false + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_4_2: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + return &options.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_2: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + _, err = coll.Find(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_4_2: failed executing Find: %v", err) + } + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "find", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_4_2: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.Find(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_4_2: Find succeeded when it should fail") + } + + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("machine_4_2: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} + +func machine43WriteCommandsFailIfReauthenticationFails() error { + callbackCount := 0 + var callbackFailed error + firstCall := true + countMutex := sync.Mutex{} + + adminClient, err := connectAdminClinet() + defer adminClient.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_3: failed connecting admin client: %v", err) + } + + client, err := connectWithMachineCB(uriSingle, func(ctx context.Context, args *options.OIDCArgs) (*options.OIDCCredential, error) { + countMutex.Lock() + defer countMutex.Unlock() + callbackCount++ + t := time.Now().Add(time.Hour) + if firstCall { + firstCall = false + tokenFile := tokenFile("test_user1") + accessToken, err := os.ReadFile(tokenFile) + if err != nil { + callbackFailed = fmt.Errorf("machine_4_3: failed reading token file: %v", err) + } + return &options.OIDCCredential{ + AccessToken: string(accessToken), + ExpiresAt: &t, + RefreshToken: nil, + }, nil + } + return &options.OIDCCredential{ + AccessToken: "this is a bad, bad token", + ExpiresAt: &t, + RefreshToken: nil, + }, nil + }) + + defer client.Disconnect(context.Background()) + + if err != nil { + return fmt.Errorf("machine_4_3: failed connecting client: %v", err) + } + + coll := client.Database("test").Collection("test") + _, err = coll.InsertOne(context.Background(), bson.D{}) + if err != nil { + return fmt.Errorf("machine_4_3: failed executing Insert: %v", err) + } + + res := adminClient.Database("admin").RunCommand(context.Background(), bson.D{ + {Key: "configureFailPoint", Value: "failCommand"}, + {Key: "mode", Value: bson.D{ + {Key: "times", Value: 1}, + }}, + {Key: "data", Value: bson.D{ + {Key: "failCommands", Value: bson.A{ + "insert", + }}, + {Key: "errorCode", Value: 391}, + }}, + }) + + if res.Err() != nil { + return fmt.Errorf("machine_4_3: failed setting failpoint: %v", res.Err()) + } + + _, err = coll.InsertOne(context.Background(), bson.D{}) + if err == nil { + return fmt.Errorf("machine_4_3: Insert succeeded when it should fail") + } + + countMutex.Lock() + defer countMutex.Unlock() + if callbackCount != 2 { + return fmt.Errorf("machine_4_3: expected callback count to be 2, got %d", callbackCount) + } + return callbackFailed +} diff --git a/etc/run-oidc-test.sh b/etc/run-oidc-test.sh new file mode 100644 index 0000000000..bc5eb99758 --- /dev/null +++ b/etc/run-oidc-test.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +# run-oidc-test +# Runs oidc auth tests. +set -eu + +echo "Running MONGODB-OIDC authentication tests" + +OIDC_ENV="${OIDC_ENV:-"test"}" + +if [ $OIDC_ENV == "test" ]; then + # Make sure DRIVERS_TOOLS is set. + if [ -z "$DRIVERS_TOOLS" ]; then + echo "Must specify DRIVERS_TOOLS" + exit 1 + fi + source ${DRIVERS_TOOLS}/.evergreen/auth_oidc/secrets-export.sh + +elif [ $OIDC_ENV == "azure" ]; then + source ./env.sh + +elif [ $OIDC_ENV == "gcp" ]; then + source ./secrets-export.sh + +else + echo "Unrecognized OIDC_ENV $OIDC_ENV" + exit 1 +fi + +export TEST_AUTH_OIDC=1 +export COVERAGE=1 +export AUTH="auth" + +make -s evg-test-oidc-auth diff --git a/mongo/bulk_write.go b/mongo/bulk_write.go index 3fdb67b9a2..40f1181e0e 100644 --- a/mongo/bulk_write.go +++ b/mongo/bulk_write.go @@ -186,7 +186,7 @@ func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (opera Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE). ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout). - Logger(bw.collection.client.logger) + Logger(bw.collection.client.logger).Authenticator(bw.collection.client.authenticator) if bw.comment != nil { comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { @@ -256,7 +256,7 @@ func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (opera Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint). ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout). - Logger(bw.collection.client.logger) + Logger(bw.collection.client.logger).Authenticator(bw.collection.client.authenticator) if bw.comment != nil { comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { @@ -387,7 +387,8 @@ func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (opera Database(bw.collection.db.name).Collection(bw.collection.name). Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint). ArrayFilters(hasArrayFilters).ServerAPI(bw.collection.client.serverAPI). - Timeout(bw.collection.client.timeout).Logger(bw.collection.client.logger) + Timeout(bw.collection.client.timeout).Logger(bw.collection.client.logger). + Authenticator(bw.collection.client.authenticator) if bw.comment != nil { comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry) if err != nil { diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 8d0a2031de..3ea8baf1f2 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -137,7 +137,8 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in ReadPreference(config.readPreference).ReadConcern(config.readConcern). Deployment(cs.client.deployment).ClusterClock(cs.client.clock). CommandMonitor(cs.client.monitor).Session(cs.sess).ServerSelector(cs.selector).Retry(driver.RetryNone). - ServerAPI(cs.client.serverAPI).Crypt(config.crypt).Timeout(cs.client.timeout) + ServerAPI(cs.client.serverAPI).Crypt(config.crypt).Timeout(cs.client.timeout). + Authenticator(cs.client.authenticator) if cs.options.Collation != nil { cs.aggregate.Collation(bsoncore.Document(cs.options.Collation.ToDocument())) diff --git a/mongo/client.go b/mongo/client.go index 4266412aab..00f4f363ae 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -26,6 +26,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt" mcopts "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" @@ -79,6 +80,7 @@ type Client struct { metadataClientFLE *Client internalClientFLE *Client encryptedFieldsMap map[string]interface{} + authenticator driver.Authenticator } // Connect creates a new Client and then initializes it using the Connect method. This is equivalent to calling @@ -209,7 +211,40 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { clientOpt.SetMaxPoolSize(defaultMaxPoolSize) } - cfg, err := topology.NewConfig(clientOpt, client.clock) + if clientOpt.Auth != nil { + var oidcMachineCallback auth.OIDCCallback + if clientOpt.Auth.OIDCMachineCallback != nil { + oidcMachineCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + cred, err := clientOpt.Auth.OIDCMachineCallback(ctx, convertOIDCArgs(args)) + return (*driver.OIDCCredential)(cred), err + } + } + + var oidcHumanCallback auth.OIDCCallback + if clientOpt.Auth.OIDCHumanCallback != nil { + oidcHumanCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + cred, err := clientOpt.Auth.OIDCHumanCallback(ctx, convertOIDCArgs(args)) + return (*driver.OIDCCredential)(cred), err + } + } + + // Create an authenticator for the client + client.authenticator, err = auth.CreateAuthenticator(clientOpt.Auth.AuthMechanism, &auth.Cred{ + Source: clientOpt.Auth.AuthSource, + Username: clientOpt.Auth.Username, + Password: clientOpt.Auth.Password, + PasswordSet: clientOpt.Auth.PasswordSet, + Props: clientOpt.Auth.AuthMechanismProperties, + OIDCMachineCallback: oidcMachineCallback, + OIDCHumanCallback: oidcHumanCallback, + }, clientOpt.HTTPClient) + if err != nil { + return nil, err + } + } + + cfg, err := topology.NewConfigWithAuthenticator(clientOpt, client.clock, client.authenticator) + if err != nil { return nil, err } @@ -231,6 +266,19 @@ func NewClient(opts ...*options.ClientOptions) (*Client, error) { return client, nil } +// convertOIDCArgs converts the internal *driver.OIDCArgs into the equivalent +// public type *options.OIDCArgs. +func convertOIDCArgs(args *driver.OIDCArgs) *options.OIDCArgs { + if args == nil { + return nil + } + return &options.OIDCArgs{ + Version: args.Version, + IDPInfo: (*options.IDPInfo)(args.IDPInfo), + RefreshToken: args.RefreshToken, + } +} + // Connect initializes the Client by starting background monitoring goroutines. // If the Client was created using the NewClient function, this method must be called before a Client can be used. // @@ -690,7 +738,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... op := operation.NewListDatabases(filterDoc). Session(sess).ReadPreference(c.readPreference).CommandMonitor(c.monitor). ServerSelector(selector).ClusterClock(c.clock).Database("admin").Deployment(c.deployment).Crypt(c.cryptFLE). - ServerAPI(c.serverAPI).Timeout(c.timeout) + ServerAPI(c.serverAPI).Timeout(c.timeout).Authenticator(c.authenticator) if ldo.NameOnly != nil { op = op.NameOnly(*ldo.NameOnly) diff --git a/mongo/client_test.go b/mongo/client_test.go index 013c1ae6bb..0a96e54501 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -11,6 +11,7 @@ import ( "errors" "math" "os" + "reflect" "testing" "time" @@ -18,11 +19,13 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/integtest" + "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/tag" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt" "go.mongodb.org/mongo-driver/x/mongo/driver/session" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" @@ -502,3 +505,76 @@ func TestClient(t *testing.T) { } }) } + +// Test that convertOIDCArgs exhaustively copies all fields of a driver.OIDCArgs +// into an options.OIDCArgs. +func TestConvertOIDCArgs(t *testing.T) { + refreshToken := "test refresh token" + + testCases := []struct { + desc string + args *driver.OIDCArgs + }{ + { + desc: "populated args", + args: &driver.OIDCArgs{ + Version: 9, + IDPInfo: &driver.IDPInfo{ + Issuer: "test issuer", + ClientID: "test client ID", + RequestScopes: []string{"test scope 1", "test scope 2"}, + }, + RefreshToken: &refreshToken, + }, + }, + { + desc: "nil", + args: nil, + }, + { + desc: "nil IDPInfo and RefreshToken", + args: &driver.OIDCArgs{ + Version: 9, + IDPInfo: nil, + RefreshToken: nil, + }, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + got := convertOIDCArgs(tc.args) + + if tc.args == nil { + assert.Nil(t, got, "expected nil when input is nil") + return + } + + require.Equal(t, + 3, + reflect.ValueOf(*tc.args).NumField(), + "expected the driver.OIDCArgs struct to have exactly 3 fields") + require.Equal(t, + 3, + reflect.ValueOf(*got).NumField(), + "expected the options.OIDCArgs struct to have exactly 3 fields") + + assert.Equal(t, + tc.args.Version, + got.Version, + "expected Version field to be equal") + assert.EqualValues(t, + tc.args.IDPInfo, + got.IDPInfo, + "expected IDPInfo field to be convertible to equal values") + assert.Equal(t, + tc.args.RefreshToken, + got.RefreshToken, + "expected RefreshToken field to be equal") + }) + } +} diff --git a/mongo/collection.go b/mongo/collection.go index 4cf6fd1a1a..8a0a054d5e 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -291,7 +291,8 @@ func (coll *Collection) insert(ctx context.Context, documents []interface{}, ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Ordered(true). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger). + Authenticator(coll.client.authenticator) imo := options.MergeInsertManyOptions(opts...) if imo.BypassDocumentValidation != nil && *imo.BypassDocumentValidation { op = op.BypassDocumentValidation(*imo.BypassDocumentValidation) @@ -471,7 +472,8 @@ func (coll *Collection) delete(ctx context.Context, filter interface{}, deleteOn ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Ordered(true). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).Logger(coll.client.logger). + Authenticator(coll.client.authenticator) if do.Comment != nil { comment, err := marshalValue(do.Comment, coll.bsonOpts, coll.registry) if err != nil { @@ -588,7 +590,7 @@ func (coll *Collection) updateOrReplace(ctx context.Context, filter bsoncore.Doc Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).Hint(uo.Hint != nil). ArrayFilters(uo.ArrayFilters != nil).Ordered(true).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).Logger(coll.client.logger) + Timeout(coll.client.timeout).Logger(coll.client.logger).Authenticator(coll.client.authenticator) if uo.Let != nil { let, err := marshal(uo.Let, coll.bsonOpts, coll.registry) if err != nil { @@ -861,7 +863,8 @@ func aggregate(a aggregateParams) (cur *Cursor, err error) { ServerAPI(a.client.serverAPI). HasOutputStage(hasOutputStage). Timeout(a.client.timeout). - MaxTime(ao.MaxTime) + MaxTime(ao.MaxTime). + Authenticator(a.client.authenticator) // Omit "maxTimeMS" from operations that return a user-managed cursor to // prevent confusing "cursor not found" errors. To maintain existing @@ -992,7 +995,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{}, op := operation.NewAggregate(pipelineArr).Session(sess).ReadConcern(rc).ReadPreference(coll.readPreference). CommandMonitor(coll.client.monitor).ServerSelector(selector).ClusterClock(coll.client.clock).Database(coll.db.name). Collection(coll.name).Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(countOpts.MaxTime) + Timeout(coll.client.timeout).MaxTime(countOpts.MaxTime).Authenticator(coll.client.authenticator) if countOpts.Collation != nil { op.Collation(bsoncore.Document(countOpts.Collation.ToDocument())) } @@ -1077,7 +1080,7 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context, Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(co.MaxTime) + Timeout(coll.client.timeout).MaxTime(co.MaxTime).Authenticator(coll.client.authenticator) if co.Comment != nil { comment, err := marshalValue(co.Comment, coll.bsonOpts, coll.registry) @@ -1144,7 +1147,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i Database(coll.db.name).Collection(coll.name).CommandMonitor(coll.client.monitor). Deployment(coll.client.deployment).ReadConcern(rc).ReadPreference(coll.readPreference). ServerSelector(selector).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). - Timeout(coll.client.timeout).MaxTime(option.MaxTime) + Timeout(coll.client.timeout).MaxTime(option.MaxTime).Authenticator(coll.client.authenticator) if option.Collation != nil { op.Collation(bsoncore.Document(option.Collation.ToDocument())) @@ -1224,6 +1227,7 @@ func (coll *Collection) find( f, err := marshal(filter, coll.bsonOpts, coll.registry) if err != nil { + fmt.Println(err) return nil, err } @@ -1257,7 +1261,7 @@ func (coll *Collection) find( ClusterClock(coll.client.clock).Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE).ServerAPI(coll.client.serverAPI). Timeout(coll.client.timeout).MaxTime(fo.MaxTime).Logger(coll.client.logger). - OmitCSOTMaxTimeMS(omitCSOTMaxTimeMS) + OmitCSOTMaxTimeMS(omitCSOTMaxTimeMS).Authenticator(coll.client.authenticator) cursorOpts := coll.client.createBaseCursorOptions() @@ -1521,7 +1525,7 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{} } fod := options.MergeFindOneAndDeleteOptions(opts...) op := operation.NewFindAndModify(f).Remove(true).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). - MaxTime(fod.MaxTime) + MaxTime(fod.MaxTime).Authenticator(coll.client.authenticator) if fod.Collation != nil { op = op.Collation(bsoncore.Document(fod.Collation.ToDocument())) } @@ -1601,7 +1605,8 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{ fo := options.MergeFindOneAndReplaceOptions(opts...) op := operation.NewFindAndModify(f).Update(bsoncore.Value{Type: bsontype.EmbeddedDocument, Data: r}). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).MaxTime(fo.MaxTime) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout).MaxTime(fo.MaxTime).Authenticator(coll.client.authenticator) + if fo.BypassDocumentValidation != nil && *fo.BypassDocumentValidation { op = op.BypassDocumentValidation(*fo.BypassDocumentValidation) } @@ -1688,7 +1693,7 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{} fo := options.MergeFindOneAndUpdateOptions(opts...) op := operation.NewFindAndModify(f).ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). - MaxTime(fo.MaxTime) + MaxTime(fo.MaxTime).Authenticator(coll.client.authenticator) u, err := marshalUpdateValue(update, coll.bsonOpts, coll.registry, true) if err != nil { @@ -1894,7 +1899,8 @@ func (coll *Collection) drop(ctx context.Context) error { ServerSelector(selector).ClusterClock(coll.client.clock). Database(coll.db.name).Collection(coll.name). Deployment(coll.client.deployment).Crypt(coll.client.cryptFLE). - ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout) + ServerAPI(coll.client.serverAPI).Timeout(coll.client.timeout). + Authenticator(coll.client.authenticator) err = op.Execute(ctx) // ignore namespace not found errors diff --git a/mongo/database.go b/mongo/database.go index 57c0186eca..5344c9641e 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -189,7 +189,7 @@ func (db *Database) processRunCommand(ctx context.Context, cmd interface{}, ServerSelector(readSelect).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment). Crypt(db.client.cryptFLE).ReadPreference(ro.ReadPreference).ServerAPI(db.client.serverAPI). - Timeout(db.client.timeout).Logger(db.client.logger), sess, nil + Timeout(db.client.timeout).Logger(db.client.logger).Authenticator(db.client.authenticator), sess, nil } // RunCommand executes the given command against the database. @@ -308,7 +308,7 @@ func (db *Database) Drop(ctx context.Context) error { Session(sess).WriteConcern(wc).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment).Crypt(db.client.cryptFLE). - ServerAPI(db.client.serverAPI) + ServerAPI(db.client.serverAPI).Authenticator(db.client.authenticator) err = op.Execute(ctx) @@ -402,7 +402,7 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt Session(sess).ReadPreference(db.readPreference).CommandMonitor(db.client.monitor). ServerSelector(selector).ClusterClock(db.client.clock). Database(db.name).Deployment(db.client.deployment).Crypt(db.client.cryptFLE). - ServerAPI(db.client.serverAPI).Timeout(db.client.timeout) + ServerAPI(db.client.serverAPI).Timeout(db.client.timeout).Authenticator(db.client.authenticator) cursorOpts := db.client.createBaseCursorOptions() @@ -679,7 +679,7 @@ func (db *Database) createCollection(ctx context.Context, name string, opts ...* func (db *Database) createCollectionOperation(name string, opts ...*options.CreateCollectionOptions) (*operation.Create, error) { cco := options.MergeCreateCollectionOptions(opts...) - op := operation.NewCreate(name).ServerAPI(db.client.serverAPI) + op := operation.NewCreate(name).ServerAPI(db.client.serverAPI).Authenticator(db.client.authenticator) if cco.Capped != nil { op.Capped(*cco.Capped) @@ -805,7 +805,8 @@ func (db *Database) CreateView(ctx context.Context, viewName, viewOn string, pip op := operation.NewCreate(viewName). ViewOn(viewOn). Pipeline(pipelineArray). - ServerAPI(db.client.serverAPI) + ServerAPI(db.client.serverAPI). + Authenticator(db.client.authenticator) cvo := options.MergeCreateViewOptions(opts...) if cvo.Collation != nil { op.Collation(bsoncore.Document(cvo.Collation.ToDocument())) diff --git a/mongo/index_view.go b/mongo/index_view.go index 8d3555d0b0..b7e7234339 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -94,7 +94,7 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout) + Timeout(iv.coll.client.timeout).Authenticator(iv.coll.client.authenticator) cursorOpts := iv.coll.client.createBaseCursorOptions() @@ -262,7 +262,7 @@ func (iv IndexView) CreateMany(ctx context.Context, models []IndexModel, opts .. Session(sess).WriteConcern(wc).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name).CommandMonitor(iv.coll.client.monitor). Deployment(iv.coll.client.deployment).ServerSelector(selector).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout).MaxTime(option.MaxTime) + Timeout(iv.coll.client.timeout).MaxTime(option.MaxTime).Authenticator(iv.coll.client.authenticator) if option.CommitQuorum != nil { commitQuorum, err := marshalValue(option.CommitQuorum, iv.coll.bsonOpts, iv.coll.registry) if err != nil { @@ -402,7 +402,8 @@ func (iv IndexView) drop(ctx context.Context, name string, opts ...*options.Drop ServerSelector(selector).ClusterClock(iv.coll.client.clock). Database(iv.coll.db.name).Collection(iv.coll.name). Deployment(iv.coll.client.deployment).ServerAPI(iv.coll.client.serverAPI). - Timeout(iv.coll.client.timeout).MaxTime(dio.MaxTime) + Timeout(iv.coll.client.timeout).MaxTime(dio.MaxTime). + Authenticator(iv.coll.client.authenticator) err = op.Execute(ctx) if err != nil { diff --git a/mongo/integration/mtest/opmsg_deployment.go b/mongo/integration/mtest/opmsg_deployment.go index 2215f84b38..2ddc23c413 100644 --- a/mongo/integration/mtest/opmsg_deployment.go +++ b/mongo/integration/mtest/opmsg_deployment.go @@ -61,6 +61,13 @@ func (c *connection) WriteWireMessage(context.Context, []byte) error { return nil } +func (c *connection) OIDCTokenGenID() uint64 { + return 0 +} + +func (c *connection) SetOIDCTokenGenID(uint64) { +} + // ReadWireMessage returns the next response in the connection's list of responses. func (c *connection) ReadWireMessage(_ context.Context) ([]byte, error) { var dst []byte diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index 17b3731301..180d039969 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -111,6 +111,34 @@ type Credential struct { Username string Password string PasswordSet bool + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback +} + +// OIDCCallback is the type for both Human and Machine Callback flows. +// RefreshToken will always be nil in the OIDCArgs for the Machine flow. +type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) + +// OIDCArgs contains the arguments for the OIDC callback. +type OIDCArgs struct { + Version int + IDPInfo *IDPInfo + RefreshToken *string +} + +// OIDCCredential contains the access token and refresh token. +type OIDCCredential struct { + AccessToken string + ExpiresAt *time.Time + RefreshToken *string +} + +// IDPInfo contains the information needed to perform OIDC authentication with +// an Identity Provider. +type IDPInfo struct { + Issuer string + ClientID string + RequestScopes []string } // BSONOptions are optional BSON marshaling and unmarshaling behaviors. diff --git a/mongo/search_index_view.go b/mongo/search_index_view.go index 73fe8534ed..3253a73a2b 100644 --- a/mongo/search_index_view.go +++ b/mongo/search_index_view.go @@ -143,7 +143,7 @@ func (siv SearchIndexView) CreateMany( ServerSelector(selector).ClusterClock(siv.coll.client.clock). Collection(siv.coll.name).Database(siv.coll.db.name). Deployment(siv.coll.client.deployment).ServerAPI(siv.coll.client.serverAPI). - Timeout(siv.coll.client.timeout) + Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) err = op.Execute(ctx) if err != nil { @@ -198,7 +198,7 @@ func (siv SearchIndexView) DropOne( ServerSelector(selector).ClusterClock(siv.coll.client.clock). Collection(siv.coll.name).Database(siv.coll.db.name). Deployment(siv.coll.client.deployment).ServerAPI(siv.coll.client.serverAPI). - Timeout(siv.coll.client.timeout) + Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) err = op.Execute(ctx) if de, ok := err.(driver.Error); ok && de.NamespaceNotFound() { @@ -252,7 +252,7 @@ func (siv SearchIndexView) UpdateOne( ServerSelector(selector).ClusterClock(siv.coll.client.clock). Collection(siv.coll.name).Database(siv.coll.db.name). Deployment(siv.coll.client.deployment).ServerAPI(siv.coll.client.serverAPI). - Timeout(siv.coll.client.timeout) + Timeout(siv.coll.client.timeout).Authenticator(siv.coll.client.authenticator) return op.Execute(ctx) } diff --git a/mongo/session.go b/mongo/session.go index 8f1e029b95..77be4ab6db 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -296,7 +296,8 @@ func (s *sessionImpl) AbortTransaction(ctx context.Context) error { _ = operation.NewAbortTransaction().Session(s.clientSession).ClusterClock(s.client.clock).Database("admin"). Deployment(s.deployment).WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector). Retry(driver.RetryOncePerCommand).CommandMonitor(s.client.monitor). - RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).ServerAPI(s.client.serverAPI).Execute(ctx) + RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).ServerAPI(s.client.serverAPI). + Authenticator(s.client.authenticator).Execute(ctx) s.clientSession.Aborting = false _ = s.clientSession.AbortTransaction() @@ -328,7 +329,7 @@ func (s *sessionImpl) CommitTransaction(ctx context.Context) error { Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").Deployment(s.deployment). WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).Retry(driver.RetryOncePerCommand). CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)). - ServerAPI(s.client.serverAPI).MaxTime(s.clientSession.CurrentMct) + ServerAPI(s.client.serverAPI).MaxTime(s.clientSession.CurrentMct).Authenticator(s.client.authenticator) err = op.Execute(ctx) // Return error without updating transaction state if it is a timeout, as the transaction has not diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index 6eeaf0ee01..f6471cea26 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -19,8 +19,11 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) +// Config contains the configuration for an Authenticator. +type Config = driver.AuthConfig + // AuthenticatorFactory constructs an authenticator. -type AuthenticatorFactory func(cred *Cred) (Authenticator, error) +type AuthenticatorFactory func(*Cred, *http.Client) (Authenticator, error) var authFactories = make(map[string]AuthenticatorFactory) @@ -33,12 +36,13 @@ func init() { RegisterAuthenticatorFactory(GSSAPI, newGSSAPIAuthenticator) RegisterAuthenticatorFactory(MongoDBX509, newMongoDBX509Authenticator) RegisterAuthenticatorFactory(MongoDBAWS, newMongoDBAWSAuthenticator) + RegisterAuthenticatorFactory(MongoDBOIDC, newOIDCAuthenticator) } // CreateAuthenticator creates an authenticator. -func CreateAuthenticator(name string, cred *Cred) (Authenticator, error) { +func CreateAuthenticator(name string, cred *Cred, httpClient *http.Client) (Authenticator, error) { if f, ok := authFactories[name]; ok { - return f(cred) + return f(cred, httpClient) } return nil, newAuthError(fmt.Sprintf("unknown authenticator: %s", name), nil) @@ -61,7 +65,6 @@ type HandshakeOptions struct { ClusterClock *session.ClusterClock ServerAPI *driver.ServerAPIOptions LoadBalanced bool - HTTPClient *http.Client } type authHandshaker struct { @@ -97,12 +100,17 @@ func (ah *authHandshaker) GetHandshakeInformation(ctx context.Context, addr addr return driver.HandshakeInformation{}, newAuthError("failed to create conversation", err) } - firstMsg, err := ah.conversation.FirstMessage() - if err != nil { - return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err) - } + // It is possible for the speculative conversation to be nil even without error if the authenticator + // cannot perform speculative authentication. An example of this is MONGODB-OIDC when there is + // no AccessToken in the cache. + if ah.conversation != nil { + firstMsg, err := ah.conversation.FirstMessage() + if err != nil { + return driver.HandshakeInformation{}, newAuthError("failed to create speculative authentication message", err) + } - op = op.SpeculativeAuthenticate(firstMsg) + op = op.SpeculativeAuthenticate(firstMsg) + } } } @@ -132,7 +140,6 @@ func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Conne ClusterClock: ah.options.ClusterClock, HandshakeInfo: ah.handshakeInfo, ServerAPI: ah.options.ServerAPI, - HTTPClient: ah.options.HTTPClient, } if err := ah.authenticate(ctx, cfg); err != nil { @@ -170,21 +177,8 @@ func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshake } } -// Config holds the information necessary to perform an authentication attempt. -type Config struct { - Description description.Server - Connection driver.Connection - ClusterClock *session.ClusterClock - HandshakeInfo driver.HandshakeInformation - ServerAPI *driver.ServerAPIOptions - HTTPClient *http.Client -} - // Authenticator handles authenticating a connection. -type Authenticator interface { - // Auth authenticates the connection. - Auth(context.Context, *Config) error -} +type Authenticator = driver.Authenticator func newAuthError(msg string, inner error) error { return &Error{ diff --git a/x/mongo/driver/auth/auth_test.go b/x/mongo/driver/auth/auth_test.go index 9145a21595..3c07ed2cd8 100644 --- a/x/mongo/driver/auth/auth_test.go +++ b/x/mongo/driver/auth/auth_test.go @@ -7,6 +7,7 @@ package auth_test import ( + "net/http" "testing" "github.com/google/go-cmp/cmp" @@ -39,7 +40,7 @@ func TestCreateAuthenticator(t *testing.T) { PasswordSet: true, } - a, err := CreateAuthenticator(test.name, cred) + a, err := CreateAuthenticator(test.name, cred, &http.Client{}) require.NoError(t, err) require.IsType(t, test.auth, a) }) diff --git a/x/mongo/driver/auth/cred.go b/x/mongo/driver/auth/cred.go index 7b2b8f17d0..a9685f6ed8 100644 --- a/x/mongo/driver/auth/cred.go +++ b/x/mongo/driver/auth/cred.go @@ -6,11 +6,9 @@ package auth -// Cred is a user's credential. -type Cred struct { - Source string - Username string - Password string - PasswordSet bool - Props map[string]string -} +import ( + "go.mongodb.org/mongo-driver/x/mongo/driver" +) + +// Cred is the type of user credential +type Cred = driver.Cred diff --git a/x/mongo/driver/auth/default.go b/x/mongo/driver/auth/default.go index 6f2ca5224a..785a41951d 100644 --- a/x/mongo/driver/auth/default.go +++ b/x/mongo/driver/auth/default.go @@ -9,10 +9,13 @@ package auth import ( "context" "fmt" + "net/http" + + "go.mongodb.org/mongo-driver/x/mongo/driver" ) -func newDefaultAuthenticator(cred *Cred) (Authenticator, error) { - scram, err := newScramSHA256Authenticator(cred) +func newDefaultAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { + scram, err := newScramSHA256Authenticator(cred, httpClient) if err != nil { return nil, newAuthError("failed to create internal authenticator", err) } @@ -25,6 +28,7 @@ func newDefaultAuthenticator(cred *Cred) (Authenticator, error) { return &DefaultAuthenticator{ Cred: cred, speculativeAuthenticator: speculative, + httpClient: httpClient, }, nil } @@ -36,6 +40,8 @@ type DefaultAuthenticator struct { // The authenticator to use for speculative authentication. Because the correct auth mechanism is unknown when doing // the initial hello, SCRAM-SHA-256 is used for the speculative attempt. speculativeAuthenticator SpeculativeAuthenticator + + httpClient *http.Client } var _ SpeculativeAuthenticator = (*DefaultAuthenticator)(nil) @@ -52,11 +58,11 @@ func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error { switch chooseAuthMechanism(cfg) { case SCRAMSHA256: - actual, err = newScramSHA256Authenticator(a.Cred) + actual, err = newScramSHA256Authenticator(a.Cred, a.httpClient) case SCRAMSHA1: - actual, err = newScramSHA1Authenticator(a.Cred) + actual, err = newScramSHA1Authenticator(a.Cred, a.httpClient) default: - actual, err = newMongoDBCRAuthenticator(a.Cred) + actual, err = newMongoDBCRAuthenticator(a.Cred, a.httpClient) } if err != nil { @@ -66,6 +72,11 @@ func (a *DefaultAuthenticator) Auth(ctx context.Context, cfg *Config) error { return actual.Auth(ctx, cfg) } +// Reauth reauthenticates the connection. +func (a *DefaultAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("DefaultAuthenticator does not support reauthentication", nil) +} + // If a server provides a list of supported mechanisms, we choose // SCRAM-SHA-256 if it exists or else MUST use SCRAM-SHA-1. // Otherwise, we decide based on what is supported. diff --git a/x/mongo/driver/auth/gssapi.go b/x/mongo/driver/auth/gssapi.go index 4b860ba63f..037c944eb7 100644 --- a/x/mongo/driver/auth/gssapi.go +++ b/x/mongo/driver/auth/gssapi.go @@ -14,14 +14,16 @@ import ( "context" "fmt" "net" + "net/http" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth/internal/gssapi" ) // GSSAPI is the mechanism name for GSSAPI. const GSSAPI = "GSSAPI" -func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) { +func newGSSAPIAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { if cred.Source != "" && cred.Source != "$external" { return nil, newAuthError("GSSAPI source must be empty or $external", nil) } @@ -57,3 +59,8 @@ func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *Config) error { } return ConductSaslConversation(ctx, cfg, "$external", client) } + +// Reauth reauthenticates the connection. +func (a *GSSAPIAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("GSSAPI does not support reauthentication", nil) +} diff --git a/x/mongo/driver/auth/gssapi_not_enabled.go b/x/mongo/driver/auth/gssapi_not_enabled.go index 7ba5fe860c..e50553c7a1 100644 --- a/x/mongo/driver/auth/gssapi_not_enabled.go +++ b/x/mongo/driver/auth/gssapi_not_enabled.go @@ -9,9 +9,11 @@ package auth +import "net/http" + // GSSAPI is the mechanism name for GSSAPI. const GSSAPI = "GSSAPI" -func newGSSAPIAuthenticator(*Cred) (Authenticator, error) { +func newGSSAPIAuthenticator(*Cred, *http.Client) (Authenticator, error) { return nil, newAuthError("GSSAPI support not enabled during build (-tags gssapi)", nil) } diff --git a/x/mongo/driver/auth/gssapi_not_supported.go b/x/mongo/driver/auth/gssapi_not_supported.go index 10312c228e..12046ff67c 100644 --- a/x/mongo/driver/auth/gssapi_not_supported.go +++ b/x/mongo/driver/auth/gssapi_not_supported.go @@ -11,12 +11,13 @@ package auth import ( "fmt" + "net/http" "runtime" ) // GSSAPI is the mechanism name for GSSAPI. const GSSAPI = "GSSAPI" -func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) { +func newGSSAPIAuthenticator(*Cred, *http.Client) (Authenticator, error) { return nil, newAuthError(fmt.Sprintf("GSSAPI is not supported on %s", runtime.GOOS), nil) } diff --git a/x/mongo/driver/auth/mongodbaws.go b/x/mongo/driver/auth/mongodbaws.go index 7ae4b08998..2245bdb6fe 100644 --- a/x/mongo/driver/auth/mongodbaws.go +++ b/x/mongo/driver/auth/mongodbaws.go @@ -9,19 +9,24 @@ package auth import ( "context" "errors" + "net/http" "go.mongodb.org/mongo-driver/internal/aws/credentials" "go.mongodb.org/mongo-driver/internal/credproviders" + "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth/creds" ) // MongoDBAWS is the mechanism name for MongoDBAWS. const MongoDBAWS = "MONGODB-AWS" -func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) { +func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { if cred.Source != "" && cred.Source != "$external" { return nil, newAuthError("MONGODB-AWS source must be empty or $external", nil) } + if httpClient == nil { + return nil, errors.New("httpClient must not be nil") + } return &MongoDBAWSAuthenticator{ source: cred.Source, credentials: &credproviders.StaticProvider{ @@ -32,6 +37,7 @@ func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) { SessionToken: cred.Props["AWS_SESSION_TOKEN"], }, }, + httpClient: httpClient, }, nil } @@ -39,15 +45,12 @@ func newMongoDBAWSAuthenticator(cred *Cred) (Authenticator, error) { type MongoDBAWSAuthenticator struct { source string credentials *credproviders.StaticProvider + httpClient *http.Client } // Auth authenticates the connection. func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error { - httpClient := cfg.HTTPClient - if httpClient == nil { - return errors.New("cfg.HTTPClient must not be nil") - } - providers := creds.NewAWSCredentialProvider(httpClient, a.credentials) + providers := creds.NewAWSCredentialProvider(a.httpClient, a.credentials) adapter := &awsSaslAdapter{ conversation: &awsConversation{ credentials: providers.Cred, @@ -60,6 +63,11 @@ func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *Config) error { return nil } +// Reauth reauthenticates the connection. +func (a *MongoDBAWSAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("AWS authentication does not support reauthentication", nil) +} + type awsSaslAdapter struct { conversation *awsConversation } diff --git a/x/mongo/driver/auth/mongodbcr.go b/x/mongo/driver/auth/mongodbcr.go index 6e2c2f4dcb..a988011b36 100644 --- a/x/mongo/driver/auth/mongodbcr.go +++ b/x/mongo/driver/auth/mongodbcr.go @@ -10,6 +10,7 @@ import ( "context" "fmt" "io" + "net/http" // Ignore gosec warning "Blocklisted import crypto/md5: weak cryptographic primitive". We need // to use MD5 here to implement the MONGODB-CR specification. @@ -28,7 +29,7 @@ import ( // MongoDB 4.0. const MONGODBCR = "MONGODB-CR" -func newMongoDBCRAuthenticator(cred *Cred) (Authenticator, error) { +func newMongoDBCRAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { return &MongoDBCRAuthenticator{ DB: cred.Source, Username: cred.Username, @@ -97,6 +98,11 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, cfg *Config) error { return nil } +// Reauth reauthenticates the connection. +func (a *MongoDBCRAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("MONGODB-CR does not support reauthentication", nil) +} + func (a *MongoDBCRAuthenticator) createKey(nonce string) string { // Ignore gosec warning "Use of weak cryptographic primitive". We need to use MD5 here to // implement the MONGODB-CR specification. diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go new file mode 100644 index 0000000000..91748598d3 --- /dev/null +++ b/x/mongo/driver/auth/oidc.go @@ -0,0 +1,343 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package auth + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" +) + +// MongoDBOIDC is the string constant for the MONGODB-OIDC authentication mechanism. +const MongoDBOIDC = "MONGODB-OIDC" + +// TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider +// const tokenResourceProp = "TOKEN_RESOURCE" +const environmentProp = "ENVIRONMENT" + +const resourceProp = "TOKEN_RESOURCE" + +// GODRIVER-3249 OIDC: Handle all possible OIDC configuration errors +//const allowedHostsProp = "ALLOWED_HOSTS" + +const azureEnvironmentValue = "azure" +const gcpEnvironmentValue = "gcp" +const testEnvironmentValue = "test" + +const apiVersion = 1 +const invalidateSleepTimeout = 100 * time.Millisecond + +// The CSOT specification says to apply a 1-minute timeout if "CSOT is not applied". That's +// ambiguous for the v1.x Go Driver because it could mean either "no timeout provided" or "CSOT not +// enabled". Always use a maximum timeout duration of 1 minute, allowing us to ignore the ambiguity. +// Contexts with a shorter timeout are unaffected. +const machineCallbackTimeout = 60 * time.Second + +//GODRIVER-3246 OIDC: Implement Human Callback Mechanism +//var defaultAllowedHosts = []string{ +// "*.mongodb.net", +// "*.mongodb-qa.net", +// "*.mongodb-dev.net", +// "*.mongodbgov.net", +// "localhost", +// "127.0.0.1", +// "::1", +//} + +// OIDCCallback is a function that takes a context and OIDCArgs and returns an OIDCCredential. +type OIDCCallback = driver.OIDCCallback + +// OIDCArgs contains the arguments for the OIDC callback. +type OIDCArgs = driver.OIDCArgs + +// OIDCCredential contains the access token and refresh token. +type OIDCCredential = driver.OIDCCredential + +// IDPInfo contains the information needed to perform OIDC authentication with an Identity Provider. +type IDPInfo = driver.IDPInfo + +var _ driver.Authenticator = (*OIDCAuthenticator)(nil) +var _ SpeculativeAuthenticator = (*OIDCAuthenticator)(nil) +var _ SaslClient = (*oidcOneStep)(nil) + +// OIDCAuthenticator is synchronized and handles caching of the access token, refreshToken, +// and IDPInfo. It also provides a mechanism to refresh the access token, but this functionality +// is only for the OIDC Human flow. +type OIDCAuthenticator struct { + mu sync.Mutex // Guards all of the info in the OIDCAuthenticator struct. + + AuthMechanismProperties map[string]string + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback + + userName string + httpClient *http.Client + accessToken string + refreshToken *string + idpInfo *IDPInfo + tokenGenID uint64 +} + +// SetAccessToken allows for manually setting the access token for the OIDCAuthenticator, this is +// only for testing purposes. +func (oa *OIDCAuthenticator) SetAccessToken(accessToken string) { + oa.mu.Lock() + defer oa.mu.Unlock() + oa.accessToken = accessToken +} + +func newOIDCAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { + if cred.Password != "" { + return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC) + } + if cred.Props != nil { + if env, ok := cred.Props[environmentProp]; ok { + switch strings.ToLower(env) { + case azureEnvironmentValue: + fallthrough + case gcpEnvironmentValue: + if _, ok := cred.Props[resourceProp]; !ok { + return nil, fmt.Errorf("%q must be specified for %q %q", resourceProp, env, environmentProp) + } + fallthrough + case testEnvironmentValue: + if cred.OIDCMachineCallback != nil || cred.OIDCHumanCallback != nil { + return nil, fmt.Errorf("OIDC callbacks are not allowed for %q %q", env, environmentProp) + } + } + } + } + oa := &OIDCAuthenticator{ + userName: cred.Username, + httpClient: httpClient, + AuthMechanismProperties: cred.Props, + OIDCMachineCallback: cred.OIDCMachineCallback, + OIDCHumanCallback: cred.OIDCHumanCallback, + } + return oa, nil +} + +type oidcOneStep struct { + userName string + accessToken string +} + +func jwtStepRequest(accessToken string) []byte { + return bsoncore.NewDocumentBuilder(). + AppendString("jwt", accessToken). + Build() +} + +// TODO GODRIVER-3246: Implement OIDC human flow +//func principalStepRequest(principal string) []byte { +// doc := bsoncore.NewDocumentBuilder() +// if principal != "" { +// doc.AppendString("n", principal) +// } +// return doc.Build() +//} + +func (oos *oidcOneStep) Start() (string, []byte, error) { + return MongoDBOIDC, jwtStepRequest(oos.accessToken), nil +} + +func (oos *oidcOneStep) Next([]byte) ([]byte, error) { + return nil, newAuthError("unexpected step in OIDC authentication", nil) +} + +func (*oidcOneStep) Completed() bool { + return true +} + +func (oa *OIDCAuthenticator) providerCallback() (OIDCCallback, error) { + env, ok := oa.AuthMechanismProperties[environmentProp] + if !ok { + return nil, nil + } + + switch env { + // TODO GODRIVER-2728: Automatic token acquisition for Azure Identity Provider + // TODO GODRIVER-2806: Automatic token acquisition for GCP Identity Provider + // This is here just to pass the linter, it will be fixed in one of the above tickets. + case azureEnvironmentValue, gcpEnvironmentValue: + return func(ctx context.Context, args *OIDCArgs) (*OIDCCredential, error) { + return nil, fmt.Errorf("automatic token acquisition for %q not implemented yet", env) + }, fmt.Errorf("automatic token acquisition for %q not implemented yet", env) + } + + return nil, fmt.Errorf("%q %q not supported for MONGODB-OIDC", environmentProp, env) +} + +func (oa *OIDCAuthenticator) getAccessToken( + ctx context.Context, + conn driver.Connection, + args *OIDCArgs, + callback OIDCCallback, +) (string, error) { + oa.mu.Lock() + defer oa.mu.Unlock() + + if oa.accessToken != "" { + return oa.accessToken, nil + } + + cred, err := callback(ctx, args) + if err != nil { + return "", err + } + + oa.accessToken = cred.AccessToken + oa.tokenGenID++ + conn.SetOIDCTokenGenID(oa.tokenGenID) + if cred.RefreshToken != nil { + oa.refreshToken = cred.RefreshToken + } + return cred.AccessToken, nil +} + +// TODO GODRIVER-3246: Implement OIDC human flow +// This should only be called with the Mutex held. +//func (oa *OIDCAuthenticator) getAccessTokenWithRefresh( +// ctx context.Context, +// callback OIDCCallback, +// refreshToken string, +//) (string, error) { +// +// cred, err := callback(ctx, &OIDCArgs{ +// Version: apiVersion, +// IDPInfo: oa.idpInfo, +// RefreshToken: &refreshToken, +// }) +// if err != nil { +// return "", err +// } +// +// oa.accessToken = cred.AccessToken +// oa.tokenGenID++ +// oa.cfg.Connection.SetOIDCTokenGenID(oa.tokenGenID) +// return cred.AccessToken, nil +//} + +// invalidateAccessToken invalidates the access token, if the force flag is set to true (which is +// only on a Reauth call) or if the tokenGenID of the connection is greater than or equal to the +// tokenGenID of the OIDCAuthenticator. It should never actually be greater than, but only equal, +// but this is a safety check, since extra invalidation is only a performance impact, not a +// correctness impact. +func (oa *OIDCAuthenticator) invalidateAccessToken(conn driver.Connection) { + oa.mu.Lock() + defer oa.mu.Unlock() + tokenGenID := conn.OIDCTokenGenID() + // If the connection used in a Reauth is a new connection it will not have a correct tokenGenID, + // it will instead be set to 0. In the absence of information, the only safe thing to do is to + // invalidate the cached accessToken. + if tokenGenID == 0 || tokenGenID >= oa.tokenGenID { + oa.accessToken = "" + conn.SetOIDCTokenGenID(0) + } +} + +// Reauth reauthenticates the connection when the server returns a 391 code. Reauth is part of the +// driver.Authenticator interface. +func (oa *OIDCAuthenticator) Reauth(ctx context.Context, cfg *Config) error { + oa.invalidateAccessToken(cfg.Connection) + return oa.Auth(ctx, cfg) +} + +// Auth authenticates the connection. +func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *Config) error { + var err error + + if cfg == nil { + return newAuthError(fmt.Sprintf("config must be set for %q authentication", MongoDBOIDC), nil) + } + conn := cfg.Connection + + oa.mu.Lock() + cachedAccessToken := oa.accessToken + oa.mu.Unlock() + + if cachedAccessToken != "" { + err = ConductSaslConversation(ctx, cfg, "$external", &oidcOneStep{ + userName: oa.userName, + accessToken: cachedAccessToken, + }) + if err == nil { + return nil + } + // this seems like it could be incorrect since we could be inavlidating an access token that + // has already been replaced by a different auth attempt, but the TokenGenID will prevernt + // that from happening. + oa.invalidateAccessToken(conn) + time.Sleep(invalidateSleepTimeout) + } + + if oa.OIDCHumanCallback != nil { + return oa.doAuthHuman(ctx, cfg, oa.OIDCHumanCallback) + } + + // Handle user provided or automatic provider machine callback. + var machineCallback OIDCCallback + if oa.OIDCMachineCallback != nil { + machineCallback = oa.OIDCMachineCallback + } else { + machineCallback, err = oa.providerCallback() + if err != nil { + return fmt.Errorf("error getting built-in OIDC provider: %w", err) + } + } + + if machineCallback != nil { + return oa.doAuthMachine(ctx, cfg, machineCallback) + } + return newAuthError("no OIDC callback provided", nil) +} + +func (oa *OIDCAuthenticator) doAuthHuman(_ context.Context, _ *Config, _ OIDCCallback) error { + // TODO GODRIVER-3246: Implement OIDC human flow + return newAuthError("OIDC", fmt.Errorf("human flow not implemented yet, %v", oa.idpInfo)) +} + +func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *Config, machineCallback OIDCCallback) error { + subCtx, cancel := context.WithTimeout(ctx, machineCallbackTimeout) + accessToken, err := oa.getAccessToken(subCtx, + cfg.Connection, + &OIDCArgs{ + Version: apiVersion, + // idpInfo is nil for machine callbacks in the current spec. + IDPInfo: nil, + RefreshToken: nil, + }, + machineCallback) + cancel() + if err != nil { + return err + } + return ConductSaslConversation( + ctx, + cfg, + "$external", + &oidcOneStep{accessToken: accessToken}, + ) +} + +// CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication. +func (oa *OIDCAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) { + oa.mu.Lock() + defer oa.mu.Unlock() + accessToken := oa.accessToken + if accessToken == "" { + return nil, nil // Skip speculative auth. + } + + return newSaslConversation(&oidcOneStep{accessToken: accessToken}, "$external", true), nil +} diff --git a/x/mongo/driver/auth/plain.go b/x/mongo/driver/auth/plain.go index 532d43e39f..3e4c5b4eb3 100644 --- a/x/mongo/driver/auth/plain.go +++ b/x/mongo/driver/auth/plain.go @@ -8,12 +8,15 @@ package auth import ( "context" + "net/http" + + "go.mongodb.org/mongo-driver/x/mongo/driver" ) // PLAIN is the mechanism name for PLAIN. const PLAIN = "PLAIN" -func newPlainAuthenticator(cred *Cred) (Authenticator, error) { +func newPlainAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { return &PlainAuthenticator{ Username: cred.Username, Password: cred.Password, @@ -34,6 +37,11 @@ func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *Config) error { }) } +// Reauth reauthenticates the connection. +func (a *PlainAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("Plain authentication does not support reauthentication", nil) +} + type plainSaslClient struct { username string password string diff --git a/x/mongo/driver/auth/sasl.go b/x/mongo/driver/auth/sasl.go index 2a84b53a64..75f0c411bf 100644 --- a/x/mongo/driver/auth/sasl.go +++ b/x/mongo/driver/auth/sasl.go @@ -156,7 +156,6 @@ func (sc *saslConversation) Finish(ctx context.Context, cfg *Config, firstRespon func ConductSaslConversation(ctx context.Context, cfg *Config, authSource string, client SaslClient) error { // Create a non-speculative SASL conversation. conversation := newSaslConversation(client, authSource, false) - saslStartDoc, err := conversation.FirstMessage() if err != nil { return newError(err, conversation.mechanism) diff --git a/x/mongo/driver/auth/scram.go b/x/mongo/driver/auth/scram.go index c1238cd6a9..291492e6ff 100644 --- a/x/mongo/driver/auth/scram.go +++ b/x/mongo/driver/auth/scram.go @@ -14,10 +14,12 @@ package auth import ( "context" + "net/http" "github.com/xdg-go/scram" "github.com/xdg-go/stringprep" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver" ) const ( @@ -35,7 +37,7 @@ var ( ) ) -func newScramSHA1Authenticator(cred *Cred) (Authenticator, error) { +func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { passdigest := mongoPasswordDigest(cred.Username, cred.Password) client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "") if err != nil { @@ -49,7 +51,7 @@ func newScramSHA1Authenticator(cred *Cred) (Authenticator, error) { }, nil } -func newScramSHA256Authenticator(cred *Cred) (Authenticator, error) { +func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { passprep, err := stringprep.SASLprep.Prepare(cred.Password) if err != nil { return nil, newAuthError("error SASLprepping password", err) @@ -84,6 +86,11 @@ func (a *ScramAuthenticator) Auth(ctx context.Context, cfg *Config) error { return nil } +// Reauth reauthenticates the connection. +func (a *ScramAuthenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("SCRAM does not support reauthentication", nil) +} + // CreateSpeculativeConversation creates a speculative conversation for SCRAM authentication. func (a *ScramAuthenticator) CreateSpeculativeConversation() (SpeculativeConversation, error) { return newSaslConversation(a.createSaslClient(), a.source, true), nil diff --git a/x/mongo/driver/auth/scram_test.go b/x/mongo/driver/auth/scram_test.go index ef30a07364..0a745885ee 100644 --- a/x/mongo/driver/auth/scram_test.go +++ b/x/mongo/driver/auth/scram_test.go @@ -8,6 +8,7 @@ package auth import ( "context" + "net/http" "testing" "go.mongodb.org/mongo-driver/internal/assert" @@ -38,7 +39,7 @@ func TestSCRAM(t *testing.T) { t.Run("conversation", func(t *testing.T) { testCases := []struct { name string - createAuthenticatorFn func(*Cred) (Authenticator, error) + createAuthenticatorFn func(*Cred, *http.Client) (Authenticator, error) payloads [][]byte nonce string }{ @@ -49,11 +50,13 @@ func TestSCRAM(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - authenticator, err := tc.createAuthenticatorFn(&Cred{ - Username: "user", - Password: "pencil", - Source: "admin", - }) + authenticator, err := tc.createAuthenticatorFn( + &Cred{ + Username: "user", + Password: "pencil", + Source: "admin", + }, + &http.Client{}) assert.Nil(t, err, "error creating authenticator: %v", err) sa, _ := authenticator.(*ScramAuthenticator) sa.client = sa.client.WithNonceGenerator(func() string { diff --git a/x/mongo/driver/auth/speculative_scram_test.go b/x/mongo/driver/auth/speculative_scram_test.go index a159891adc..9108fe1d21 100644 --- a/x/mongo/driver/auth/speculative_scram_test.go +++ b/x/mongo/driver/auth/speculative_scram_test.go @@ -9,6 +9,7 @@ package auth import ( "bytes" "context" + "net/http" "testing" "go.mongodb.org/mongo-driver/bson" @@ -63,7 +64,7 @@ func TestSpeculativeSCRAM(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Create a SCRAM authenticator and overwrite the nonce generator to make the conversation // deterministic. - authenticator, err := CreateAuthenticator(tc.mechanism, cred) + authenticator, err := CreateAuthenticator(tc.mechanism, cred, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) setNonce(t, authenticator, tc.nonce) @@ -148,7 +149,7 @@ func TestSpeculativeSCRAM(t *testing.T) { for _, tc := range testCases { t.Run(tc.mechanism, func(t *testing.T) { - authenticator, err := CreateAuthenticator(tc.mechanism, cred) + authenticator, err := CreateAuthenticator(tc.mechanism, cred, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) setNonce(t, authenticator, tc.nonce) diff --git a/x/mongo/driver/auth/speculative_x509_test.go b/x/mongo/driver/auth/speculative_x509_test.go index 85bd93191b..e26b448e79 100644 --- a/x/mongo/driver/auth/speculative_x509_test.go +++ b/x/mongo/driver/auth/speculative_x509_test.go @@ -9,6 +9,7 @@ package auth import ( "bytes" "context" + "net/http" "testing" "go.mongodb.org/mongo-driver/bson" @@ -32,7 +33,7 @@ func TestSpeculativeX509(t *testing.T) { // Tests for X509 when the hello response contains a reply to the speculative authentication attempt. The // driver should not send any more commands after the hello. - authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}) + authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) handshaker := Handshaker(nil, &HandshakeOptions{ Authenticator: authenticator, @@ -76,7 +77,7 @@ func TestSpeculativeX509(t *testing.T) { // Tests for X509 when the hello response does not contain a reply to the speculative authentication attempt. // The driver should send an authenticate command after the hello. - authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}) + authenticator, err := CreateAuthenticator("MONGODB-X509", &Cred{}, &http.Client{}) assert.Nil(t, err, "CreateAuthenticator error: %v", err) handshaker := Handshaker(nil, &HandshakeOptions{ Authenticator: authenticator, diff --git a/x/mongo/driver/auth/x509.go b/x/mongo/driver/auth/x509.go index 03a9d750e2..3e84f516f8 100644 --- a/x/mongo/driver/auth/x509.go +++ b/x/mongo/driver/auth/x509.go @@ -8,6 +8,7 @@ package auth import ( "context" + "net/http" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -17,7 +18,7 @@ import ( // MongoDBX509 is the mechanism name for MongoDBX509. const MongoDBX509 = "MONGODB-X509" -func newMongoDBX509Authenticator(cred *Cred) (Authenticator, error) { +func newMongoDBX509Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { return &MongoDBX509Authenticator{User: cred.Username}, nil } @@ -76,3 +77,8 @@ func (a *MongoDBX509Authenticator) Auth(ctx context.Context, cfg *Config) error return nil } + +// Reauth reauthenticates the connection. +func (a *MongoDBX509Authenticator) Reauth(_ context.Context, _ *driver.AuthConfig) error { + return newAuthError("X509 does not support reauthentication", nil) +} diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index 686458e292..a8adafb8f8 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -302,6 +302,13 @@ func (u *ConnString) setDefaultAuthParams(dbName string) error { u.AuthSource = "admin" } } + case "mongodb-oidc": + if u.AuthSource == "" { + u.AuthSource = dbName + if u.AuthSource == "" { + u.AuthSource = "$external" + } + } case "": // Only set auth source if there is a request for authentication via non-empty credentials. if u.AuthSource == "" && (u.AuthMechanismProperties != nil || u.Username != "" || u.PasswordSet) { @@ -781,6 +788,10 @@ func (u *ConnString) validateAuth() error { if u.AuthMechanismProperties != nil { return fmt.Errorf("SCRAM-SHA-256 cannot have mechanism properties") } + case "mongodb-oidc": + if u.Password != "" { + return fmt.Errorf("password cannot be specified for MONGODB-OIDC") + } case "": if u.UsernameSet && u.Username == "" { return fmt.Errorf("username required if URI contains user info") diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 900729bf87..363f4d6be3 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -24,6 +24,63 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) +// AuthConfig holds the information necessary to perform an authentication attempt. +// this was moved from the auth package to avoid a circular dependency. The auth package +// reexports this under the old name to avoid breaking the public api. +type AuthConfig struct { + Description description.Server + Connection Connection + ClusterClock *session.ClusterClock + HandshakeInfo HandshakeInformation + ServerAPI *ServerAPIOptions +} + +// OIDCCallback is the type for both Human and Machine Callback flows. RefreshToken will always be +// nil in the OIDCArgs for the Machine flow. +type OIDCCallback func(context.Context, *OIDCArgs) (*OIDCCredential, error) + +// OIDCArgs contains the arguments for the OIDC callback. +type OIDCArgs struct { + Version int + IDPInfo *IDPInfo + RefreshToken *string +} + +// OIDCCredential contains the access token and refresh token. +type OIDCCredential struct { + AccessToken string + ExpiresAt *time.Time + RefreshToken *string +} + +// IDPInfo contains the information needed to perform OIDC authentication with an Identity Provider. +type IDPInfo struct { + Issuer string `bson:"issuer"` + ClientID string `bson:"clientId"` + RequestScopes []string `bson:"requestScopes"` +} + +// Authenticator handles authenticating a connection. The implementers of this interface +// are all in the auth package. Most authentication mechanisms do not allow for Reauth, +// but this is included in the interface so that whenever a new mechanism is added, it +// must be explicitly considered. +type Authenticator interface { + // Auth authenticates the connection. + Auth(context.Context, *AuthConfig) error + Reauth(context.Context, *AuthConfig) error +} + +// Cred is a user's credential. +type Cred struct { + Source string + Username string + Password string + PasswordSet bool + Props map[string]string + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback +} + // Deployment is implemented by types that can select a server from a deployment. type Deployment interface { SelectServer(context.Context, description.ServerSelector) (Server, error) @@ -79,6 +136,8 @@ type Connection interface { DriverConnectionID() uint64 // TODO(GODRIVER-2824): change type to int64. Address() address.Address Stale() bool + OIDCTokenGenID() uint64 + SetOIDCTokenGenID(uint64) } // RTTMonitor represents a round-trip-time monitor. diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index 27be4c264d..d002398a5b 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -26,6 +26,16 @@ type ChannelConn struct { Desc description.Server } +// OIDCTokenGenID implements the driver.Connection interface by returning the OIDCToken generation +// (which is always 0) +func (c *ChannelConn) OIDCTokenGenID() uint64 { + return 0 +} + +// SetOIDCTokenGenID implements the driver.Connection interface by setting the OIDCToken generation +// (which is always 0) +func (c *ChannelConn) SetOIDCTokenGenID(uint64) {} + // WriteWireMessage implements the driver.Connection interface. func (c *ChannelConn) WriteWireMessage(ctx context.Context, wm []byte) error { // Copy wm in case it came from a buffer pool. diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index db5367bed5..cea3543d14 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -315,6 +315,10 @@ type Operation struct { // [Operation.MaxTime]. OmitCSOTMaxTimeMS bool + // Authenticator is the authenticator to use for this operation when a reauthentication is + // required. + Authenticator Authenticator + // omitReadPreference is a boolean that indicates whether to omit the // read preference from the command. This omition includes the case // where a default read preference is used when the operation @@ -912,6 +916,28 @@ func (op Operation) Execute(ctx context.Context) error { operationErr.Labels = tt.Labels operationErr.Raw = tt.Raw case Error: + // 391 is the reauthentication required error code, so we will attempt a reauth and + // retry the operation, if it is successful. + if tt.Code == 391 { + if op.Authenticator != nil { + cfg := AuthConfig{ + Description: conn.Description(), + Connection: conn, + ClusterClock: op.Clock, + ServerAPI: op.ServerAPI, + } + if err := op.Authenticator.Reauth(ctx, &cfg); err != nil { + return fmt.Errorf("error reauthenticating: %w", err) + } + if op.Client != nil && op.Client.Committing { + // Apply majority write concern for retries + op.Client.UpdateCommitTransactionWriteConcern() + op.WriteConcern = op.Client.CurrentWc + } + resetForRetry(tt) + continue + } + } if tt.HasErrorLabel(TransientTransactionError) || tt.HasErrorLabel(UnknownTransactionCommitResult) { if err := op.Client.ClearPinnedResources(); err != nil { return err diff --git a/x/mongo/driver/operation/abort_transaction.go b/x/mongo/driver/operation/abort_transaction.go index 9413727130..aeee533533 100644 --- a/x/mongo/driver/operation/abort_transaction.go +++ b/x/mongo/driver/operation/abort_transaction.go @@ -21,6 +21,7 @@ import ( // AbortTransaction performs an abortTransaction operation. type AbortTransaction struct { + authenticator driver.Authenticator recoveryToken bsoncore.Document session *session.Client clock *session.ClusterClock @@ -66,6 +67,7 @@ func (at *AbortTransaction) Execute(ctx context.Context) error { WriteConcern: at.writeConcern, ServerAPI: at.serverAPI, Name: driverutil.AbortTransactionOp, + Authenticator: at.authenticator, }.Execute(ctx) } @@ -199,3 +201,13 @@ func (at *AbortTransaction) ServerAPI(serverAPI *driver.ServerAPIOptions) *Abort at.serverAPI = serverAPI return at } + +// Authenticator sets the authenticator to use for this operation. +func (at *AbortTransaction) Authenticator(authenticator driver.Authenticator) *AbortTransaction { + if at == nil { + at = new(AbortTransaction) + } + + at.authenticator = authenticator + return at +} diff --git a/x/mongo/driver/operation/aggregate.go b/x/mongo/driver/operation/aggregate.go index 44467df8fd..df6b8fa9dd 100644 --- a/x/mongo/driver/operation/aggregate.go +++ b/x/mongo/driver/operation/aggregate.go @@ -25,6 +25,7 @@ import ( // Aggregate represents an aggregate operation. type Aggregate struct { + authenticator driver.Authenticator allowDiskUse *bool batchSize *int32 bypassDocumentValidation *bool @@ -115,6 +116,7 @@ func (a *Aggregate) Execute(ctx context.Context) error { Timeout: a.timeout, Name: driverutil.AggregateOp, OmitCSOTMaxTimeMS: a.omitCSOTMaxTimeMS, + Authenticator: a.authenticator, }.Execute(ctx) } @@ -433,3 +435,13 @@ func (a *Aggregate) OmitCSOTMaxTimeMS(omit bool) *Aggregate { a.omitCSOTMaxTimeMS = omit return a } + +// Authenticator sets the authenticator to use for this operation. +func (a *Aggregate) Authenticator(authenticator driver.Authenticator) *Aggregate { + if a == nil { + a = new(Aggregate) + } + + a.authenticator = authenticator + return a +} diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index 35283794a3..9dd10f3cb0 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -22,6 +22,7 @@ import ( // Command is used to run a generic operation. type Command struct { + authenticator driver.Authenticator command bsoncore.Document database string deployment driver.Deployment @@ -107,6 +108,7 @@ func (c *Command) Execute(ctx context.Context) error { ServerAPI: c.serverAPI, Timeout: c.timeout, Logger: c.logger, + Authenticator: c.authenticator, }.Execute(ctx) } @@ -219,3 +221,13 @@ func (c *Command) Logger(logger *logger.Logger) *Command { c.logger = logger return c } + +// Authenticator sets the authenticator to use for this operation. +func (c *Command) Authenticator(authenticator driver.Authenticator) *Command { + if c == nil { + c = new(Command) + } + + c.authenticator = authenticator + return c +} diff --git a/x/mongo/driver/operation/commit_transaction.go b/x/mongo/driver/operation/commit_transaction.go index 11c6f69ddf..6b402bdf63 100644 --- a/x/mongo/driver/operation/commit_transaction.go +++ b/x/mongo/driver/operation/commit_transaction.go @@ -22,6 +22,7 @@ import ( // CommitTransaction attempts to commit a transaction. type CommitTransaction struct { + authenticator driver.Authenticator maxTime *time.Duration recoveryToken bsoncore.Document session *session.Client @@ -68,6 +69,7 @@ func (ct *CommitTransaction) Execute(ctx context.Context) error { WriteConcern: ct.writeConcern, ServerAPI: ct.serverAPI, Name: driverutil.CommitTransactionOp, + Authenticator: ct.authenticator, }.Execute(ctx) } @@ -201,3 +203,13 @@ func (ct *CommitTransaction) ServerAPI(serverAPI *driver.ServerAPIOptions) *Comm ct.serverAPI = serverAPI return ct } + +// Authenticator sets the authenticator to use for this operation. +func (ct *CommitTransaction) Authenticator(authenticator driver.Authenticator) *CommitTransaction { + if ct == nil { + ct = new(CommitTransaction) + } + + ct.authenticator = authenticator + return ct +} diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index 8de1e9f8d9..eaafc9a244 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -25,6 +25,7 @@ import ( // Count represents a count operation. type Count struct { + authenticator driver.Authenticator maxTime *time.Duration query bsoncore.Document session *session.Client @@ -128,6 +129,7 @@ func (c *Count) Execute(ctx context.Context) error { ServerAPI: c.serverAPI, Timeout: c.timeout, Name: driverutil.CountOp, + Authenticator: c.authenticator, }.Execute(ctx) // Swallow error if NamespaceNotFound(26) is returned from aggregate on non-existent namespace @@ -311,3 +313,13 @@ func (c *Count) Timeout(timeout *time.Duration) *Count { c.timeout = timeout return c } + +// Authenticator sets the authenticator to use for this operation. +func (c *Count) Authenticator(authenticator driver.Authenticator) *Count { + if c == nil { + c = new(Count) + } + + c.authenticator = authenticator + return c +} diff --git a/x/mongo/driver/operation/create.go b/x/mongo/driver/operation/create.go index 45b26cb707..4878e2c777 100644 --- a/x/mongo/driver/operation/create.go +++ b/x/mongo/driver/operation/create.go @@ -20,6 +20,7 @@ import ( // Create represents a create operation. type Create struct { + authenticator driver.Authenticator capped *bool collation bsoncore.Document changeStreamPreAndPostImages bsoncore.Document @@ -77,6 +78,7 @@ func (c *Create) Execute(ctx context.Context) error { Selector: c.selector, WriteConcern: c.writeConcern, ServerAPI: c.serverAPI, + Authenticator: c.authenticator, }.Execute(ctx) } @@ -399,3 +401,13 @@ func (c *Create) ClusteredIndex(ci bsoncore.Document) *Create { c.clusteredIndex = ci return c } + +// Authenticator sets the authenticator to use for this operation. +func (c *Create) Authenticator(authenticator driver.Authenticator) *Create { + if c == nil { + c = new(Create) + } + + c.authenticator = authenticator + return c +} diff --git a/x/mongo/driver/operation/create_indexes.go b/x/mongo/driver/operation/create_indexes.go index 77daf676a4..464c1762de 100644 --- a/x/mongo/driver/operation/create_indexes.go +++ b/x/mongo/driver/operation/create_indexes.go @@ -24,21 +24,22 @@ import ( // CreateIndexes performs a createIndexes operation. type CreateIndexes struct { - commitQuorum bsoncore.Value - indexes bsoncore.Document - maxTime *time.Duration - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - result CreateIndexesResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + commitQuorum bsoncore.Value + indexes bsoncore.Document + maxTime *time.Duration + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + result CreateIndexesResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // CreateIndexesResult represents a createIndexes result returned by the server. @@ -119,6 +120,7 @@ func (ci *CreateIndexes) Execute(ctx context.Context) error { ServerAPI: ci.serverAPI, Timeout: ci.timeout, Name: driverutil.CreateIndexesOp, + Authenticator: ci.authenticator, }.Execute(ctx) } @@ -278,3 +280,13 @@ func (ci *CreateIndexes) Timeout(timeout *time.Duration) *CreateIndexes { ci.timeout = timeout return ci } + +// Authenticator sets the authenticator to use for this operation. +func (ci *CreateIndexes) Authenticator(authenticator driver.Authenticator) *CreateIndexes { + if ci == nil { + ci = new(CreateIndexes) + } + + ci.authenticator = authenticator + return ci +} diff --git a/x/mongo/driver/operation/create_search_indexes.go b/x/mongo/driver/operation/create_search_indexes.go index cb0d807952..8185d27fe1 100644 --- a/x/mongo/driver/operation/create_search_indexes.go +++ b/x/mongo/driver/operation/create_search_indexes.go @@ -22,18 +22,19 @@ import ( // CreateSearchIndexes performs a createSearchIndexes operation. type CreateSearchIndexes struct { - indexes bsoncore.Document - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - result CreateSearchIndexesResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + indexes bsoncore.Document + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + result CreateSearchIndexesResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // CreateSearchIndexResult represents a single search index result in CreateSearchIndexesResult. @@ -116,6 +117,7 @@ func (csi *CreateSearchIndexes) Execute(ctx context.Context) error { Selector: csi.selector, ServerAPI: csi.serverAPI, Timeout: csi.timeout, + Authenticator: csi.authenticator, }.Execute(ctx) } @@ -237,3 +239,13 @@ func (csi *CreateSearchIndexes) Timeout(timeout *time.Duration) *CreateSearchInd csi.timeout = timeout return csi } + +// Authenticator sets the authenticator to use for this operation. +func (csi *CreateSearchIndexes) Authenticator(authenticator driver.Authenticator) *CreateSearchIndexes { + if csi == nil { + csi = new(CreateSearchIndexes) + } + + csi.authenticator = authenticator + return csi +} diff --git a/x/mongo/driver/operation/delete.go b/x/mongo/driver/operation/delete.go index bf95cf496d..298ec44196 100644 --- a/x/mongo/driver/operation/delete.go +++ b/x/mongo/driver/operation/delete.go @@ -25,25 +25,26 @@ import ( // Delete performs a delete operation type Delete struct { - comment bsoncore.Value - deletes []bsoncore.Document - ordered *bool - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - retry *driver.RetryMode - hint *bool - result DeleteResult - serverAPI *driver.ServerAPIOptions - let bsoncore.Document - timeout *time.Duration - logger *logger.Logger + authenticator driver.Authenticator + comment bsoncore.Value + deletes []bsoncore.Document + ordered *bool + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + retry *driver.RetryMode + hint *bool + result DeleteResult + serverAPI *driver.ServerAPIOptions + let bsoncore.Document + timeout *time.Duration + logger *logger.Logger } // DeleteResult represents a delete result returned by the server. @@ -116,6 +117,7 @@ func (d *Delete) Execute(ctx context.Context) error { Timeout: d.timeout, Logger: d.logger, Name: driverutil.DeleteOp, + Authenticator: d.authenticator, }.Execute(ctx) } @@ -328,3 +330,13 @@ func (d *Delete) Logger(logger *logger.Logger) *Delete { return d } + +// Authenticator sets the authenticator to use for this operation. +func (d *Delete) Authenticator(authenticator driver.Authenticator) *Delete { + if d == nil { + d = new(Delete) + } + + d.authenticator = authenticator + return d +} diff --git a/x/mongo/driver/operation/distinct.go b/x/mongo/driver/operation/distinct.go index b7e675ce42..484d96b66b 100644 --- a/x/mongo/driver/operation/distinct.go +++ b/x/mongo/driver/operation/distinct.go @@ -24,6 +24,7 @@ import ( // Distinct performs a distinct operation. type Distinct struct { + authenticator driver.Authenticator collation bsoncore.Document key *string maxTime *time.Duration @@ -107,6 +108,7 @@ func (d *Distinct) Execute(ctx context.Context) error { ServerAPI: d.serverAPI, Timeout: d.timeout, Name: driverutil.DistinctOp, + Authenticator: d.authenticator, }.Execute(ctx) } @@ -311,3 +313,13 @@ func (d *Distinct) Timeout(timeout *time.Duration) *Distinct { d.timeout = timeout return d } + +// Authenticator sets the authenticator to use for this operation. +func (d *Distinct) Authenticator(authenticator driver.Authenticator) *Distinct { + if d == nil { + d = new(Distinct) + } + + d.authenticator = authenticator + return d +} diff --git a/x/mongo/driver/operation/drop_collection.go b/x/mongo/driver/operation/drop_collection.go index 8c65967564..5a32c2f8d4 100644 --- a/x/mongo/driver/operation/drop_collection.go +++ b/x/mongo/driver/operation/drop_collection.go @@ -23,18 +23,19 @@ import ( // DropCollection performs a drop operation. type DropCollection struct { - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - result DropCollectionResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + result DropCollectionResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // DropCollectionResult represents a dropCollection result returned by the server. @@ -104,6 +105,7 @@ func (dc *DropCollection) Execute(ctx context.Context) error { ServerAPI: dc.serverAPI, Timeout: dc.timeout, Name: driverutil.DropOp, + Authenticator: dc.authenticator, }.Execute(ctx) } @@ -222,3 +224,13 @@ func (dc *DropCollection) Timeout(timeout *time.Duration) *DropCollection { dc.timeout = timeout return dc } + +// Authenticator sets the authenticator to use for this operation. +func (dc *DropCollection) Authenticator(authenticator driver.Authenticator) *DropCollection { + if dc == nil { + dc = new(DropCollection) + } + + dc.authenticator = authenticator + return dc +} diff --git a/x/mongo/driver/operation/drop_database.go b/x/mongo/driver/operation/drop_database.go index a8f9b45ba4..19956210d1 100644 --- a/x/mongo/driver/operation/drop_database.go +++ b/x/mongo/driver/operation/drop_database.go @@ -21,15 +21,16 @@ import ( // DropDatabase performs a dropDatabase operation type DropDatabase struct { - session *session.Client - clock *session.ClusterClock - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - serverAPI *driver.ServerAPIOptions + authenticator driver.Authenticator + session *session.Client + clock *session.ClusterClock + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + serverAPI *driver.ServerAPIOptions } // NewDropDatabase constructs and returns a new DropDatabase. @@ -55,6 +56,7 @@ func (dd *DropDatabase) Execute(ctx context.Context) error { WriteConcern: dd.writeConcern, ServerAPI: dd.serverAPI, Name: driverutil.DropDatabaseOp, + Authenticator: dd.authenticator, }.Execute(ctx) } @@ -154,3 +156,13 @@ func (dd *DropDatabase) ServerAPI(serverAPI *driver.ServerAPIOptions) *DropDatab dd.serverAPI = serverAPI return dd } + +// Authenticator sets the authenticator to use for this operation. +func (dd *DropDatabase) Authenticator(authenticator driver.Authenticator) *DropDatabase { + if dd == nil { + dd = new(DropDatabase) + } + + dd.authenticator = authenticator + return dd +} diff --git a/x/mongo/driver/operation/drop_indexes.go b/x/mongo/driver/operation/drop_indexes.go index 0c3d459707..e4f924e4e1 100644 --- a/x/mongo/driver/operation/drop_indexes.go +++ b/x/mongo/driver/operation/drop_indexes.go @@ -23,20 +23,21 @@ import ( // DropIndexes performs an dropIndexes operation. type DropIndexes struct { - index *string - maxTime *time.Duration - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - writeConcern *writeconcern.WriteConcern - result DropIndexesResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + index *string + maxTime *time.Duration + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + writeConcern *writeconcern.WriteConcern + result DropIndexesResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // DropIndexesResult represents a dropIndexes result returned by the server. @@ -101,6 +102,7 @@ func (di *DropIndexes) Execute(ctx context.Context) error { ServerAPI: di.serverAPI, Timeout: di.timeout, Name: driverutil.DropIndexesOp, + Authenticator: di.authenticator, }.Execute(ctx) } @@ -242,3 +244,13 @@ func (di *DropIndexes) Timeout(timeout *time.Duration) *DropIndexes { di.timeout = timeout return di } + +// Authenticator sets the authenticator to use for this operation. +func (di *DropIndexes) Authenticator(authenticator driver.Authenticator) *DropIndexes { + if di == nil { + di = new(DropIndexes) + } + + di.authenticator = authenticator + return di +} diff --git a/x/mongo/driver/operation/drop_search_index.go b/x/mongo/driver/operation/drop_search_index.go index 3992c83165..3d273434d5 100644 --- a/x/mongo/driver/operation/drop_search_index.go +++ b/x/mongo/driver/operation/drop_search_index.go @@ -21,18 +21,19 @@ import ( // DropSearchIndex performs an dropSearchIndex operation. type DropSearchIndex struct { - index string - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - result DropSearchIndexResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + index string + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + result DropSearchIndexResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // DropSearchIndexResult represents a dropSearchIndex result returned by the server. @@ -93,6 +94,7 @@ func (dsi *DropSearchIndex) Execute(ctx context.Context) error { Selector: dsi.selector, ServerAPI: dsi.serverAPI, Timeout: dsi.timeout, + Authenticator: dsi.authenticator, }.Execute(ctx) } @@ -212,3 +214,13 @@ func (dsi *DropSearchIndex) Timeout(timeout *time.Duration) *DropSearchIndex { dsi.timeout = timeout return dsi } + +// Authenticator sets the authenticator to use for this operation. +func (dsi *DropSearchIndex) Authenticator(authenticator driver.Authenticator) *DropSearchIndex { + if dsi == nil { + dsi = new(DropSearchIndex) + } + + dsi.authenticator = authenticator + return dsi +} diff --git a/x/mongo/driver/operation/end_sessions.go b/x/mongo/driver/operation/end_sessions.go index 52f300bb7f..8b24b3d8c2 100644 --- a/x/mongo/driver/operation/end_sessions.go +++ b/x/mongo/driver/operation/end_sessions.go @@ -20,15 +20,16 @@ import ( // EndSessions performs an endSessions operation. type EndSessions struct { - sessionIDs bsoncore.Document - session *session.Client - clock *session.ClusterClock - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - serverAPI *driver.ServerAPIOptions + authenticator driver.Authenticator + sessionIDs bsoncore.Document + session *session.Client + clock *session.ClusterClock + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + serverAPI *driver.ServerAPIOptions } // NewEndSessions constructs and returns a new EndSessions. @@ -61,6 +62,7 @@ func (es *EndSessions) Execute(ctx context.Context) error { Selector: es.selector, ServerAPI: es.serverAPI, Name: driverutil.EndSessionsOp, + Authenticator: es.authenticator, }.Execute(ctx) } @@ -161,3 +163,13 @@ func (es *EndSessions) ServerAPI(serverAPI *driver.ServerAPIOptions) *EndSession es.serverAPI = serverAPI return es } + +// Authenticator sets the authenticator to use for this operation. +func (es *EndSessions) Authenticator(authenticator driver.Authenticator) *EndSessions { + if es == nil { + es = new(EndSessions) + } + + es.authenticator = authenticator + return es +} diff --git a/x/mongo/driver/operation/find.go b/x/mongo/driver/operation/find.go index 8950fde86d..c71b7d755e 100644 --- a/x/mongo/driver/operation/find.go +++ b/x/mongo/driver/operation/find.go @@ -25,6 +25,7 @@ import ( // Find performs a find operation. type Find struct { + authenticator driver.Authenticator allowDiskUse *bool allowPartialResults *bool awaitData *bool @@ -112,6 +113,7 @@ func (f *Find) Execute(ctx context.Context) error { Logger: f.logger, Name: driverutil.FindOp, OmitCSOTMaxTimeMS: f.omitCSOTMaxTimeMS, + Authenticator: f.authenticator, }.Execute(ctx) } @@ -575,3 +577,13 @@ func (f *Find) Logger(logger *logger.Logger) *Find { f.logger = logger return f } + +// Authenticator sets the authenticator to use for this operation. +func (f *Find) Authenticator(authenticator driver.Authenticator) *Find { + if f == nil { + f = new(Find) + } + + f.authenticator = authenticator + return f +} diff --git a/x/mongo/driver/operation/find_and_modify.go b/x/mongo/driver/operation/find_and_modify.go index 7faf561135..ea365ccb23 100644 --- a/x/mongo/driver/operation/find_and_modify.go +++ b/x/mongo/driver/operation/find_and_modify.go @@ -25,6 +25,7 @@ import ( // FindAndModify performs a findAndModify operation. type FindAndModify struct { + authenticator driver.Authenticator arrayFilters bsoncore.Array bypassDocumentValidation *bool collation bsoncore.Document @@ -145,6 +146,7 @@ func (fam *FindAndModify) Execute(ctx context.Context) error { ServerAPI: fam.serverAPI, Timeout: fam.timeout, Name: driverutil.FindAndModifyOp, + Authenticator: fam.authenticator, }.Execute(ctx) } @@ -477,3 +479,13 @@ func (fam *FindAndModify) Timeout(timeout *time.Duration) *FindAndModify { fam.timeout = timeout return fam } + +// Authenticator sets the authenticator to use for this operation. +func (fam *FindAndModify) Authenticator(authenticator driver.Authenticator) *FindAndModify { + if fam == nil { + fam = new(FindAndModify) + } + + fam.authenticator = authenticator + return fam +} diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 16f2ebf6c0..60c99f063d 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -36,6 +36,7 @@ const driverName = "mongo-go-driver" // Hello is used to run the handshake operation. type Hello struct { + authenticator driver.Authenticator appname string compressors []string saslSupportedMechs string @@ -649,3 +650,13 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, func (h *Hello) FinishHandshake(context.Context, driver.Connection) error { return nil } + +// Authenticator sets the authenticator to use for this operation. +func (h *Hello) Authenticator(authenticator driver.Authenticator) *Hello { + if h == nil { + h = new(Hello) + } + + h.authenticator = authenticator + return h +} diff --git a/x/mongo/driver/operation/insert.go b/x/mongo/driver/operation/insert.go index 7da4b8b0fb..f5afe31169 100644 --- a/x/mongo/driver/operation/insert.go +++ b/x/mongo/driver/operation/insert.go @@ -25,6 +25,7 @@ import ( // Insert performs an insert operation. type Insert struct { + authenticator driver.Authenticator bypassDocumentValidation *bool comment bsoncore.Value documents []bsoncore.Document @@ -115,6 +116,7 @@ func (i *Insert) Execute(ctx context.Context) error { Timeout: i.timeout, Logger: i.logger, Name: driverutil.InsertOp, + Authenticator: i.authenticator, }.Execute(ctx) } @@ -306,3 +308,13 @@ func (i *Insert) Logger(logger *logger.Logger) *Insert { i.logger = logger return i } + +// Authenticator sets the authenticator to use for this operation. +func (i *Insert) Authenticator(authenticator driver.Authenticator) *Insert { + if i == nil { + i = new(Insert) + } + + i.authenticator = authenticator + return i +} diff --git a/x/mongo/driver/operation/listDatabases.go b/x/mongo/driver/operation/listDatabases.go index c70248e2a9..3df171e37a 100644 --- a/x/mongo/driver/operation/listDatabases.go +++ b/x/mongo/driver/operation/listDatabases.go @@ -24,6 +24,7 @@ import ( // ListDatabases performs a listDatabases operation. type ListDatabases struct { + authenticator driver.Authenticator filter bsoncore.Document authorizedDatabases *bool nameOnly *bool @@ -165,6 +166,7 @@ func (ld *ListDatabases) Execute(ctx context.Context) error { ServerAPI: ld.serverAPI, Timeout: ld.timeout, Name: driverutil.ListDatabasesOp, + Authenticator: ld.authenticator, }.Execute(ctx) } @@ -327,3 +329,13 @@ func (ld *ListDatabases) Timeout(timeout *time.Duration) *ListDatabases { ld.timeout = timeout return ld } + +// Authenticator sets the authenticator to use for this operation. +func (ld *ListDatabases) Authenticator(authenticator driver.Authenticator) *ListDatabases { + if ld == nil { + ld = new(ListDatabases) + } + + ld.authenticator = authenticator + return ld +} diff --git a/x/mongo/driver/operation/list_collections.go b/x/mongo/driver/operation/list_collections.go index 6fe68fa033..1e39f5bfbe 100644 --- a/x/mongo/driver/operation/list_collections.go +++ b/x/mongo/driver/operation/list_collections.go @@ -22,6 +22,7 @@ import ( // ListCollections performs a listCollections operation. type ListCollections struct { + authenticator driver.Authenticator filter bsoncore.Document nameOnly *bool authorizedCollections *bool @@ -83,6 +84,7 @@ func (lc *ListCollections) Execute(ctx context.Context) error { ServerAPI: lc.serverAPI, Timeout: lc.timeout, Name: driverutil.ListCollectionsOp, + Authenticator: lc.authenticator, }.Execute(ctx) } @@ -259,3 +261,13 @@ func (lc *ListCollections) Timeout(timeout *time.Duration) *ListCollections { lc.timeout = timeout return lc } + +// Authenticator sets the authenticator to use for this operation. +func (lc *ListCollections) Authenticator(authenticator driver.Authenticator) *ListCollections { + if lc == nil { + lc = new(ListCollections) + } + + lc.authenticator = authenticator + return lc +} diff --git a/x/mongo/driver/operation/list_indexes.go b/x/mongo/driver/operation/list_indexes.go index 79d50eca95..433344f307 100644 --- a/x/mongo/driver/operation/list_indexes.go +++ b/x/mongo/driver/operation/list_indexes.go @@ -21,19 +21,20 @@ import ( // ListIndexes performs a listIndexes operation. type ListIndexes struct { - batchSize *int32 - maxTime *time.Duration - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - database string - deployment driver.Deployment - selector description.ServerSelector - retry *driver.RetryMode - crypt driver.Crypt - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + batchSize *int32 + maxTime *time.Duration + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + database string + deployment driver.Deployment + selector description.ServerSelector + retry *driver.RetryMode + crypt driver.Crypt + serverAPI *driver.ServerAPIOptions + timeout *time.Duration result driver.CursorResponse } @@ -85,6 +86,7 @@ func (li *ListIndexes) Execute(ctx context.Context) error { ServerAPI: li.serverAPI, Timeout: li.timeout, Name: driverutil.ListIndexesOp, + Authenticator: li.authenticator, }.Execute(ctx) } @@ -233,3 +235,13 @@ func (li *ListIndexes) Timeout(timeout *time.Duration) *ListIndexes { li.timeout = timeout return li } + +// Authenticator sets the authenticator to use for this operation. +func (li *ListIndexes) Authenticator(authenticator driver.Authenticator) *ListIndexes { + if li == nil { + li = new(ListIndexes) + } + + li.authenticator = authenticator + return li +} diff --git a/x/mongo/driver/operation/update.go b/x/mongo/driver/operation/update.go index 881b1bcf7b..1070e7ca70 100644 --- a/x/mongo/driver/operation/update.go +++ b/x/mongo/driver/operation/update.go @@ -26,6 +26,7 @@ import ( // Update performs an update operation. type Update struct { + authenticator driver.Authenticator bypassDocumentValidation *bool comment bsoncore.Value ordered *bool @@ -167,6 +168,7 @@ func (u *Update) Execute(ctx context.Context) error { Timeout: u.timeout, Logger: u.logger, Name: driverutil.UpdateOp, + Authenticator: u.authenticator, }.Execute(ctx) } @@ -414,3 +416,13 @@ func (u *Update) Logger(logger *logger.Logger) *Update { u.logger = logger return u } + +// Authenticator sets the authenticator to use for this operation. +func (u *Update) Authenticator(authenticator driver.Authenticator) *Update { + if u == nil { + u = new(Update) + } + + u.authenticator = authenticator + return u +} diff --git a/x/mongo/driver/operation/update_search_index.go b/x/mongo/driver/operation/update_search_index.go index 64f2da7f6f..4ed9946c69 100644 --- a/x/mongo/driver/operation/update_search_index.go +++ b/x/mongo/driver/operation/update_search_index.go @@ -21,19 +21,20 @@ import ( // UpdateSearchIndex performs a updateSearchIndex operation. type UpdateSearchIndex struct { - index string - definition bsoncore.Document - session *session.Client - clock *session.ClusterClock - collection string - monitor *event.CommandMonitor - crypt driver.Crypt - database string - deployment driver.Deployment - selector description.ServerSelector - result UpdateSearchIndexResult - serverAPI *driver.ServerAPIOptions - timeout *time.Duration + authenticator driver.Authenticator + index string + definition bsoncore.Document + session *session.Client + clock *session.ClusterClock + collection string + monitor *event.CommandMonitor + crypt driver.Crypt + database string + deployment driver.Deployment + selector description.ServerSelector + result UpdateSearchIndexResult + serverAPI *driver.ServerAPIOptions + timeout *time.Duration } // UpdateSearchIndexResult represents a single index in the updateSearchIndexResult result. @@ -95,6 +96,7 @@ func (usi *UpdateSearchIndex) Execute(ctx context.Context) error { Selector: usi.selector, ServerAPI: usi.serverAPI, Timeout: usi.timeout, + Authenticator: usi.authenticator, }.Execute(ctx) } @@ -225,3 +227,13 @@ func (usi *UpdateSearchIndex) Timeout(timeout *time.Duration) *UpdateSearchIndex usi.timeout = timeout return usi } + +// Authenticator sets the authenticator to use for this operation. +func (usi *UpdateSearchIndex) Authenticator(authenticator driver.Authenticator) *UpdateSearchIndex { + if usi == nil { + usi = new(UpdateSearchIndex) + } + + usi.authenticator = authenticator + return usi +} diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 6445c9d0f6..27ef3a090d 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -789,6 +789,8 @@ func (m *mockConnection) SupportsStreaming() bool { return m.rCanStream func (m *mockConnection) CurrentlyStreaming() bool { return m.rStreaming } func (m *mockConnection) SetStreaming(streaming bool) { m.rStreaming = streaming } func (m *mockConnection) Stale() bool { return false } +func (m *mockConnection) OIDCTokenGenID() uint64 { return 0 } +func (m *mockConnection) SetOIDCTokenGenID(uint64) {} // TODO:(GODRIVER-2824) replace return type with int64. func (m *mockConnection) DriverConnectionID() uint64 { return 0 } diff --git a/x/mongo/driver/session/client_session.go b/x/mongo/driver/session/client_session.go index 8dac0932de..4a6be9c5e4 100644 --- a/x/mongo/driver/session/client_session.go +++ b/x/mongo/driver/session/client_session.go @@ -90,6 +90,8 @@ type LoadBalancedTransactionConnection interface { DriverConnectionID() uint64 // TODO(GODRIVER-2824): change type to int64. Address() address.Address Stale() bool + OIDCTokenGenID() uint64 + SetOIDCTokenGenID(uint64) // Functions copied over from driver.PinnedConnection that are not part of Connection or Expirable. PinToCursor() error diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 649e87b3d1..49a613aef8 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -82,6 +82,10 @@ type connection struct { // awaitingResponse indicates that the server response was not completely // read before returning the connection to the pool. awaitingResponse bool + + // oidcTokenGenID is the monotonic generation ID for OIDC tokens, used to invalidate + // accessTokens in the OIDC authenticator cache. + oidcTokenGenID uint64 } // newConnection handles the creation of a connection. It does not connect the connection. @@ -606,6 +610,8 @@ type Connection struct { refCount int cleanupPoolFn func() + oidcTokenGenID uint64 + // cleanupServerFn resets the server state when a connection is returned to the connection pool // via Close() or expired via Expire(). cleanupServerFn func() @@ -860,6 +866,16 @@ func configureTLS(ctx context.Context, return client, nil } +// OIDCTokenGenID returns the OIDC token generation ID. +func (c *Connection) OIDCTokenGenID() uint64 { + return c.oidcTokenGenID +} + +// SetOIDCTokenGenID sets the OIDC token generation ID. +func (c *Connection) SetOIDCTokenGenID(genID uint64) { + c.oidcTokenGenID = genID +} + // TODO: Naming? // cancellListener listens for context cancellation and notifies listeners via a @@ -903,3 +919,11 @@ func (c *cancellListener) StopListening() bool { c.done <- struct{}{} return c.aborted } + +func (c *connection) OIDCTokenGenID() uint64 { + return c.oidcTokenGenID +} + +func (c *connection) SetOIDCTokenGenID(genID uint64) { + c.oidcTokenGenID = genID +} diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index b5eb4a9729..0563e5524e 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -72,8 +72,30 @@ func newLogger(opts *options.LoggerOptions) (*logger.Logger, error) { } // NewConfig will translate data from client options into a topology config for building non-default deployments. -// Server and topology options are not honored if a custom deployment is used. func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, error) { + // Auth & Database & Password & Username + if co.Auth != nil { + cred := &auth.Cred{ + Username: co.Auth.Username, + Password: co.Auth.Password, + PasswordSet: co.Auth.PasswordSet, + Props: co.Auth.AuthMechanismProperties, + Source: co.Auth.AuthSource, + } + mechanism := co.Auth.AuthMechanism + authenticator, err := auth.CreateAuthenticator(mechanism, cred, co.HTTPClient) + if err != nil { + return nil, err + } + return NewConfigWithAuthenticator(co, clock, authenticator) + } + return NewConfigWithAuthenticator(co, clock, nil) +} + +// NewConfigWithAuthenticator will translate data from client options into a topology config for building non-default deployments. +// Server and topology options are not honored if a custom deployment is used. It uses a passed in +// authenticator to authenticate the connection. +func NewConfigWithAuthenticator(co *options.ClientOptions, clock *session.ClusterClock, authenticator driver.Authenticator) (*Config, error) { var serverAPI *driver.ServerAPIOptions if err := co.Validate(); err != nil { @@ -180,11 +202,6 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, } } - authenticator, err := auth.CreateAuthenticator(mechanism, cred) - if err != nil { - return nil, err - } - handshakeOpts := &auth.HandshakeOptions{ AppName: appName, Authenticator: authenticator, @@ -192,7 +209,6 @@ func NewConfig(co *options.ClientOptions, clock *session.ClusterClock) (*Config, ServerAPI: serverAPI, LoadBalanced: loadBalanced, ClusterClock: clock, - HTTPClient: co.HTTPClient, } if mechanism == "" {