diff --git a/spanner/client.go b/spanner/client.go index 1ce356266108..5eb8bb164ff9 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -107,20 +107,19 @@ func parseDatabaseName(db string) (project, instance, database string, err error // Client is a client for reading and writing data to a Cloud Spanner database. // A client is safe to use concurrently, except for its Close method. type Client struct { - sc *sessionClient - idleSessions *sessionPool - logger *log.Logger - qo QueryOptions - ro ReadOptions - ao []ApplyOption - txo TransactionOptions - bwo BatchWriteOptions - ct *commonTags - disableRouteToLeader bool - enableMultiplexedSessionForRW bool - dro *sppb.DirectedReadOptions - otConfig *openTelemetryConfig - metricsTracerFactory *builtinMetricsTracerFactory + sc *sessionClient + idleSessions *sessionPool + logger *log.Logger + qo QueryOptions + ro ReadOptions + ao []ApplyOption + txo TransactionOptions + bwo BatchWriteOptions + ct *commonTags + disableRouteToLeader bool + dro *sppb.DirectedReadOptions + otConfig *openTelemetryConfig + metricsTracerFactory *builtinMetricsTracerFactory } // DatabaseName returns the full name of a database, e.g., @@ -548,20 +547,19 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf } c = &Client{ - sc: sc, - idleSessions: sp, - logger: config.Logger, - qo: getQueryOptions(config.QueryOptions), - ro: config.ReadOptions, - ao: config.ApplyOptions, - txo: config.TransactionOptions, - bwo: config.BatchWriteOptions, - ct: getCommonTags(sc), - disableRouteToLeader: config.DisableRouteToLeader, - dro: config.DirectedReadOptions, - otConfig: otConfig, - metricsTracerFactory: metricsTracerFactory, - enableMultiplexedSessionForRW: config.enableMultiplexedSessionForRW, + sc: sc, + idleSessions: sp, + logger: config.Logger, + qo: getQueryOptions(config.QueryOptions), + ro: config.ReadOptions, + ao: config.ApplyOptions, + txo: config.TransactionOptions, + bwo: config.BatchWriteOptions, + ct: getCommonTags(sc), + disableRouteToLeader: config.DisableRouteToLeader, + dro: config.DirectedReadOptions, + otConfig: otConfig, + metricsTracerFactory: metricsTracerFactory, } return c, nil } @@ -1025,7 +1023,7 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea err error ) if sh == nil || sh.getID() == "" || sh.getClient() == nil { - if c.enableMultiplexedSessionForRW { + if c.idleSessions.isMultiplexedSessionForRWEnabled() { sh, err = c.idleSessions.takeMultiplexed(ctx) } else { // Session handle hasn't been allocated or has been destroyed. @@ -1044,7 +1042,7 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea // Note that the t.begin(ctx) call could change the session that is being used by the transaction, as the // BeginTransaction RPC invocation will be retried on a new session if it returns SessionNotFound. t.txReadOnly.sh = sh - if err = t.begin(ctx); err != nil { + if err = t.begin(ctx, nil); err != nil { trace.TracePrintf(ctx, nil, "Error while BeginTransaction during retrying a ReadWrite transaction: %v", ToSpannerError(err)) return ToSpannerError(err) } @@ -1072,7 +1070,7 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea return err }) if isUnimplementedErrorForMultiplexedRW(err) { - c.enableMultiplexedSessionForRW = false + c.idleSessions.disableMultiplexedSessionForRW() } return resp, err } diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index 86770e4d2948..2955b182a767 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -1078,6 +1078,9 @@ func (s *inMemSpannerServer) BeginTransaction(ctx context.Context, req *spannerp } s.updateSessionLastUseTime(session.Name) tx := s.beginTransaction(session, req.Options) + if session.Multiplexed && req.MutationKey != nil { + tx.PrecommitToken = s.getPreCommitToken(string(tx.Id), "TransactionPrecommitToken") + } return tx, nil } diff --git a/spanner/mutation.go b/spanner/mutation.go index b9909742d010..7ff9079ed057 100644 --- a/spanner/mutation.go +++ b/spanner/mutation.go @@ -17,7 +17,9 @@ limitations under the License. package spanner import ( + "math/rand" "reflect" + "time" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" "google.golang.org/grpc/codes" @@ -427,16 +429,42 @@ func (m Mutation) proto() (*sppb.Mutation, error) { // mutationsProto turns a spanner.Mutation array into a sppb.Mutation array, // it is convenient for sending batch mutations to Cloud Spanner. -func mutationsProto(ms []*Mutation) ([]*sppb.Mutation, error) { +func mutationsProto(ms []*Mutation) ([]*sppb.Mutation, *sppb.Mutation, error) { + var selectedMutation *Mutation + var nonInsertMutations []*Mutation + l := make([]*sppb.Mutation, 0, len(ms)) for _, m := range ms { + if m.op != opInsert { + nonInsertMutations = append(nonInsertMutations, m) + } + if selectedMutation == nil { + selectedMutation = m + } + // Track the INSERT mutation with the highest number of values if only INSERT mutation were found + if selectedMutation.op == opInsert && m.op == opInsert && len(m.values) > len(selectedMutation.values) { + selectedMutation = m + } + + // Convert the mutation to sppb.Mutation and add to the list pb, err := m.proto() if err != nil { - return nil, err + return nil, nil, err } l = append(l, pb) } - return l, nil + if len(nonInsertMutations) > 0 { + selectedMutation = nonInsertMutations[rand.New(rand.NewSource(time.Now().UnixNano())).Intn(len(nonInsertMutations))] + } + if selectedMutation != nil { + m, err := selectedMutation.proto() + if err != nil { + return nil, nil, err + } + return l, m, nil + } + + return l, nil, nil } // mutationGroupsProto turns a spanner.MutationGroup array into a @@ -444,7 +472,7 @@ func mutationsProto(ms []*Mutation) ([]*sppb.Mutation, error) { func mutationGroupsProto(mgs []*MutationGroup) ([]*sppb.BatchWriteRequest_MutationGroup, error) { gs := make([]*sppb.BatchWriteRequest_MutationGroup, 0, len(mgs)) for _, mg := range mgs { - ms, err := mutationsProto(mg.Mutations) + ms, _, err := mutationsProto(mg.Mutations) if err != nil { return nil, err } diff --git a/spanner/mutation_test.go b/spanner/mutation_test.go index 566b26f892d9..90a9e608de5d 100644 --- a/spanner/mutation_test.go +++ b/spanner/mutation_test.go @@ -18,6 +18,7 @@ package spanner import ( "math/big" + "reflect" "sort" "strings" "testing" @@ -561,61 +562,198 @@ func TestEncodeMutation(t *testing.T) { // Test Encoding an array of mutations. func TestEncodeMutationArray(t *testing.T) { - for _, test := range []struct { - name string - ms []*Mutation - want []*sppb.Mutation - wantErr error + tests := []struct { + name string + ms []*Mutation + want []*sppb.Mutation + wantMutationKey *sppb.Mutation + wantErr error }{ + // Test case for empty mutation list { - "Multiple Mutations", - []*Mutation{ - {opDelete, "t_test", Key{"bar"}, nil, nil}, - {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo", 1}}, + name: "Empty Mutation List", + ms: []*Mutation{}, + want: []*sppb.Mutation{}, + wantMutationKey: nil, + wantErr: nil, + }, + // Test case for only insert mutations + { + name: "Only Inserts", + ms: []*Mutation{ + {opInsert, "t_test", nil, []string{"key", "val"}, []interface{}{"foo", 1}}, + {opInsert, "t_test", nil, []string{"key", "val"}, []interface{}{"bar", 2}}, + {opInsert, "t_test", nil, []string{"key", "val", "col3"}, []interface{}{"bar2", 3, 4}}, }, - []*sppb.Mutation{ + want: []*sppb.Mutation{ { - Operation: &sppb.Mutation_Delete_{ - Delete: &sppb.Mutation_Delete{ - Table: "t_test", - KeySet: &sppb.KeySet{ - Keys: []*proto3.ListValue{listValueProto(stringProto("bar"))}, + Operation: &sppb.Mutation_Insert{ + Insert: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{ + listValueProto(stringProto("foo"), intProto(1)), }, }, }, }, { - Operation: &sppb.Mutation_InsertOrUpdate{ - InsertOrUpdate: &sppb.Mutation_Write{ + Operation: &sppb.Mutation_Insert{ + Insert: &sppb.Mutation_Write{ Table: "t_test", Columns: []string{"key", "val"}, - Values: []*proto3.ListValue{listValueProto(stringProto("foo"), intProto(1))}, + Values: []*proto3.ListValue{ + listValueProto(stringProto("bar"), intProto(2)), + }, + }, + }, + }, + { + Operation: &sppb.Mutation_Insert{ + Insert: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val", "col3"}, + Values: []*proto3.ListValue{ + listValueProto(stringProto("bar2"), intProto(3), intProto(4)), + }, }, }, }, }, - nil, + wantMutationKey: &sppb.Mutation{ + Operation: &sppb.Mutation_Insert{ + Insert: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val", "col3"}, + Values: []*proto3.ListValue{ + listValueProto(stringProto("bar2"), intProto(3), intProto(4)), + }, + }, + }, + }, + wantErr: nil, }, + // Test case for mixed operations { - "Multiple Mutations - Bad Mutation", - []*Mutation{ + name: "Mixed Operations", + ms: []*Mutation{ + {opInsert, "t_test", nil, []string{"key", "val"}, []interface{}{"foo", 1}}, + {opUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"bar", 2}}, + }, + want: []*sppb.Mutation{ + { + Operation: &sppb.Mutation_Insert{ + Insert: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{ + listValueProto(stringProto("foo"), intProto(1)), + }, + }, + }, + }, + { + Operation: &sppb.Mutation_Update{ + Update: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{ + listValueProto(stringProto("bar"), intProto(2)), + }, + }, + }, + }, + }, + wantMutationKey: &sppb.Mutation{ + Operation: &sppb.Mutation_Update{ + Update: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{ + listValueProto(stringProto("bar"), intProto(2)), + }, + }, + }, + }, + wantErr: nil, + }, + // Test case for error in mutation + { + name: "Error in Mutation", + ms: []*Mutation{ + {opInsert, "t_test", nil, []string{"key", "val"}, []interface{}{struct{}{}, 1}}, + }, + want: []*sppb.Mutation{}, + wantMutationKey: nil, + wantErr: errEncoderUnsupportedType(struct{}{}), + }, + // Test case for only delete mutations + { + name: "Only Deletes", + ms: []*Mutation{ + {opDelete, "t_test", Key{"foo"}, nil, nil}, {opDelete, "t_test", Key{"bar"}, nil, nil}, - {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo", struct{}{}}}, }, - []*sppb.Mutation{}, - errEncoderUnsupportedType(struct{}{}), + want: []*sppb.Mutation{ + { + Operation: &sppb.Mutation_Delete_{ + Delete: &sppb.Mutation_Delete{ + Table: "t_test", + KeySet: &sppb.KeySet{ + Keys: []*proto3.ListValue{ + listValueProto(stringProto("foo")), + }, + }, + }, + }, + }, + { + Operation: &sppb.Mutation_Delete_{ + Delete: &sppb.Mutation_Delete{ + Table: "t_test", + KeySet: &sppb.KeySet{ + Keys: []*proto3.ListValue{ + listValueProto(stringProto("bar")), + }, + }, + }, + }, + }, + }, + wantMutationKey: &sppb.Mutation{ + Operation: &sppb.Mutation_Delete_{ + Delete: &sppb.Mutation_Delete{ + Table: "t_test", + KeySet: &sppb.KeySet{ + Keys: []*proto3.ListValue{ + listValueProto(stringProto("bar")), + }, + }, + }, + }, + }, + wantErr: nil, }, - } { - gotProto, gotErr := mutationsProto(test.ms) - if gotErr != nil { - if !testEqual(gotErr, test.wantErr) { - t.Errorf("%v: mutationsProto(%v) returns error %v, want %v", test.name, test.ms, gotErr, test.wantErr) + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gotProto, gotMutationKey, gotErr := mutationsProto(test.ms) + if gotErr != nil { + if !testEqual(gotErr, test.wantErr) { + t.Errorf("mutationsProto(%v) returns error %v, want %v", test.ms, gotErr, test.wantErr) + } + return } - continue - } - if !testEqual(gotProto, test.want) { - t.Errorf("%v: mutationsProto(%v) = (%v, nil), want (%v, nil)", test.name, test.ms, gotProto, test.want) - } + if !testEqual(gotProto, test.want) { + t.Errorf("mutationsProto(%v) = (%v, nil), want (%v, nil)", test.ms, gotProto, test.want) + } + if test.wantMutationKey != nil { + if reflect.TypeOf(gotMutationKey.Operation) != reflect.TypeOf(test.wantMutationKey.Operation) { + t.Errorf("mutationsProto(%v) returns mutation key %v, want %v", test.ms, gotMutationKey, test.wantMutationKey) + } + } + }) } } diff --git a/spanner/session.go b/spanner/session.go index 2165fdee04d8..2396bc0ab9ee 100644 --- a/spanner/session.go +++ b/spanner/session.go @@ -508,7 +508,7 @@ type SessionPoolConfig struct { enableMultiplexSession bool - // enableMultiplexedSessionForRW is a flag to enable multiplexed session for read/write transactions, is used in testing + // enableMultiplexedSessionForRW is a flag to enable multiplexed session for read/write transactions enableMultiplexedSessionForRW bool // healthCheckSampleInterval is how often the health checker samples live @@ -810,6 +810,18 @@ func (p *sessionPool) getRatioOfSessionsInUseLocked() float64 { return float64(p.numInUse) / float64(maxSessions) } +func (p *sessionPool) isMultiplexedSessionForRWEnabled() bool { + p.mu.Lock() + defer p.mu.Unlock() + return p.enableMultiplexedSessionForRW +} + +func (p *sessionPool) disableMultiplexedSessionForRW() { + p.mu.Lock() + defer p.mu.Unlock() + p.enableMultiplexedSessionForRW = false +} + // gets sessions which are unexpectedly long-running. func (p *sessionPool) getLongRunningSessionsLocked() []*sessionHandle { usedSessionsRatio := p.getRatioOfSessionsInUseLocked() diff --git a/spanner/transaction.go b/spanner/transaction.go index dbc1c5e969ef..94fb770ead53 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -1535,7 +1535,7 @@ func (t *ReadWriteTransaction) setSessionEligibilityForLongRunning(sh *sessionHa } } -func beginTransaction(ctx context.Context, sid string, client spannerClient, opts TransactionOptions) (transactionID, error) { +func beginTransaction(ctx context.Context, sid string, client spannerClient, opts TransactionOptions, mutationKey *sppb.Mutation) (transactionID, *sppb.MultiplexedSessionPrecommitToken, error) { res, err := client.BeginTransaction(ctx, &sppb.BeginTransactionRequest{ Session: sid, Options: &sppb.TransactionOptions{ @@ -1546,14 +1546,16 @@ func beginTransaction(ctx context.Context, sid string, client spannerClient, opt }, ExcludeTxnFromChangeStreams: opts.ExcludeTxnFromChangeStreams, }, + MutationKey: mutationKey, }) if err != nil { - return nil, err + return nil, nil, err } if res.Id == nil { - return nil, spannerErrorf(codes.Unknown, "BeginTransaction returned a transaction with a nil ID.") + return nil, nil, spannerErrorf(codes.Unknown, "BeginTransaction returned a transaction with a nil ID.") } - return res.Id, nil + + return res.Id, res.GetPrecommitToken(), nil } // shouldExplicitBegin checks if ReadWriteTransaction should do an explicit BeginTransaction @@ -1572,7 +1574,7 @@ func (t *ReadWriteTransaction) shouldExplicitBegin(attempt int) bool { } // begin starts a read-write transaction on Cloud Spanner. -func (t *ReadWriteTransaction) begin(ctx context.Context) error { +func (t *ReadWriteTransaction) begin(ctx context.Context, mutation *sppb.Mutation) error { t.mu.Lock() if t.tx != nil { t.state = txActive @@ -1582,8 +1584,9 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error { t.mu.Unlock() var ( - tx transactionID - err error + tx transactionID + precommitToken *sppb.MultiplexedSessionPrecommitToken + err error ) defer func() { if err != nil && sh != nil { @@ -1601,9 +1604,10 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error { if sh != nil { sh.updateLastUseTime() } - tx, err = beginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), sh.getID(), sh.getClient(), t.txOpts) + tx, precommitToken, err = beginTransaction(contextWithOutgoingMetadata(ctx, sh.getMetadata(), t.disableRouteToLeader), sh.getID(), sh.getClient(), t.txOpts, mutation) if isSessionNotFoundError(err) { sh.destroy() + // this should not happen with multiplexed session, but if it does, we should not retry with multiplexed session sh, err = t.sp.take(ctx) if err != nil { return err @@ -1614,6 +1618,7 @@ func (t *ReadWriteTransaction) begin(ctx context.Context) error { } else { err = ToSpannerError(err) } + t.updatePrecommitToken(precommitToken) break } if err == nil { @@ -1660,6 +1665,7 @@ func (co CommitOptions) merge(opts CommitOptions) CommitOptions { func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions) (CommitResponse, error) { resp := CommitResponse{} t.mu.Lock() + mutationProtos, selectedMutationProto, err := mutationsProto(t.wb) if t.tx == nil { if t.state == txClosed { // inline begin transaction failed @@ -1667,16 +1673,18 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions return resp, errInlineBeginTransactionFailed() } t.mu.Unlock() + if !t.sp.isMultiplexedSessionForRWEnabled() { + selectedMutationProto = nil + } // mutations or empty transaction body only - if err := t.begin(ctx); err != nil { + if err := t.begin(ctx, selectedMutationProto); err != nil { return resp, err } t.mu.Lock() } t.state = txClosed // No further operations after commit. close(t.txReadyOrClosed) - mPb, err := mutationsProto(t.wb) - + precommitToken := t.precommitToken t.mu.Unlock() if err != nil { return resp, err @@ -1700,9 +1708,9 @@ func (t *ReadWriteTransaction) commit(ctx context.Context, options CommitOptions Transaction: &sppb.CommitRequest_TransactionId{ TransactionId: t.tx, }, - PrecommitToken: t.precommitToken, + PrecommitToken: precommitToken, RequestOptions: createRequestOptions(t.txOpts.CommitPriority, "", t.txOpts.TransactionTag), - Mutations: mPb, + Mutations: mutationProtos, ReturnCommitStats: options.ReturnCommitStats, MaxCommitDelay: maxCommitDelay, }, gax.WithGRPCOptions(grpc.Header(&md))) @@ -1883,7 +1891,7 @@ func newReadWriteStmtBasedTransactionWithSessionHandle(ctx context.Context, c *C t.otConfig = c.otConfig // always explicit begin the transactions - if err = t.begin(ctx); err != nil { + if err = t.begin(ctx, nil); err != nil { if sh != nil { sh.recycle() } @@ -1975,7 +1983,7 @@ func (t *writeOnlyTransaction) applyAtLeastOnce(ctx context.Context, ms ...*Muta sh.recycle() } }() - mPb, err := mutationsProto(ms) + mPb, _, err := mutationsProto(ms) if err != nil { // Malformed mutation found, just return the error. return ts, err diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go index 1acd7f72f390..c17d40879cbe 100644 --- a/spanner/transaction_test.go +++ b/spanner/transaction_test.go @@ -413,19 +413,31 @@ func TestReadWriteTransaction_PrecommitToken(t *testing.T) { query bool update bool batchUpdate bool + mutationsOnly bool expectedPrecommitToken string expectedSequenceNumber int32 } testCases := []testCase{ - {"Only Query", true, false, false, "PartialResultSetPrecommitToken", 3}, //since mock server is returning 3 rows - {"Query and Update", true, true, false, "ResultSetPrecommitToken", 4}, - {"Query, Update, and Batch Update", true, true, true, "ExecuteBatchDmlResponsePrecommitToken", 5}, + {"Only Query", true, false, false, false, "PartialResultSetPrecommitToken", 3}, + {"Query and Update", true, true, false, false, "ResultSetPrecommitToken", 4}, + {"Query, Update, and Batch Update", true, true, true, false, "ExecuteBatchDmlResponsePrecommitToken", 5}, + {"Only Mutations", false, false, false, true, "TransactionPrecommitToken", 1}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + if tc.mutationsOnly { + ms := []*Mutation{ + Insert("t_foo", []string{"col1", "col2"}, []interface{}{int64(1), int64(2)}), + Update("t_foo", []string{"col1", "col2"}, []interface{}{"one", []byte(nil)}), + } + if err := tx.BufferWrite(ms); err != nil { + return err + } + } + if tc.query { iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums)) defer iter.Stop() @@ -463,8 +475,19 @@ func TestReadWriteTransaction_PrecommitToken(t *testing.T) { for _, req := range requests { if c, ok := req.(*sppb.CommitRequest); ok { commitReq = c - break } + if b, ok := req.(*sppb.BeginTransactionRequest); ok { + if !strings.Contains(b.GetSession(), "multiplexed") { + t.Errorf("Expected session to be multiplexed") + } + if b.MutationKey == nil { + t.Fatalf("Expected BeginTransaction request to contain a mutation key") + } + } + + } + if !strings.Contains(commitReq.GetSession(), "multiplexed") { + t.Errorf("Expected session to be multiplexed") } if commitReq.PrecommitToken == nil || len(commitReq.PrecommitToken.GetPrecommitToken()) == 0 { t.Fatalf("Expected commit request to contain a valid precommitToken, got: %v", commitReq.PrecommitToken) @@ -481,6 +504,88 @@ func TestReadWriteTransaction_PrecommitToken(t *testing.T) { } } +func TestMutationOnlyCaseAborted(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Define mutations to apply + mutations := []*Mutation{ + Insert("FOO", []string{"ID", "NAME"}, []interface{}{int64(1), "Bar"}), + } + + // Define a function to verify requests + verifyRequests := func(server *MockedSpannerInMemTestServer) { + var numBeginReq, numCommitReq int + // Verify that for mutation-only case, a mutation key is set in BeginTransactionRequest + requests := drainRequestsFromServer(server.TestSpanner) + for _, req := range requests { + if beginReq, ok := req.(*sppb.BeginTransactionRequest); ok { + if beginReq.GetMutationKey() == nil { + t.Fatalf("Expected mutation key with insert operation") + } + if !strings.Contains(beginReq.GetSession(), "multiplexed") { + t.Errorf("Expected session to be multiplexed") + } + numBeginReq++ + } + if commitReq, ok := req.(*sppb.CommitRequest); ok { + if commitReq.GetPrecommitToken() == nil || !strings.Contains(string(commitReq.GetPrecommitToken().PrecommitToken), "TransactionPrecommitToken") { + t.Errorf("Expected precommit token 'TransactionPrecommitToken', got %v", commitReq.GetPrecommitToken()) + } + if !strings.Contains(commitReq.GetSession(), "multiplexed") { + t.Errorf("Expected session to be multiplexed") + } + numCommitReq++ + } + } + if numBeginReq != 2 || numCommitReq != 2 { + t.Fatalf("Expected 2 BeginTransactionRequests and 2 CommitRequests, got %d and %d", numBeginReq, numCommitReq) + } + } + + // Test both ReadWriteTransaction and client.Apply + for _, method := range []string{"ReadWriteTransaction", "Apply"} { + t.Run(method, func(t *testing.T) { + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + DisableNativeMetrics: true, + SessionPoolConfig: SessionPoolConfig{ + MinOpened: 1, + MaxOpened: 1, + enableMultiplexSession: true, + enableMultiplexedSessionForRW: true, + }, + }) + defer teardown() + + // Simulate an aborted transaction on the first commit attempt + server.TestSpanner.PutExecutionTime(MethodCommitTransaction, + SimulatedExecutionTime{ + Errors: []error{status.Errorf(codes.Aborted, "Transaction aborted")}, + }) + switch method { + case "ReadWriteTransaction": + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + if err := tx.BufferWrite(mutations); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatalf("ReadWriteTransaction failed: %v", err) + } + case "Apply": + _, err := client.Apply(ctx, mutations) + if err != nil { + t.Fatalf("Apply failed: %v", err) + } + } + + // Verify requests for the current method + verifyRequests(server) + }) + } +} + func TestBatchDML_WithMultipleDML(t *testing.T) { t.Parallel() ctx := context.Background()