diff --git a/.ci/e2e-expected/backslash-dn.txt b/.ci/e2e-expected/backslash-dn.txt index b7baca6ce..9dc3f9902 100644 --- a/.ci/e2e-expected/backslash-dn.txt +++ b/.ci/e2e-expected/backslash-dn.txt @@ -2,7 +2,8 @@ catalog_name | schema_name | schema_owner | default_character_set_catalog | default_character_set_schema | default_character_set_name | sql_path | effective_timestamp +++++++ testdb_e2e_psql_v?? | public | | | | | | + testdb_e2e_psql_v?? | pg_catalog | | | | | | testdb_e2e_psql_v?? | information_schema | | | | | | testdb_e2e_psql_v?? | spanner_sys | | | | | | -(3 rows) +(4 rows) diff --git a/.ci/e2e-expected/backslash-dt.txt b/.ci/e2e-expected/backslash-dt.txt index dc550365c..c0be7f635 100644 --- a/.ci/e2e-expected/backslash-dt.txt +++ b/.ci/e2e-expected/backslash-dt.txt @@ -2,9 +2,27 @@ table_catalog | table_schema | table_name | table_type | self_referencing_column_name | reference_generation | user_defined_type_catalog | user_defined_type_schema | user_defined_type_name | is_insertable_into | is_typed | commit_action | parent_table_name | on_delete_action | spanner_state | interleave_type | row_deletion_policy_expression ++++++++++++++++ testdb_e2e_psql_v?? | public | users | BASE TABLE | | | | | | | | | | | COMMITTED | | - testdb_e2e_psql_v?? | information_schema | check_constraints | VIEW | | | | | | | | | | | | | - testdb_e2e_psql_v?? | information_schema | column_column_usage | VIEW | | | | | | | | | | | | | - testdb_e2e_psql_v?? | information_schema | column_options | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_am | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_attrdef | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_attribute | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_class | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_collation | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_database | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_index | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_indexes | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_inherits | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_namespace | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_policy | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_publication | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_publication_rel | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_roles | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_sequence | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_statistic_ext | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_tables | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | pg_catalog | pg_type | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | information_schema | check_constraints | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | information_schema | column_column_usage | VIEW | | | | | | | | | | | | | + testdb_e2e_psql_v?? | information_schema | column_options | VIEW | | | | | | | | | | | | | testdb_e2e_psql_v?? | information_schema | columns | VIEW | | | | | | | | | | | | | testdb_e2e_psql_v?? | information_schema | constraint_column_usage | VIEW | | | | | | | | | | | | | testdb_e2e_psql_v?? | information_schema | constraint_table_usage | VIEW | | | | | | | | | | | | | @@ -47,5 +65,5 @@ testdb_e2e_psql_v?? | spanner_sys | txn_stats_total_10minute | VIEW | | | | | | | | | | | | | testdb_e2e_psql_v?? | spanner_sys | txn_stats_total_hour | VIEW | | | | | | | | | | | | | testdb_e2e_psql_v?? | spanner_sys | txn_stats_total_minute | VIEW | | | | | | | | | | | | | -(46 rows) +(64 rows) diff --git a/.ci/e2e-expected/copy-from-stdin.txt b/.ci/e2e-expected/copy-from-stdin.txt new file mode 100644 index 000000000..2ce9e9910 --- /dev/null +++ b/.ci/e2e-expected/copy-from-stdin.txt @@ -0,0 +1,18 @@ + id | age | name +----+-----+------ + 1 | 1 | John + 2 | 20 | Joe + 3 | 23 | Jack + 7 | 7 | 7 + 8 | 8 | 8 + 9 | 9 | 9 + 10 | 10 | 10 + 11 | 11 | 11 + 12 | 12 | 12 + 13 | 13 | 13 + 14 | 14 | 14 + 15 | 15 | 15 + 16 | 16 | 16 + 17 | 17 | 17 +(14 rows) + diff --git a/.ci/evaluate-with-psql.sh b/.ci/evaluate-with-psql.sh index 3a726ec1e..18439ba9f 100644 --- a/.ci/evaluate-with-psql.sh +++ b/.ci/evaluate-with-psql.sh @@ -110,3 +110,17 @@ echo "------Test \"-c option invalid begin/commit batching\"------" /usr/lib/postgresql/"${PSQL_VERSION}"/bin/psql -h localhost -p 4242 -d "${GOOGLE_CLOUD_DATABASE_WITH_VERSION}" -c "$(cat .ci/e2e-batching/invalid-commit-batch.txt)" &> .ci/e2e-result/invalid-commit-batching.txt diff -i -w -s .ci/e2e-result/invalid-commit-batching.txt .ci/e2e-expected/invalid-commit-batching.txt RETURN_CODE=$((${RETURN_CODE}||$?)) + +echo "------Test \"COPY FROM STDIN\"------" +/usr/lib/postgresql/"${PSQL_VERSION}"/bin/psql -h localhost -p 4242 -d "${GOOGLE_CLOUD_DATABASE_WITH_VERSION}" -c "COPY users FROM STDIN;" < .ci/e2e-result/copy-from-stdin.txt +diff -i -w -s .ci/e2e-result/copy-from-stdin.txt .ci/e2e-expected/copy-from-stdin.txt +RETURN_CODE=$((${RETURN_CODE}||$?)) \ No newline at end of file diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java b/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java index e47ab9249..7c7d92d80 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/ConnectionHandler.java @@ -319,12 +319,16 @@ public synchronized IntermediateStatement getActiveStatement() { return activeStatementsMap.get(this.connectionId); } + public synchronized ConnectionStatus getStatus() { + return status; + } + public synchronized void setStatus(ConnectionStatus status) { this.status = status; } /** Status of a {@link ConnectionHandler} */ - private enum ConnectionStatus { + public enum ConnectionStatus { UNAUTHENTICATED, IDLE, COPY_IN, diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java index d40e685c5..1db150f7f 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java @@ -167,13 +167,16 @@ private static TypeCode parseSpannerDataType(String columnType) { private void queryInformationSchema() throws SQLException { Map tableColumns = new LinkedHashMap<>(); + // The mutation API requires GoogleSQL type names instead of PostgreSQL type names for the table + // column types. Issue the information_schema table query as a GoogleSQL query so that it will + // return GoogleSQL type names. PreparedStatement statement = this.connection.prepareStatement( - "SELECT " + "/*GSQL*/SELECT " + COLUMN_NAME + ", " + SPANNER_TYPE - + " FROM information_schema.columns WHERE table_name = \"?\""); + + " FROM information_schema.columns WHERE table_name = ?"); statement.setString(1, getTableName()); ResultSet result = statement.executeQuery(); diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java index e817fa789..9531cb185 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java @@ -75,7 +75,7 @@ public int getRowCount() { public void addCopyData(ConnectionHandler connectionHandler, byte[] payload) throws Exception { this.payload.write(payload, 0, payload.length); if (!commitSizeIsWithinLimit()) { - rollback(connectionHandler); + handleError(connectionHandler); throw new SQLException( "Commit size: " + this.payload.size() + " has exceeded the limit: " + COMMIT_LIMIT); } @@ -87,7 +87,7 @@ public void buildMutationList(ConnectionHandler connectionHandler) throws Except for (CSVRecord record : records) { // Check that the number of columns in a record matches the number of columns in the table if (record.size() != this.tableColumns.keySet().size()) { - rollback(connectionHandler); + handleError(connectionHandler); throw new SQLException( "Invalid COPY data: Row length mismatched. Expected " + this.tableColumns.keySet().size() @@ -122,7 +122,7 @@ public void buildMutationList(ConnectionHandler connectionHandler) throws Except break; } } catch (NumberFormatException | DateTimeParseException e) { - rollback(connectionHandler); + handleError(connectionHandler); throw new SQLException( "Invalid input syntax for type " + columnType.toString() @@ -131,10 +131,10 @@ public void buildMutationList(ConnectionHandler connectionHandler) throws Except + recordValue + "\""); } catch (IllegalArgumentException e) { - rollback(connectionHandler); + handleError(connectionHandler); throw new SQLException("Invalid input syntax for column \"" + columnName + "\""); } catch (Exception e) { - rollback(connectionHandler); + handleError(connectionHandler); throw e; } } @@ -143,7 +143,7 @@ public void buildMutationList(ConnectionHandler connectionHandler) throws Except this.rowCount++; // Increment the number of COPY rows by one } if (!mutationCountIsWithinLimit()) { - rollback(connectionHandler); + handleError(connectionHandler); throw new SQLException( "Mutation count: " + mutationCount + " has exceeded the limit: " + MUTATION_LIMIT); } @@ -176,7 +176,14 @@ private List parsePayloadData(byte[] payload) throws IOException { } else { parser = CSVParser.parse(copyData, this.format); } - return parser.getRecords(); + // Skip the last record if that is the '\.' end of file indicator. + List records = parser.getRecords(); + if (!records.isEmpty() + && records.get(records.size() - 1).size() == 1 + && "\\.".equals(records.get(records.size() - 1).get(0))) { + return records.subList(0, records.size() - 1); + } + return records; } /** @@ -198,9 +205,7 @@ public int writeToSpanner(ConnectionHandler connectionHandler) throws SQLExcepti return this.rowCount; } - public void rollback(ConnectionHandler connectionHandler) throws Exception { - Connection connection = connectionHandler.getJdbcConnection(); - connection.rollback(); + public void handleError(ConnectionHandler connectionHandler) throws Exception { this.mutations = new ArrayList<>(); this.mutationCount = 0; writeCopyDataToErrorFile(); diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java index 41528d20f..1da089f16 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/ControlMessage.java @@ -15,6 +15,7 @@ package com.google.cloud.spanner.pgadapter.wireprotocol; import com.google.cloud.spanner.pgadapter.ConnectionHandler; +import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus; import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode; import com.google.cloud.spanner.pgadapter.metadata.SendResultSetState; import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement; @@ -49,35 +50,42 @@ public ControlMessage(ConnectionHandler connection) throws IOException { */ public static ControlMessage create(ConnectionHandler connection) throws Exception { char nextMsg = (char) connection.getConnectionMetadata().getInputStream().readUnsignedByte(); - switch (nextMsg) { - case QueryMessage.IDENTIFIER: - return new QueryMessage(connection); - case ParseMessage.IDENTIFIER: - return new ParseMessage(connection); - case BindMessage.IDENTIFIER: - return new BindMessage(connection); - case DescribeMessage.IDENTIFIER: - return new DescribeMessage(connection); - case ExecuteMessage.IDENTIFIER: - return new ExecuteMessage(connection); - case CloseMessage.IDENTIFIER: - return new CloseMessage(connection); - case SyncMessage.IDENTIFIER: - return new SyncMessage(connection); - case TerminateMessage.IDENTIFIER: - return new TerminateMessage(connection); - case CopyDoneMessage.IDENTIFIER: - return new CopyDoneMessage(connection); - case CopyDataMessage.IDENTIFIER: - return new CopyDataMessage(connection); - case CopyFailMessage.IDENTIFIER: - return new CopyFailMessage(connection); - case FunctionCallMessage.IDENTIFIER: - return new FunctionCallMessage(connection); - case FlushMessage.IDENTIFIER: - return new FlushMessage(connection); - default: - throw new IllegalStateException("Unknown message"); + if (connection.getStatus() == ConnectionStatus.COPY_IN) { + switch (nextMsg) { + case CopyDoneMessage.IDENTIFIER: + return new CopyDoneMessage(connection); + case CopyDataMessage.IDENTIFIER: + return new CopyDataMessage(connection); + case CopyFailMessage.IDENTIFIER: + return new CopyFailMessage(connection); + default: + throw new IllegalStateException("Expected 0 or more Copy Data messages."); + } + } else { + switch (nextMsg) { + case QueryMessage.IDENTIFIER: + return new QueryMessage(connection); + case ParseMessage.IDENTIFIER: + return new ParseMessage(connection); + case BindMessage.IDENTIFIER: + return new BindMessage(connection); + case DescribeMessage.IDENTIFIER: + return new DescribeMessage(connection); + case ExecuteMessage.IDENTIFIER: + return new ExecuteMessage(connection); + case CloseMessage.IDENTIFIER: + return new CloseMessage(connection); + case SyncMessage.IDENTIFIER: + return new SyncMessage(connection); + case TerminateMessage.IDENTIFIER: + return new TerminateMessage(connection); + case FunctionCallMessage.IDENTIFIER: + return new FunctionCallMessage(connection); + case FlushMessage.IDENTIFIER: + return new FlushMessage(connection); + default: + throw new IllegalStateException("Unknown message"); + } } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDoneMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDoneMessage.java index e120c6914..1bac7e0f1 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDoneMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDoneMessage.java @@ -15,6 +15,7 @@ package com.google.cloud.spanner.pgadapter.wireprotocol; import com.google.cloud.spanner.pgadapter.ConnectionHandler; +import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus; import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode; import com.google.cloud.spanner.pgadapter.statements.CopyStatement; import com.google.cloud.spanner.pgadapter.utils.MutationWriter; @@ -52,8 +53,7 @@ protected void sendPayload() throws Exception { // Spanner returned an error when trying to commit the batch of mutations. mw.writeCopyDataToErrorFile(); mw.closeErrorFile(); - // TODO: enable in next PR - // this.connection.setStatus(ConnectionStatus.IDLE); + this.connection.setStatus(ConnectionStatus.IDLE); this.connection.removeActiveStatement(this.statement); throw e; } @@ -61,8 +61,7 @@ protected void sendPayload() throws Exception { mw.closeErrorFile(); } new ReadyResponse(this.outputStream, ReadyResponse.Status.IDLE).send(); - // TODO: enable in next PR - // this.connection.setStatus(ConnectionStatus.IDLE); + this.connection.setStatus(ConnectionStatus.IDLE); this.connection.removeActiveStatement(this.statement); } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/QueryMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/QueryMessage.java index a60a85d18..fc8fbb482 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/QueryMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/QueryMessage.java @@ -16,10 +16,13 @@ import com.google.cloud.spanner.jdbc.CloudSpannerJdbcConnection; import com.google.cloud.spanner.pgadapter.ConnectionHandler; +import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus; import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode; +import com.google.cloud.spanner.pgadapter.statements.CopyStatement; import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement; import com.google.cloud.spanner.pgadapter.statements.MatcherStatement; import com.google.cloud.spanner.pgadapter.utils.StatementParser; +import com.google.cloud.spanner.pgadapter.wireoutput.CopyInResponse; import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse; import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse.Status; import com.google.cloud.spanner.pgadapter.wireoutput.RowDescriptionResponse; @@ -29,13 +32,17 @@ public class QueryMessage extends ControlMessage { protected static final char IDENTIFIER = 'Q'; + protected static final String COPY = "COPY"; private IntermediateStatement statement; public QueryMessage(ConnectionHandler connection) throws Exception { super(connection); String query = StatementParser.removeCommentsAndTrim(this.readAll()); - if (!connection.getServer().getOptions().requiresMatcher()) { + String command = StatementParser.parseCommand(query); + if (command.equalsIgnoreCase(COPY)) { + this.statement = new CopyStatement(query, this.connection.getJdbcConnection()); + } else if (!connection.getServer().getOptions().requiresMatcher()) { this.statement = new IntermediateStatement(query, this.connection.getJdbcConnection()); } else { this.statement = new MatcherStatement(query, this.connection); @@ -47,7 +54,9 @@ public QueryMessage(ConnectionHandler connection) throws Exception { protected void sendPayload() throws Exception { this.statement.execute(); this.handleQuery(); - this.connection.removeActiveStatement(this.statement); + if (!this.statement.getCommand().equalsIgnoreCase(COPY)) { + this.connection.removeActiveStatement(this.statement); + } } @Override @@ -81,7 +90,18 @@ public void handleQuery() throws Exception { if (this.statement.hasException()) { this.handleError(this.statement.getException()); } else { - if (this.statement.containsResultSet()) { + if (this.statement.getCommand().equalsIgnoreCase(COPY)) { + CopyStatement copyStatement = (CopyStatement) this.statement; + new CopyInResponse( + this.outputStream, + copyStatement.getTableColumns().size(), + copyStatement.getFormatCode()) + .send(); + this.connection.setStatus(ConnectionStatus.COPY_IN); + + // Return early as we do not respond with CommandComplete after a COPY command. + return; + } else if (this.statement.containsResultSet()) { new RowDescriptionResponse( this.outputStream, this.statement, diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java index 8f46b1542..7bb1c7051 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/JdbcMockServerTest.java @@ -16,10 +16,25 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; +import com.google.protobuf.ListValue; +import com.google.protobuf.Value; +import com.google.spanner.v1.CommitRequest; import com.google.spanner.v1.ExecuteBatchDmlRequest; import com.google.spanner.v1.ExecuteSqlRequest; +import com.google.spanner.v1.Mutation; +import com.google.spanner.v1.Mutation.OperationCase; +import com.google.spanner.v1.ResultSetMetadata; +import com.google.spanner.v1.StructType; +import com.google.spanner.v1.StructType.Field; +import com.google.spanner.v1.Type; +import com.google.spanner.v1.TypeCode; +import java.io.File; +import java.io.IOException; +import java.io.StringReader; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; @@ -29,6 +44,8 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.postgresql.copy.CopyManager; +import org.postgresql.core.BaseConnection; @RunWith(JUnit4.class) public class JdbcMockServerTest extends AbstractMockServerTest { @@ -58,6 +75,95 @@ public void testQuery() throws SQLException { } } + @Test + public void testCopyIn() throws SQLException, IOException { + setupCopyInformationSchemaResults(); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); + copyManager.copyIn("COPY users FROM STDIN;", new StringReader("5\t5\t5\n6\t6\t6\n7\t7\t7\n")); + } + + List commitRequests = mockSpanner.getRequestsOfType(CommitRequest.class); + assertEquals(1, commitRequests.size()); + CommitRequest commitRequest = commitRequests.get(0); + assertEquals(1, commitRequest.getMutationsCount()); + + Mutation mutation = commitRequest.getMutations(0); + assertEquals(OperationCase.INSERT, mutation.getOperationCase()); + assertEquals(3, mutation.getInsert().getValuesCount()); + } + + @Test + public void testCopyInWithInvalidRow() throws SQLException { + setupCopyInformationSchemaResults(); + + try (Connection connection = DriverManager.getConnection(createUrl())) { + CopyManager copyManager = new CopyManager(connection.unwrap(BaseConnection.class)); + // This row does not contain all the necessary columns. + SQLException exception = + assertThrows( + SQLException.class, + () -> copyManager.copyIn("COPY users FROM STDIN;", new StringReader("5\n"))); + assertTrue( + exception + .getMessage() + .contains("Row length mismatched. Expected 3 columns, but only found 1")); + } finally { + assertTrue(new File("output.txt").delete()); + } + + List commitRequests = mockSpanner.getRequestsOfType(CommitRequest.class); + assertTrue(commitRequests.isEmpty()); + } + + private void setupCopyInformationSchemaResults() { + ResultSetMetadata metadata = + ResultSetMetadata.newBuilder() + .setRowType( + StructType.newBuilder() + .addFields( + Field.newBuilder() + .setName("column_name") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .addFields( + Field.newBuilder() + .setName("spanner_type") + .setType(Type.newBuilder().setCode(TypeCode.STRING).build()) + .build()) + .build()) + .build(); + com.google.spanner.v1.ResultSet resultSet = + com.google.spanner.v1.ResultSet.newBuilder() + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("id").build()) + .addValues(Value.newBuilder().setStringValue("INT64").build()) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("age").build()) + .addValues(Value.newBuilder().setStringValue("INT64").build()) + .build()) + .addRows( + ListValue.newBuilder() + .addValues(Value.newBuilder().setStringValue("name").build()) + .addValues(Value.newBuilder().setStringValue("STRING(MAX)").build()) + .build()) + .setMetadata(metadata) + .build(); + + mockSpanner.putStatementResult( + StatementResult.query( + com.google.cloud.spanner.Statement.newBuilder( + "/*GSQL*/SELECT column_name, spanner_type FROM information_schema.columns WHERE table_name = @p1") + .bind("p1") + .to("users") + .build(), + resultSet)); + } + @Test public void testTwoDmlStatements() throws SQLException { try (Connection connection = DriverManager.getConnection(createUrl())) { diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/ProtocolTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/ProtocolTest.java index ca5d7d97c..a09ceb63a 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/ProtocolTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/ProtocolTest.java @@ -23,6 +23,7 @@ import static org.mockito.Mockito.when; import com.google.cloud.spanner.jdbc.CloudSpannerJdbcConnection; +import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus; import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode; import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata; import com.google.cloud.spanner.pgadapter.metadata.DescribePortalMetadata; @@ -1197,6 +1198,7 @@ public void testCopyDataMessage() throws Exception { CopyStatement copyStatement = Mockito.mock(CopyStatement.class); Mockito.when(connectionHandler.getActiveStatement()).thenReturn(copyStatement); Mockito.when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); + Mockito.when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.COPY_IN); Mockito.when(connectionMetadata.getInputStream()).thenReturn(inputStream); Mockito.when(connectionMetadata.getOutputStream()).thenReturn(outputStream); @@ -1219,6 +1221,7 @@ public void testMultipleCopyDataMessages() throws Exception { Mockito.when(connection.prepareStatement(ArgumentMatchers.anyString())) .thenReturn(preparedStatement); Mockito.when(connectionHandler.getJdbcConnection()).thenReturn(connection); + Mockito.when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.COPY_IN); Mockito.when(statement.getUpdateCount()).thenReturn(1); byte[] messageMetadata = {'d'}; @@ -1295,6 +1298,7 @@ public void testCopyDoneMessage() throws Exception { CopyStatement copyStatement = Mockito.mock(CopyStatement.class); Mockito.when(connectionHandler.getActiveStatement()).thenReturn(copyStatement); Mockito.when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); + Mockito.when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.COPY_IN); Mockito.when(connectionMetadata.getInputStream()).thenReturn(inputStream); Mockito.when(connectionMetadata.getOutputStream()).thenReturn(outputStream); @@ -1330,6 +1334,7 @@ public void testCopyFailMessage() throws Exception { CopyStatement copyStatement = Mockito.mock(CopyStatement.class); Mockito.when(connectionHandler.getActiveStatement()).thenReturn(copyStatement); Mockito.when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); + Mockito.when(connectionHandler.getStatus()).thenReturn(ConnectionStatus.COPY_IN); Mockito.when(connectionMetadata.getInputStream()).thenReturn(inputStream); Mockito.when(connectionMetadata.getOutputStream()).thenReturn(outputStream);