diff --git a/spanner/client.go b/spanner/client.go index d28b95cfaf20..c2e8447c155e 100644 --- a/spanner/client.go +++ b/spanner/client.go @@ -19,6 +19,7 @@ package spanner import ( "context" "fmt" + "io" "log" "os" "regexp" @@ -26,6 +27,8 @@ import ( "cloud.google.com/go/internal/trace" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" + "github.com/googleapis/gax-go/v2" + "google.golang.org/api/iterator" "google.golang.org/api/option" "google.golang.org/api/option/internaloption" gtransport "google.golang.org/api/transport/grpc" @@ -97,6 +100,7 @@ type Client struct { ro ReadOptions ao []ApplyOption txo TransactionOptions + bwo BatchWriteOptions ct *commonTags disableRouteToLeader bool } @@ -138,6 +142,9 @@ type ClientConfig struct { // TransactionOptions is the configuration for a transaction. TransactionOptions TransactionOptions + // BatchWriteOptions is the configuration for a BatchWrite request. + BatchWriteOptions BatchWriteOptions + // CallOptions is the configuration for providing custom retry settings that // override the default values. CallOptions *vkit.CallOptions @@ -281,6 +288,7 @@ func NewClientWithConfig(ctx context.Context, database string, config ClientConf ro: config.ReadOptions, ao: config.ApplyOptions, txo: config.TransactionOptions, + bwo: config.BatchWriteOptions, ct: getCommonTags(sc), disableRouteToLeader: config.DisableRouteToLeader, } @@ -669,6 +677,217 @@ func (c *Client) Apply(ctx context.Context, ms []*Mutation, opts ...ApplyOption) return t.applyAtLeastOnce(ctx, ms...) } +// BatchWriteOptions provides options for a BatchWriteRequest. +type BatchWriteOptions struct { + // Priority is the RPC priority to use for this request. + Priority sppb.RequestOptions_Priority + + // The transaction tag to use for this request. + TransactionTag string +} + +// merge combines two BatchWriteOptions such that the input parameter will have higher +// order of precedence. +func (bwo BatchWriteOptions) merge(opts BatchWriteOptions) BatchWriteOptions { + merged := BatchWriteOptions{ + TransactionTag: bwo.TransactionTag, + Priority: bwo.Priority, + } + if opts.TransactionTag != "" { + merged.TransactionTag = opts.TransactionTag + } + if opts.Priority != sppb.RequestOptions_PRIORITY_UNSPECIFIED { + merged.Priority = opts.Priority + } + return merged +} + +// BatchWriteResponseIterator is an iterator over BatchWriteResponse structures returned from BatchWrite RPC. +type BatchWriteResponseIterator struct { + ctx context.Context + stream sppb.Spanner_BatchWriteClient + err error + dataReceived bool + replaceSession func(ctx context.Context) error + rpc func(ctx context.Context) (sppb.Spanner_BatchWriteClient, error) + release func(error) + cancel func() +} + +// Next returns the next result. Its second return value is iterator.Done if +// there are no more results. Once Next returns Done, all subsequent calls +// will return Done. +func (r *BatchWriteResponseIterator) Next() (*sppb.BatchWriteResponse, error) { + for { + // Stream finished or in error state. + if r.err != nil { + return nil, r.err + } + + // RPC not made yet. + if r.stream == nil { + r.stream, r.err = r.rpc(r.ctx) + continue + } + + // Read from the stream. + var response *sppb.BatchWriteResponse + response, r.err = r.stream.Recv() + + // Return an item. + if r.err == nil { + r.dataReceived = true + return response, nil + } + + // Stream finished. + if r.err == io.EOF { + r.err = iterator.Done + return nil, r.err + } + + // Retry request on session not found error only if no data has been received before. + if !r.dataReceived && r.replaceSession != nil && isSessionNotFoundError(r.err) { + r.err = r.replaceSession(r.ctx) + r.stream = nil + } + } +} + +// Stop terminates the iteration. It should be called after you finish using the +// iterator. +func (r *BatchWriteResponseIterator) Stop() { + if r.stream != nil { + err := r.err + if err == iterator.Done { + err = nil + } + defer trace.EndSpan(r.ctx, err) + } + if r.cancel != nil { + r.cancel() + r.cancel = nil + } + if r.release != nil { + r.release(r.err) + r.release = nil + } + if r.err == nil { + r.err = spannerErrorf(codes.FailedPrecondition, "Next called after Stop") + } +} + +// Do calls the provided function once in sequence for each item in the +// iteration. If the function returns a non-nil error, Do immediately returns +// that error. +// +// If there are no items in the iterator, Do will return nil without calling the +// provided function. +// +// Do always calls Stop on the iterator. +func (r *BatchWriteResponseIterator) Do(f func(r *sppb.BatchWriteResponse) error) error { + defer r.Stop() + for { + row, err := r.Next() + switch err { + case iterator.Done: + return nil + case nil: + if err = f(row); err != nil { + return err + } + default: + return err + } + } +} + +// BatchWrite applies a list of mutation groups in a collection of efficient +// transactions. The mutation groups are applied non-atomically in an +// unspecified order and thus, they must be independent of each other. Partial +// failure is possible, i.e., some mutation groups may have been applied +// successfully, while some may have failed. The results of individual batches +// are streamed into the response as the batches are applied. +// +// BatchWrite requests are not replay protected, meaning that each mutation +// group may be applied more than once. Replays of non-idempotent mutations +// may have undesirable effects. For example, replays of an insert mutation +// may produce an already exists error or if you use generated or commit +// timestamp-based keys, it may result in additional rows being added to the +// mutation's table. We recommend structuring your mutation groups to be +// idempotent to avoid this issue. +func (c *Client) BatchWrite(ctx context.Context, mgs []*MutationGroup) *BatchWriteResponseIterator { + return c.BatchWriteWithOptions(ctx, mgs, BatchWriteOptions{}) +} + +// BatchWriteWithOptions is same as BatchWrite. It accepts additional options to customize the request. +func (c *Client) BatchWriteWithOptions(ctx context.Context, mgs []*MutationGroup, opts BatchWriteOptions) *BatchWriteResponseIterator { + ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchWrite") + + var err error + defer func() { + trace.EndSpan(ctx, err) + }() + + opts = c.bwo.merge(opts) + + mgsPb, err := mutationGroupsProto(mgs) + if err != nil { + return &BatchWriteResponseIterator{err: err} + } + + var sh *sessionHandle + sh, err = c.idleSessions.take(ctx) + if err != nil { + return &BatchWriteResponseIterator{err: err} + } + + rpc := func(ct context.Context) (sppb.Spanner_BatchWriteClient, error) { + var md metadata.MD + stream, rpcErr := sh.getClient().BatchWrite(contextWithOutgoingMetadata(ct, sh.getMetadata(), c.disableRouteToLeader), &sppb.BatchWriteRequest{ + Session: sh.getID(), + MutationGroups: mgsPb, + RequestOptions: createRequestOptions(opts.Priority, "", opts.TransactionTag), + }, gax.WithGRPCOptions(grpc.Header(&md))) + + if getGFELatencyMetricsFlag() && md != nil && c.ct != nil { + if metricErr := createContextAndCaptureGFELatencyMetrics(ct, c.ct, md, "BatchWrite"); metricErr != nil { + trace.TracePrintf(ct, nil, "Error in recording GFE Latency. Try disabling and rerunning. Error: %v", err) + } + } + return stream, rpcErr + } + + replaceSession := func(ct context.Context) error { + if sh != nil { + sh.destroy() + } + var sessionErr error + sh, sessionErr = c.idleSessions.take(ct) + return sessionErr + } + + release := func(err error) { + if sh == nil { + return + } + if isSessionNotFoundError(err) { + sh.destroy() + } + sh.recycle() + } + + ctx, cancel := context.WithCancel(ctx) + ctx = trace.StartSpan(ctx, "cloud.google.com/go/spanner.BatchWriteResponseIterator") + return &BatchWriteResponseIterator{ + ctx: ctx, + rpc: rpc, + replaceSession: replaceSession, + release: release, + cancel: cancel, + } +} + // logf logs the given message to the given logger, or the standard logger if // the given logger is nil. func logf(logger *log.Logger, format string, v ...interface{}) { diff --git a/spanner/client_test.go b/spanner/client_test.go index aa17d7454c3c..da3309adccb5 100644 --- a/spanner/client_test.go +++ b/spanner/client_test.go @@ -31,6 +31,7 @@ import ( itestutil "cloud.google.com/go/internal/testutil" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/googleapis/gax-go/v2" "google.golang.org/api/iterator" "google.golang.org/api/option" @@ -4304,3 +4305,250 @@ func TestClient_CustomRetryAndTimeoutSettings(t *testing.T) { t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) } } + +func TestClient_BatchWrite(t *testing.T) { + t.Parallel() + + server, client, teardown := setupMockedTestServer(t) + defer teardown() + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", 1}}, + }}, + } + iter := client.BatchWrite(context.Background(), mutationGroups) + responseCount := 0 + doFunc := func(r *sppb.BatchWriteResponse) error { + responseCount++ + return nil + } + if err := iter.Do(doFunc); err != nil { + t.Fatal(err) + } + if responseCount != len(mutationGroups) { + t.Fatalf("Response count mismatch.\nGot: %v\nWant:%v", responseCount, len(mutationGroups)) + } + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BatchWriteRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + +func TestClient_BatchWrite_SessionNotFound(t *testing.T) { + t.Parallel() + + server, client, teardown := setupMockedTestServer(t) + defer teardown() + server.TestSpanner.PutExecutionTime( + MethodBatchWrite, + SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}}, + ) + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", 1}}, + }}, + } + iter := client.BatchWrite(context.Background(), mutationGroups) + responseCount := 0 + doFunc := func(r *sppb.BatchWriteResponse) error { + responseCount++ + return nil + } + if err := iter.Do(doFunc); err != nil { + t.Fatal(err) + } + if responseCount != len(mutationGroups) { + t.Fatalf("Response count mismatch.\nGot: %v\nWant:%v", responseCount, len(mutationGroups)) + } + + requests := drainRequestsFromServer(server.TestSpanner) + if err := compareRequests([]interface{}{ + &sppb.BatchCreateSessionsRequest{}, + &sppb.BatchWriteRequest{}, + &sppb.BatchWriteRequest{}, + }, requests); err != nil { + t.Fatal(err) + } +} + +func TestClient_BatchWrite_Error(t *testing.T) { + t.Parallel() + + server, client, teardown := setupMockedTestServer(t) + defer teardown() + injectedErr := status.Error(codes.InvalidArgument, "Invalid argument") + server.TestSpanner.PutExecutionTime( + MethodBatchWrite, + SimulatedExecutionTime{Errors: []error{injectedErr}}, + ) + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", 1}}, + }}, + } + iter := client.BatchWrite(context.Background(), mutationGroups) + responseCount := 0 + doFunc := func(r *sppb.BatchWriteResponse) error { + responseCount++ + return nil + } + if err := iter.Do(doFunc); status.Code(err) != status.Code(injectedErr) { + t.Fatalf("Error mismatch.\nGot:%v\nExpected:%v\n", err, injectedErr) + } + if responseCount != 0 { + t.Fatalf("Do function unexpectedly called %v times", responseCount) + } +} + +func checkBatchWriteForExpectedRequestOptions(t *testing.T, server InMemSpannerServer, want *sppb.RequestOptions) { + reqs := drainRequestsFromServer(server) + var got *sppb.RequestOptions + + for _, req := range reqs { + if request, ok := req.(*sppb.BatchWriteRequest); ok { + got = request.RequestOptions + break + } + } + + if got == nil { + t.Fatalf("Missing BatchWrite RequestOptions") + } + + if diff := itestutil.Diff(got, want, cmpopts.IgnoreUnexported(sppb.RequestOptions{})); diff != "" { + t.Fatalf("RequestOptions mismatch. (+Got, -Want):%v", diff) + } +} + +func TestClient_BatchWrite_Options(t *testing.T) { + t.Parallel() + + testcases := []struct { + name string + client BatchWriteOptions + write BatchWriteOptions + wantTransactionTag string + wantPriority sppb.RequestOptions_Priority + }{ + { + name: "Client level", + client: BatchWriteOptions{TransactionTag: "testTransactionTag", Priority: sppb.RequestOptions_PRIORITY_LOW}, + wantTransactionTag: "testTransactionTag", + wantPriority: sppb.RequestOptions_PRIORITY_LOW, + }, + { + name: "Write level", + write: BatchWriteOptions{TransactionTag: "testTransactionTag", Priority: sppb.RequestOptions_PRIORITY_LOW}, + wantTransactionTag: "testTransactionTag", + wantPriority: sppb.RequestOptions_PRIORITY_LOW, + }, + { + name: "Write level has precedence over client level", + client: BatchWriteOptions{TransactionTag: "clientTransactionTag", Priority: sppb.RequestOptions_PRIORITY_LOW}, + write: BatchWriteOptions{TransactionTag: "writeTransactionTag", Priority: sppb.RequestOptions_PRIORITY_MEDIUM}, + wantTransactionTag: "writeTransactionTag", + wantPriority: sppb.RequestOptions_PRIORITY_MEDIUM, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{BatchWriteOptions: tt.client}) + defer teardown() + + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", 1}}, + }}, + } + iter := client.BatchWriteWithOptions(context.Background(), mutationGroups, tt.write) + doFunc := func(r *sppb.BatchWriteResponse) error { + return nil + } + if err := iter.Do(doFunc); err != nil { + t.Fatal(err) + } + checkBatchWriteForExpectedRequestOptions(t, server.TestSpanner, &sppb.RequestOptions{Priority: tt.wantPriority, TransactionTag: tt.wantTransactionTag}) + }) + } +} + +func checkBatchWriteSpan(t *testing.T, errors []error, code codes.Code) { + // This test cannot be parallel, as the TestExporter does not support that. + te := itestutil.NewTestExporter() + defer te.Unregister() + minOpened := uint64(1) + server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ + SessionPoolConfig: SessionPoolConfig{ + MinOpened: minOpened, + WriteSessions: 0, + }, + }) + defer teardown() + + // Wait until all sessions have been created, so we know that those requests will not interfere with the test. + sp := client.idleSessions + waitFor(t, func() error { + sp.mu.Lock() + defer sp.mu.Unlock() + if uint64(sp.idleList.Len()) != minOpened { + return fmt.Errorf("num open sessions mismatch\nWant: %d\nGot: %d", sp.MinOpened, sp.numOpened) + } + return nil + }) + + server.TestSpanner.PutExecutionTime( + MethodBatchWrite, + SimulatedExecutionTime{Errors: errors}, + ) + mutationGroups := []*MutationGroup{ + {[]*Mutation{ + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", 1}}, + }}, + } + iter := client.BatchWrite(context.Background(), mutationGroups) + iter.Do(func(r *sppb.BatchWriteResponse) error { + return nil + }) + select { + case <-te.Stats: + case <-time.After(1 * time.Second): + t.Fatal("No stats were exported before timeout") + } + // Preferably we would want to lock the TestExporter here, but the mutex TestExporter.mu is not exported, so we + // cannot do that. + if len(te.Spans) == 0 { + t.Fatal("No spans were exported") + } + s := te.Spans[len(te.Spans)-1].Status + if s.Code != int32(code) { + t.Errorf("Span status mismatch\nGot: %v\nWant: %v", s.Code, code) + } +} +func TestClient_BatchWrite_SpanExported(t *testing.T) { + testcases := []struct { + name string + code codes.Code + errors []error + }{ + { + name: "Success", + code: codes.OK, + errors: []error{}, + }, + { + name: "Error", + code: codes.InvalidArgument, + errors: []error{status.Error(codes.InvalidArgument, "Invalid argument")}, + }, + } + + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + checkBatchWriteSpan(t, tt.errors, tt.code) + }) + } +} diff --git a/spanner/integration_test.go b/spanner/integration_test.go index 5d016ede945d..7ad2b82268f0 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -897,6 +897,92 @@ func deleteTestSession(ctx context.Context, t *testing.T, sessionName string) { } } +func TestIntegration_BatchWrite(t *testing.T) { + skipEmulatorTest(t) + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + // Set up testing environment. + client, _, cleanup := prepareIntegrationTest(ctx, t, DefaultSessionPoolConfig, statements[testDialect][singerDDLStatements]) + defer cleanup() + + writes := []struct { + row []interface{} + ts time.Time + }{ + {row: []interface{}{1, "Marc", "Foo"}}, + {row: []interface{}{2, "Tars", "Bar"}}, + {row: []interface{}{3, "Alpha", "Beta"}}, + {row: []interface{}{4, "Last", "End"}}, + } + mgs := make([]*MutationGroup, len(writes)) + // Try to write four rows through the BatchWrite API. + for i, w := range writes { + m := InsertOrUpdate("Singers", + []string{"SingerId", "FirstName", "LastName"}, + w.row) + ms := make([]*Mutation, 1) + ms[0] = m + mgs[i] = &MutationGroup{mutations: ms} + } + // Records the mutation group indexes received in the response. + seen := make(map[int32]int32) + numMutationGroups := len(mgs) + validate := func(res *sppb.BatchWriteResponse) error { + if status := status.ErrorProto(res.GetStatus()); status != nil { + t.Fatalf("Invalid status: %v", status) + } + if ts := res.GetCommitTimestamp(); ts == nil { + t.Fatal("Invalid commit timestamp") + } + for _, idx := range res.GetIndexes() { + if idx >= 0 && idx < int32(numMutationGroups) { + seen[idx]++ + } else { + t.Fatalf("Index %v out of range. Expected range [%v,%v]", idx, 0, numMutationGroups-1) + } + } + return nil + } + iter := client.BatchWrite(ctx, mgs) + if err := iter.Do(validate); err != nil { + t.Fatal(err) + } + // Validate that each mutation group index is seen exactly once. + if numMutationGroups != len(seen) { + t.Fatalf("Expected %v indexes, got %v indexes", numMutationGroups, len(seen)) + } + for idx, ct := range seen { + if ct != 1 { + t.Fatalf("Index %v seen %v times instead of exactly once", idx, ct) + } + } + + // Verify the writes by reading the database. + singersQuery := "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId IN (@p1, @p2, @p3)" + if testDialect == adminpb.DatabaseDialect_POSTGRESQL { + singersQuery = "SELECT SingerId, FirstName, LastName FROM Singers WHERE SingerId = $1 OR SingerId = $2 OR SingerId = $3" + } + qo := QueryOptions{Options: &sppb.ExecuteSqlRequest_QueryOptions{ + OptimizerVersion: "1", + OptimizerStatisticsPackage: "latest", + }} + got, err := readAll(client.Single().QueryWithOptions(ctx, Statement{ + singersQuery, + map[string]interface{}{"p1": int64(1), "p2": int64(3), "p3": int64(4)}, + }, qo)) + + if err != nil { + t.Errorf("ReadOnlyTransaction.QueryWithOptions returns error %v, want nil", err) + } + + want := [][]interface{}{{int64(1), "Marc", "Foo"}, {int64(3), "Alpha", "Beta"}, {int64(4), "Last", "End"}} + if !testEqual(got, want) { + t.Errorf("got unexpected result from ReadOnlyTransaction.QueryWithOptions: %v, want %v", got, want) + } +} + func TestIntegration_SingleUse_ReadingWithLimit(t *testing.T) { t.Parallel() diff --git a/spanner/internal/testutil/inmem_spanner_server.go b/spanner/internal/testutil/inmem_spanner_server.go index 6f48a54646b3..905e77c0cc84 100644 --- a/spanner/internal/testutil/inmem_spanner_server.go +++ b/spanner/internal/testutil/inmem_spanner_server.go @@ -89,6 +89,7 @@ const ( MethodExecuteStreamingSql string = "EXECUTE_STREAMING_SQL" MethodExecuteBatchDml string = "EXECUTE_BATCH_DML" MethodStreamingRead string = "EXECUTE_STREAMING_READ" + MethodBatchWrite string = "BATCH_WRITE" ) // StatementResult represents a mocked result on the test server. The result is @@ -1161,3 +1162,36 @@ func DecodeResumeToken(t []byte) (uint64, error) { } return s, nil } + +func (s *inMemSpannerServer) BatchWrite(req *spannerpb.BatchWriteRequest, stream spannerpb.Spanner_BatchWriteServer) error { + if err := s.simulateExecutionTime(MethodBatchWrite, req); err != nil { + return err + } + return s.batchWrite(req, stream) +} + +func (s *inMemSpannerServer) batchWrite(req *spannerpb.BatchWriteRequest, stream spannerpb.Spanner_BatchWriteServer) error { + if req.Session == "" { + return gstatus.Error(codes.InvalidArgument, "Missing session name") + } + session, err := s.findSession(req.Session) + if err != nil { + return err + } + s.updateSessionLastUseTime(session.Name) + if len(req.GetMutationGroups()) == 0 { + return gstatus.Error(codes.InvalidArgument, "No mutations in Batch Write") + } + // For each MutationGroup, write a BatchWriteResponse to the response stream + for idx := range req.GetMutationGroups() { + res := &spannerpb.BatchWriteResponse{ + Indexes: []int32{int32(idx)}, + CommitTimestamp: getCurrentTimestamp(), + Status: &status.Status{}, + } + if err = stream.Send(res); err != nil { + return err + } + } + return nil +} diff --git a/spanner/internal/testutil/inmem_spanner_server_test.go b/spanner/internal/testutil/inmem_spanner_server_test.go index 18f667835881..7ee83ac9c4ac 100644 --- a/spanner/internal/testutil/inmem_spanner_server_test.go +++ b/spanner/internal/testutil/inmem_spanner_server_test.go @@ -15,6 +15,7 @@ package testutil_test import ( + "io" "strconv" . "cloud.google.com/go/spanner/internal/testutil" @@ -603,3 +604,103 @@ func TestRollbackTransaction(t *testing.T) { t.Fatal(err) } } + +func TestSpannerBatchWrite(t *testing.T) { + setup() + c, err := apiv1.NewClient(context.Background(), clientOpt) + if err != nil { + t.Fatal(err) + } + + var formattedDatabase = fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]") + var createRequest = &spannerpb.CreateSessionRequest{ + Database: formattedDatabase, + } + session, err := c.CreateSession(context.Background(), createRequest) + if err != nil { + t.Fatal(err) + } + + batchWriteRequest := &spannerpb.BatchWriteRequest{ + Session: session.Name, + MutationGroups: []*spannerpb.BatchWriteRequest_MutationGroup{ + {Mutations: []*spannerpb.Mutation{ + { + Operation: &spannerpb.Mutation_Delete_{ + Delete: &spannerpb.Mutation_Delete{ + Table: "t_test", + KeySet: &spannerpb.KeySet{ + Keys: []*structpb.ListValue{ + { + Values: []*structpb.Value{ + {Kind: &structpb.Value_StringValue{StringValue: "k"}}, + }, + }, + }, + }, + }, + }, + }, + }}, + {Mutations: []*spannerpb.Mutation{ + { + Operation: &spannerpb.Mutation_Insert{ + Insert: &spannerpb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*structpb.ListValue{ + { + Values: []*structpb.Value{ + {Kind: &structpb.Value_StringValue{StringValue: "k"}}, + {Kind: &structpb.Value_StringValue{StringValue: "v"}}, + }, + }, + }, + }, + }, + }, + }}, + }, + } + stream, err := c.BatchWrite(context.Background(), batchWriteRequest) + if err != nil { + t.Fatal(err) + } + // Records the mutation group indexes received in the response. + seen := make(map[int32]int32) + numMutationGroups := len(batchWriteRequest.GetMutationGroups()) + validate := func(res *spannerpb.BatchWriteResponse) { + if status := res.GetStatus().GetCode(); status != int32(codes.OK) { + t.Fatalf("Invalid status: %v", status) + } + if ts := res.GetCommitTimestamp(); ts == nil { + t.Fatal("Invalid commit timestamp") + } + for _, idx := range res.GetIndexes() { + if idx >= 0 && idx < int32(numMutationGroups) { + seen[idx]++ + } else { + t.Fatalf("Index %v out of range. Expected range [%v,%v]", idx, 0, numMutationGroups-1) + } + } + } + for { + response, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + validate(response) + } + // Validate that each mutation group index is seen exactly once. + if numMutationGroups != len(seen) { + t.Fatalf("Expected %v indexes, got %v indexes", numMutationGroups, len(seen)) + } + for idx, ct := range seen { + if ct != 1 { + t.Fatalf("Index %v seen %v times instead of exactly once", idx, ct) + } + } +} diff --git a/spanner/mutation.go b/spanner/mutation.go index 19ced2be66fb..80a992e6c2cf 100644 --- a/spanner/mutation.go +++ b/spanner/mutation.go @@ -141,6 +141,12 @@ type Mutation struct { values []interface{} } +// A MutationGroup is a list of Mutation to be committed atomically. +type MutationGroup struct { + // The mutations in this group + mutations []*Mutation +} + // mapToMutationParams converts Go map into mutation parameters. func mapToMutationParams(in map[string]interface{}) ([]string, []interface{}) { cols := []string{} @@ -432,3 +438,17 @@ func mutationsProto(ms []*Mutation) ([]*sppb.Mutation, error) { } return l, nil } + +// mutationGroupsProto turns a spanner.MutationGroup array into a +// sppb.BatchWriteRequest_MutationGroup array, in preparation to send RPCs. +func mutationGroupsProto(mgs []*MutationGroup) ([]*sppb.BatchWriteRequest_MutationGroup, error) { + gs := make([]*sppb.BatchWriteRequest_MutationGroup, 0, len(mgs)) + for _, mg := range mgs { + ms, err := mutationsProto(mg.mutations) + if err != nil { + return nil, err + } + gs = append(gs, &sppb.BatchWriteRequest_MutationGroup{Mutations: ms}) + } + return gs, nil +} diff --git a/spanner/mutation_test.go b/spanner/mutation_test.go index c5dacf65f985..8047a6b9f9c4 100644 --- a/spanner/mutation_test.go +++ b/spanner/mutation_test.go @@ -618,3 +618,110 @@ func TestEncodeMutationArray(t *testing.T) { } } } + +func TestEncodeMutationGroupArray(t *testing.T) { + for _, test := range []struct { + name string + mgs []*MutationGroup + want []*sppb.BatchWriteRequest_MutationGroup + wantErr error + }{ + { + "Multiple Mutations", + []*MutationGroup{ + {[]*Mutation{ + {opDelete, "t_test", Key{"bar"}, nil, nil}, + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", 1}}, + }}, + {[]*Mutation{ + {opInsert, "t_test", nil, []string{"key", "val"}, []interface{}{"foo2", 1}}, + {opUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo3", 1}}, + }}, + {[]*Mutation{ + {opReplace, "t_test", nil, []string{"key", "val"}, []interface{}{"foo4", 1}}, + }}, + }, + []*sppb.BatchWriteRequest_MutationGroup{ + {Mutations: []*sppb.Mutation{ + { + Operation: &sppb.Mutation_Delete_{ + Delete: &sppb.Mutation_Delete{ + Table: "t_test", + KeySet: &sppb.KeySet{ + Keys: []*proto3.ListValue{listValueProto(stringProto("bar"))}, + }, + }, + }, + }, + { + Operation: &sppb.Mutation_InsertOrUpdate{ + InsertOrUpdate: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{listValueProto(stringProto("foo1"), intProto(1))}, + }, + }, + }, + }}, + {Mutations: []*sppb.Mutation{ + { + Operation: &sppb.Mutation_Insert{ + Insert: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{listValueProto(stringProto("foo2"), intProto(1))}, + }, + }, + }, + { + Operation: &sppb.Mutation_Update{ + Update: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{listValueProto(stringProto("foo3"), intProto(1))}, + }, + }, + }, + }}, + {Mutations: []*sppb.Mutation{ + { + Operation: &sppb.Mutation_Replace{ + Replace: &sppb.Mutation_Write{ + Table: "t_test", + Columns: []string{"key", "val"}, + Values: []*proto3.ListValue{listValueProto(stringProto("foo4"), intProto(1))}, + }, + }, + }, + }}, + }, + nil, + }, + { + "Multiple Mutations - Bad Mutation", + []*MutationGroup{ + {[]*Mutation{ + {opDelete, "t_test", Key{"bar"}, nil, nil}, + {opInsertOrUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo1", struct{}{}}}, + }}, + {[]*Mutation{ + {opInsert, "t_test", nil, []string{"key", "val"}, []interface{}{"foo2", 1}}, + {opUpdate, "t_test", nil, []string{"key", "val"}, []interface{}{"foo3", 1}}, + }}, + }, + []*sppb.BatchWriteRequest_MutationGroup{}, + errEncoderUnsupportedType(struct{}{}), + }, + } { + gotProto, gotErr := mutationGroupsProto(test.mgs) + if gotErr != nil { + if !testEqual(gotErr, test.wantErr) { + t.Errorf("%v: mutationGroupsProto(%v) returns error %v, want %v", test.name, test.mgs, gotErr, test.wantErr) + } + continue + } + if !testEqual(gotProto, test.want) { + t.Errorf("%v: mutationGroupsProto(%v) = (%v, nil), want (%v, nil)", test.name, test.mgs, gotProto, test.want) + } + } +}