diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerBatchUpdateException.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerBatchUpdateException.java new file mode 100644 index 000000000000..0e51c5f91f31 --- /dev/null +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerBatchUpdateException.java @@ -0,0 +1,32 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner; + +public class SpannerBatchUpdateException extends SpannerException { + private long[] updateCounts; + /** Private constructor. Use {@link SpannerExceptionFactory} to create instances. */ + SpannerBatchUpdateException( + DoNotConstructDirectly token, ErrorCode code, String message, long[] counts) { + super(token, code, false, message, null); + updateCounts = counts; + } + + /** Returns the number of rows affected by each statement that is successfully run. */ + public long[] getUpdateCounts() { + return updateCounts; + } +} diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java index f6f6210d85e7..3ff2d6749778 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerExceptionFactory.java @@ -82,6 +82,12 @@ public static SpannerException newSpannerException(Throwable cause) { return newSpannerException(null, cause); } + public static SpannerBatchUpdateException newSpannerBatchUpdateException( + ErrorCode code, String message, long[] updateCounts) { + DoNotConstructDirectly token = DoNotConstructDirectly.ALLOWED; + return new SpannerBatchUpdateException(token, code, message, updateCounts); + } + /** * Creates a new exception based on {@code cause}. If {@code cause} indicates cancellation, {@code * context} will be inspected to establish the type of cancellation. diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java index 2424cf321e16..c8e3506ca596 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java @@ -16,6 +16,7 @@ package com.google.cloud.spanner; +import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerBatchUpdateException; import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerException; import static com.google.cloud.spanner.SpannerExceptionFactory.newSpannerExceptionForCancellation; import static com.google.common.base.Preconditions.checkArgument; @@ -69,6 +70,7 @@ import com.google.spanner.v1.BeginTransactionRequest; import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.CommitResponse; +import com.google.spanner.v1.ExecuteBatchDmlRequest; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.ExecuteSqlRequest.QueryMode; import com.google.spanner.v1.PartialResultSet; @@ -1080,6 +1082,36 @@ ExecuteSqlRequest.Builder getExecuteSqlRequestBuilder( return builder; } + ExecuteBatchDmlRequest.Builder getExecuteBatchDmlRequestBuilder( + Iterable statements) { + ExecuteBatchDmlRequest.Builder builder = + ExecuteBatchDmlRequest.newBuilder().setSession(session.name); + int idx = 0; + for (Statement stmt : statements) { + builder.addStatementsBuilder(); + builder.getStatementsBuilder(idx).setSql(stmt.getSql()); + Map stmtParameters = stmt.getParameters(); + if (!stmtParameters.isEmpty()) { + com.google.protobuf.Struct.Builder paramsBuilder = + builder.getStatementsBuilder(idx).getParamsBuilder(); + for (Map.Entry param : stmtParameters.entrySet()) { + paramsBuilder.putFields(param.getKey(), param.getValue().toProto()); + builder + .getStatementsBuilder(idx) + .putParamTypes(param.getKey(), param.getValue().getType().toProto()); + } + } + idx++; + } + + TransactionSelector selector = getTransactionSelector(); + if (selector != null) { + builder.setTransaction(selector); + } + builder.setSeqno(getSeqNo()); + return builder; + } + ResultSet executeQueryInternalWithOptions( Statement statement, com.google.spanner.v1.ExecuteSqlRequest.QueryMode queryMode, @@ -1660,6 +1692,32 @@ public com.google.spanner.v1.ResultSet call() throws Exception { // For standard DML, using the exact row count. return resultSet.getStats().getRowCountExact(); } + + @Override + public long[] batchUpdate(Iterable statements) { + beforeReadOrQuery(); + final ExecuteBatchDmlRequest.Builder builder = getExecuteBatchDmlRequestBuilder(statements); + com.google.spanner.v1.ExecuteBatchDmlResponse response = + runWithRetries( + new Callable() { + @Override + public com.google.spanner.v1.ExecuteBatchDmlResponse call() throws Exception { + return rpc.executeBatchDml(builder.build(), session.options); + } + }); + long[] results = new long[response.getResultSetsCount()]; + for (int i = 0; i < response.getResultSetsCount(); ++i) { + results[i] = response.getResultSets(i).getStats().getRowCountExact(); + } + + if (response.getStatus().getCode() != 0) { + throw newSpannerBatchUpdateException( + ErrorCode.fromRpcStatus(response.getStatus()), + response.getStatus().getMessage(), + results); + } + return results; + } } /** diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java index 59e4c52c28b7..a529c4c492be 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/TransactionContext.java @@ -101,4 +101,21 @@ public interface TransactionContext extends ReadContext { * visible to subsequent operations in the transaction. */ long executeUpdate(Statement statement); + + /** + * Executes a list of DML statements in a single request. The statements will be executed in order + * and the semantics is the same as if each statement is executed by {@code executeUpdate} in a + * loop. This method returns an array of long integers, each representing the number of rows + * modified by each statement. + * + *

If an individual statement fails, execution stops and a {@code SpannerBatchUpdateException} + * is returned, which includes the error and the number of rows affected by the statements that + * are run prior to the error. + * + *

For example, if statements contains 3 statements, and the 2nd one is not a valid DML. This + * method throws a {@code SpannerBatchUpdateException} that contains the error message from the + * 2nd statement, and an array of length 1 that contains the number of rows modified by the 1st + * statement. The 3rd statement will not run. + */ + long[] batchUpdate(Iterable statements); } diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 5469699abd63..262808392537 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -90,6 +90,8 @@ import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.CreateSessionRequest; import com.google.spanner.v1.DeleteSessionRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.PartialResultSet; import com.google.spanner.v1.PartitionQueryRequest; @@ -514,6 +516,14 @@ public void cancel(String message) { }; } + @Override + public ExecuteBatchDmlResponse executeBatchDml( + ExecuteBatchDmlRequest request, @Nullable Map options) { + + GrpcCallContext context = newCallContext(options, request.getSession()); + return get(spannerStub.executeBatchDmlCallable().futureCall(request, context)); + } + @Override public Transaction beginTransaction( BeginTransactionRequest request, @Nullable Map options) throws SpannerException { diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java index 107b92fa24e8..500f369f67a8 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java @@ -34,6 +34,8 @@ import com.google.spanner.v1.BeginTransactionRequest; import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.CommitResponse; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.PartialResultSet; import com.google.spanner.v1.PartitionQueryRequest; @@ -214,6 +216,8 @@ StreamingCall read( StreamingCall executeQuery( ExecuteSqlRequest request, ResultStreamConsumer consumer, @Nullable Map options); + ExecuteBatchDmlResponse executeBatchDml(ExecuteBatchDmlRequest build, Map options); + Transaction beginTransaction(BeginTransactionRequest request, @Nullable Map options) throws SpannerException; diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITBatchDmlTest.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITBatchDmlTest.java new file mode 100644 index 000000000000..c5090d10e4e5 --- /dev/null +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITBatchDmlTest.java @@ -0,0 +1,208 @@ +/* + * Copyright 2019 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.spanner.it; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.api.gax.longrunning.OperationFuture; +import com.google.cloud.spanner.Database; +import com.google.cloud.spanner.DatabaseClient; +import com.google.cloud.spanner.ErrorCode; +import com.google.cloud.spanner.IntegrationTest; +import com.google.cloud.spanner.IntegrationTestEnv; +import com.google.cloud.spanner.SpannerBatchUpdateException; +import com.google.cloud.spanner.SpannerException; +import com.google.cloud.spanner.Statement; +import com.google.cloud.spanner.TransactionContext; +import com.google.cloud.spanner.TransactionRunner; +import com.google.cloud.spanner.TransactionRunner.TransactionCallable; +import com.google.spanner.admin.database.v1.UpdateDatabaseDdlMetadata; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Integration tests for DML. */ +@Category(IntegrationTest.class) +@RunWith(JUnit4.class) +public final class ITBatchDmlTest { + + private static Database db; + @ClassRule public static IntegrationTestEnv env = new IntegrationTestEnv(); + + private static final String INSERT_DML = + "INSERT INTO T (k, v) VALUES ('boo1', 1), ('boo2', 2), ('boo3', 3), ('boo4', 4);"; + private static final String UPDATE_DML = "UPDATE T SET T.V = 100 WHERE T.K LIKE 'boo%';"; + private static final String DELETE_DML = "DELETE FROM T WHERE T.K like 'boo%';"; + private static DatabaseClient client; + + @BeforeClass + public static void createDatabase() { + db = env.getTestHelper().createTestDatabase(); + client = env.getTestHelper().getDatabaseClient(db); + } + + @Before + public void createTable() throws Exception { + String ddl = + "CREATE TABLE T (" + " K STRING(MAX) NOT NULL," + " V INT64," + ") PRIMARY KEY (K)"; + OperationFuture op = db.updateDdl(Arrays.asList(ddl), null); + op.get(); + } + + @After + public void dropTable() throws Exception { + String ddl = "DROP TABLE T"; + OperationFuture op = db.updateDdl(Arrays.asList(ddl), null); + op.get(); + } + + @Test + public void noStatementsInRequest() { + final TransactionCallable callable = + new TransactionCallable() { + @Override + public long[] run(TransactionContext transaction) { + List stmts = new ArrayList<>(); + long[] rowCounts; + try { + rowCounts = transaction.batchUpdate(stmts); + Assert.fail("Expecting an exception."); + } catch (SpannerException e) { + assertThat(e instanceof SpannerBatchUpdateException).isFalse(); + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT); + assertThat(e.getMessage()).contains("No statements in batch DML request."); + rowCounts = new long[0]; + } + return rowCounts; + } + }; + TransactionRunner runner = client.readWriteTransaction(); + long[] rowCounts = runner.run(callable); + assertThat(rowCounts.length).isEqualTo(0); + } + + @Test + public void batchDml() { + final TransactionCallable callable = + new TransactionCallable() { + @Override + public long[] run(TransactionContext transaction) throws Exception { + List stmts = new ArrayList<>(); + stmts.add(Statement.of(INSERT_DML)); + stmts.add(Statement.of(UPDATE_DML)); + stmts.add(Statement.of(DELETE_DML)); + return transaction.batchUpdate(stmts); + } + }; + TransactionRunner runner = client.readWriteTransaction(); + long[] rowCounts = runner.run(callable); + assertThat(rowCounts.length).isEqualTo(3); + for (long rc : rowCounts) { + assertThat(rc).isEqualTo(4); + } + } + + @Test + public void mixedBatchDmlAndDml() { + final TransactionCallable callable = + new TransactionCallable() { + @Override + public long[] run(TransactionContext transaction) throws Exception { + long rowCount = transaction.executeUpdate(Statement.of(INSERT_DML)); + List stmts = new ArrayList<>(); + stmts.add(Statement.of(UPDATE_DML)); + stmts.add(Statement.of(DELETE_DML)); + long[] batchRowCounts = transaction.batchUpdate(stmts); + long[] rowCounts = new long[batchRowCounts.length + 1]; + System.arraycopy(batchRowCounts, 0, rowCounts, 0, batchRowCounts.length); + rowCounts[batchRowCounts.length] = rowCount; + return rowCounts; + } + }; + TransactionRunner runner = client.readWriteTransaction(); + long[] rowCounts = runner.run(callable); + assertThat(rowCounts.length).isEqualTo(3); + for (long rc : rowCounts) { + assertThat(rc).isEqualTo(4); + } + } + + @Test + public void errorBatchDmlIllegalStatement() { + final TransactionCallable callable = + new TransactionCallable() { + @Override + public long[] run(TransactionContext transaction) { + List stmts = new ArrayList<>(); + stmts.add(Statement.of(INSERT_DML)); + stmts.add(Statement.of("some illegal statement")); + stmts.add(Statement.of(UPDATE_DML)); + return transaction.batchUpdate(stmts); + } + }; + TransactionRunner runner = client.readWriteTransaction(); + try { + runner.run(callable); + Assert.fail("Expecting an exception."); + } catch (SpannerBatchUpdateException e) { + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.INVALID_ARGUMENT); + assertThat(e.getMessage()).contains("is not valid DML."); + long[] rowCounts = e.getUpdateCounts(); + assertThat(rowCounts.length).isEqualTo(1); + for (long rc : rowCounts) { + assertThat(rc).isEqualTo(4); + } + } + } + + @Test + public void errorBatchDmlAlreadyExist() { + final TransactionCallable callable = + new TransactionCallable() { + @Override + public long[] run(TransactionContext transaction) { + List stmts = new ArrayList<>(); + stmts.add(Statement.of(INSERT_DML)); + stmts.add(Statement.of(INSERT_DML)); // should fail + stmts.add(Statement.of(UPDATE_DML)); + return transaction.batchUpdate(stmts); + } + }; + TransactionRunner runner = client.readWriteTransaction(); + try { + runner.run(callable); + Assert.fail("Expecting an exception."); + } catch (SpannerBatchUpdateException e) { + assertThat(e.getErrorCode()).isEqualTo(ErrorCode.ALREADY_EXISTS); + assertThat(e.getMessage()).contains("already exists"); + long[] rowCounts = e.getUpdateCounts(); + assertThat(rowCounts.length).isEqualTo(1); + for (long rc : rowCounts) { + assertThat(rc).isEqualTo(4); + } + } + } +}