diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java index c0ea84195..d40e685c5 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/statements/CopyStatement.java @@ -127,6 +127,11 @@ public MutationWriter getMutationWriter() { return this.mutationWriter; } + /** @return 0 for text/csv formatting and 1 for binary */ + public int getFormatCode() { + return (options.getFormat() == CopyTreeParser.CopyOptions.Format.BINARY) ? 1 : 0; + } + private void verifyCopyColumns() throws SQLException { if (options.getColumnNames().size() == 0) { // Use all columns if none were specified. diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java b/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java index 6c92071c3..e817fa789 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/utils/MutationWriter.java @@ -19,6 +19,7 @@ import com.google.cloud.spanner.jdbc.CloudSpannerJdbcConnection; import com.google.cloud.spanner.pgadapter.ConnectionHandler; import com.google.spanner.v1.TypeCode; +import java.io.ByteArrayOutputStream; import java.io.File; import java.io.FileWriter; import java.io.IOException; @@ -42,18 +43,17 @@ public class MutationWriter { private boolean hasHeader; private boolean isHeaderParsed; private int mutationCount; - private int batchSize; private int rowCount; private List mutations; private String tableName; private Map tableColumns; private CSVFormat format; private FileWriter fileWriter; + private ByteArrayOutputStream payload = new ByteArrayOutputStream(); public MutationWriter( String tableName, Map tableColumns, CSVFormat format, boolean hasHeader) { this.mutationCount = 0; - this.batchSize = 0; this.hasHeader = hasHeader; this.isHeaderParsed = false; this.tableName = tableName; @@ -72,25 +72,22 @@ public int getRowCount() { return this.rowCount; } - /** Build mutation to add to mutations list with data contained within a CopyData payload */ - public void buildMutation(ConnectionHandler connectionHandler, byte[] payload) throws Exception { - List records = parsePayloadData(payload); - if (!records.isEmpty() - && !payloadFitsInCurrentBatch(records.size() * records.get(0).size(), payload.length)) { - rollback(connectionHandler, payload); - long mutationCount = this.mutationCount + records.size() * records.get(0).size(); - long commitSize = this.batchSize + payload.length; + public void addCopyData(ConnectionHandler connectionHandler, byte[] payload) throws Exception { + this.payload.write(payload, 0, payload.length); + if (!commitSizeIsWithinLimit()) { + rollback(connectionHandler); throw new SQLException( - "Mutation count: " - + mutationCount - + " or mutation commit size: " - + commitSize - + " has exceeded the limit."); + "Commit size: " + this.payload.size() + " has exceeded the limit: " + COMMIT_LIMIT); } + } + + /** Build mutation to add to mutations list with data contained within a CopyData payload */ + public void buildMutationList(ConnectionHandler connectionHandler) throws Exception { + List records = parsePayloadData(this.payload.toByteArray()); for (CSVRecord record : records) { // Check that the number of columns in a record matches the number of columns in the table if (record.size() != this.tableColumns.keySet().size()) { - rollback(connectionHandler, payload); + rollback(connectionHandler); throw new SQLException( "Invalid COPY data: Row length mismatched. Expected " + this.tableColumns.keySet().size() @@ -125,7 +122,7 @@ public void buildMutation(ConnectionHandler connectionHandler, byte[] payload) t break; } } catch (NumberFormatException | DateTimeParseException e) { - rollback(connectionHandler, payload); + rollback(connectionHandler); throw new SQLException( "Invalid input syntax for type " + columnType.toString() @@ -134,10 +131,10 @@ public void buildMutation(ConnectionHandler connectionHandler, byte[] payload) t + recordValue + "\""); } catch (IllegalArgumentException e) { - rollback(connectionHandler, payload); + rollback(connectionHandler); throw new SQLException("Invalid input syntax for column \"" + columnName + "\""); } catch (Exception e) { - rollback(connectionHandler, payload); + rollback(connectionHandler); throw e; } } @@ -145,16 +142,28 @@ public void buildMutation(ConnectionHandler connectionHandler, byte[] payload) t this.mutationCount += record.size(); // Increment the number of mutations being added this.rowCount++; // Increment the number of COPY rows by one } - this.batchSize += payload.length; // Increment the batch size based on payload length + if (!mutationCountIsWithinLimit()) { + rollback(connectionHandler); + throw new SQLException( + "Mutation count: " + mutationCount + " has exceeded the limit: " + MUTATION_LIMIT); + } } /** - * @return True if adding payload to current batch will fit under mutation limit and batch size - * limit, false otherwise. + * @return True if current payload will fit within COMMIT_LIMIT. This is only an estimate and the + * actual commit size may still be rejected by Spanner. */ - private boolean payloadFitsInCurrentBatch(int rowMutationCount, int payloadLength) { - return (this.mutationCount + rowMutationCount <= MUTATION_LIMIT - && this.batchSize + payloadLength <= COMMIT_LIMIT); + private boolean commitSizeIsWithinLimit() { + return this.payload.size() <= COMMIT_LIMIT; + } + + /** + * @return True if current mutation count will fit within MUTATION_LIMIT. This is only an estimate + * and the actual number of mutations may be different which could result in spanner rejecting + * the transaction. + */ + private boolean mutationCountIsWithinLimit() { + return this.mutationCount <= MUTATION_LIMIT; } /** @return list of CSVRecord rows parsed with CSVParser from CopyData payload byte array */ @@ -186,44 +195,34 @@ public int writeToSpanner(ConnectionHandler connectionHandler) throws SQLExcepti // Reset mutations, mutation counter, and batch size count for a new batch this.mutations = new ArrayList<>(); this.mutationCount = 0; - this.batchSize = 0; return this.rowCount; } - public void rollback(ConnectionHandler connectionHandler, byte[] payload) throws Exception { + public void rollback(ConnectionHandler connectionHandler) throws Exception { Connection connection = connectionHandler.getJdbcConnection(); connection.rollback(); this.mutations = new ArrayList<>(); this.mutationCount = 0; - this.batchSize = 0; - createErrorFile(payload); + writeCopyDataToErrorFile(); + this.payload.reset(); } - public void createErrorFile(byte[] payload) throws IOException { + private void createErrorFile() throws IOException { File unsuccessfulCopy = new File(ERROR_FILE); - if (unsuccessfulCopy.createNewFile()) { - this.fileWriter = new FileWriter(ERROR_FILE); - writeToErrorFile(payload); - } else { - System.err.println("File " + unsuccessfulCopy.getName() + " already exists"); - } - } - - public void writeToErrorFile(byte[] payload) throws IOException { - if (this.fileWriter != null) { - this.fileWriter.write(new String(payload, StandardCharsets.UTF_8).trim() + "\n"); - } + this.fileWriter = new FileWriter(unsuccessfulCopy, false); } - public void writeMutationsToErrorFile() throws IOException { - File unsuccessfulCopy = new File(ERROR_FILE); - if (unsuccessfulCopy.createNewFile()) { - this.fileWriter = new FileWriter(ERROR_FILE); - } - - for (Mutation mutation : this.mutations) { - this.fileWriter.write(mutation.toString()); + /** + * Copy data will be written to an error file if size limits were exceeded or a problem was + * encountered. Copy data will also written if an error was encountered while generating the + * mutaiton list or if Spanner returns an error upon commiting the mutations. + */ + public void writeCopyDataToErrorFile() throws IOException { + if (this.fileWriter == null) { + createErrorFile(); } + this.fileWriter.write( + new String(this.payload.toByteArray(), StandardCharsets.UTF_8).trim() + "\n"); } public void closeErrorFile() throws IOException { diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDataMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDataMessage.java index 25a6a190b..0606fefa6 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDataMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDataMessage.java @@ -50,14 +50,14 @@ protected void sendPayload() throws Exception { MutationWriter mw = this.statement.getMutationWriter(); if (!statement.hasException()) { try { - mw.buildMutation(this.connection, this.payload); + mw.addCopyData(this.connection, this.payload); } catch (SQLException e) { - mw.writeToErrorFile(this.payload); + mw.writeCopyDataToErrorFile(); statement.handleExecutionException(e); throw e; } } else { - mw.writeToErrorFile(this.payload); + mw.writeCopyDataToErrorFile(); } } diff --git a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDoneMessage.java b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDoneMessage.java index f26bcdcf7..e120c6914 100644 --- a/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDoneMessage.java +++ b/src/main/java/com/google/cloud/spanner/pgadapter/wireprotocol/CopyDoneMessage.java @@ -43,14 +43,17 @@ protected void sendPayload() throws Exception { MutationWriter mw = this.statement.getMutationWriter(); if (!statement.hasException()) { try { + mw.buildMutationList(this.connection); int rowCount = mw.writeToSpanner(this.connection); // Write any remaining mutations to Spanner statement.addUpdateCount(rowCount); // Increase the row count of number of rows copied. this.sendSpannerResult(this.statement, QueryMode.SIMPLE, 0L); } catch (Exception e) { // Spanner returned an error when trying to commit the batch of mutations. - mw.writeMutationsToErrorFile(); + mw.writeCopyDataToErrorFile(); mw.closeErrorFile(); + // TODO: enable in next PR + // this.connection.setStatus(ConnectionStatus.IDLE); this.connection.removeActiveStatement(this.statement); throw e; } @@ -58,6 +61,8 @@ protected void sendPayload() throws Exception { mw.closeErrorFile(); } new ReadyResponse(this.outputStream, ReadyResponse.Status.IDLE).send(); + // TODO: enable in next PR + // this.connection.setStatus(ConnectionStatus.IDLE); this.connection.removeActiveStatement(this.statement); } diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/ProtocolTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/ProtocolTest.java index 41a286bce..ca5d7d97c 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/ProtocolTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/ProtocolTest.java @@ -108,6 +108,7 @@ public class ProtocolTest { @Mock private ConnectionMetadata connectionMetadata; @Mock private DataOutputStream outputStream; @Mock private ResultSet resultSet; + @Mock private MutationWriter mutationWriter; private byte[] intToBytes(int value) { byte[] parameters = new byte[4]; @@ -1199,8 +1200,8 @@ public void testCopyDataMessage() throws Exception { Mockito.when(connectionMetadata.getInputStream()).thenReturn(inputStream); Mockito.when(connectionMetadata.getOutputStream()).thenReturn(outputStream); - MutationWriter mb = Mockito.mock(MutationWriter.class); - Mockito.when(copyStatement.getMutationWriter()).thenReturn(mb); + MutationWriter mw = Mockito.mock(MutationWriter.class); + Mockito.when(copyStatement.getMutationWriter()).thenReturn(mw); WireMessage message = ControlMessage.create(connectionHandler); Assert.assertEquals(message.getClass(), CopyDataMessage.class); @@ -1209,7 +1210,76 @@ public void testCopyDataMessage() throws Exception { CopyDataMessage messageSpy = (CopyDataMessage) Mockito.spy(message); messageSpy.send(); - Mockito.verify(mb, Mockito.times(1)).buildMutation(connectionHandler, payload); + Mockito.verify(mw, Mockito.times(1)).addCopyData(connectionHandler, payload); + } + + @Test + public void testMultipleCopyDataMessages() throws Exception { + Mockito.when(connection.createStatement()).thenReturn(statement); + Mockito.when(connection.prepareStatement(ArgumentMatchers.anyString())) + .thenReturn(preparedStatement); + Mockito.when(connectionHandler.getJdbcConnection()).thenReturn(connection); + Mockito.when(statement.getUpdateCount()).thenReturn(1); + + byte[] messageMetadata = {'d'}; + byte[] payload1 = "1\t'one'\n2\t".getBytes(); + byte[] payload2 = "'two'\n3\t'th".getBytes(); + byte[] payload3 = "ree'\n4\t'four'\n".getBytes(); + byte[] length1 = intToBytes(4 + payload1.length); + byte[] length2 = intToBytes(4 + payload2.length); + byte[] length3 = intToBytes(4 + payload3.length); + byte[] value1 = Bytes.concat(messageMetadata, length1, payload1); + byte[] value2 = Bytes.concat(messageMetadata, length2, payload2); + byte[] value3 = Bytes.concat(messageMetadata, length3, payload3); + + DataInputStream inputStream1 = new DataInputStream(new ByteArrayInputStream(value1)); + DataInputStream inputStream2 = new DataInputStream(new ByteArrayInputStream(value2)); + DataInputStream inputStream3 = new DataInputStream(new ByteArrayInputStream(value3)); + + ResultSet spannerType = Mockito.mock(ResultSet.class); + Mockito.when(spannerType.getString("column_name")).thenReturn("key", "value"); + Mockito.when(spannerType.getString("spanner_type")).thenReturn("INT64", "STRING"); + Mockito.when(spannerType.next()).thenReturn(true, true, false); + Mockito.when(preparedStatement.executeQuery()).thenReturn(spannerType); + + CopyStatement copyStatement = new CopyStatement("COPY keyvalue FROM STDIN;", connection); + copyStatement.execute(); + + Mockito.when(connectionHandler.getActiveStatement()).thenReturn(copyStatement); + Mockito.when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata); + Mockito.when(connectionMetadata.getOutputStream()).thenReturn(outputStream); + + { + Mockito.when(connectionMetadata.getInputStream()).thenReturn(inputStream1); + WireMessage message = ControlMessage.create(connectionHandler); + Assert.assertEquals(message.getClass(), CopyDataMessage.class); + Assert.assertArrayEquals(((CopyDataMessage) message).getPayload(), payload1); + CopyDataMessage copyDataMessage = (CopyDataMessage) message; + copyDataMessage.send(); + } + { + Mockito.when(connectionMetadata.getInputStream()).thenReturn(inputStream2); + WireMessage message = ControlMessage.create(connectionHandler); + Assert.assertEquals(message.getClass(), CopyDataMessage.class); + Assert.assertArrayEquals(((CopyDataMessage) message).getPayload(), payload2); + CopyDataMessage copyDataMessage = (CopyDataMessage) message; + copyDataMessage.send(); + } + { + Mockito.when(connectionMetadata.getInputStream()).thenReturn(inputStream3); + WireMessage message = ControlMessage.create(connectionHandler); + Assert.assertEquals(message.getClass(), CopyDataMessage.class); + Assert.assertArrayEquals(((CopyDataMessage) message).getPayload(), payload3); + CopyDataMessage copyDataMessage = (CopyDataMessage) message; + copyDataMessage.send(); + } + + MutationWriter mw = copyStatement.getMutationWriter(); + mw.buildMutationList(connectionHandler); + Assert.assertEquals( + mw.getMutations().toString(), + "[insert(keyvalue{key=1,value='one'}), insert(keyvalue{key=2,value='two'}), " + + "insert(keyvalue{key=3,value='three'}), insert(keyvalue{key=4,value='four'})]"); } @Test @@ -1290,7 +1360,8 @@ public void testCopyFromFilePipe() throws Exception { copyStatement.execute(); MutationWriter mw = copyStatement.getMutationWriter(); - mw.buildMutation(connectionHandler, payload); + mw.addCopyData(connectionHandler, payload); + mw.buildMutationList(connectionHandler); Assert.assertEquals(copyStatement.getFormatType(), "TEXT"); Assert.assertEquals(copyStatement.getDelimiterChar(), '\t'); @@ -1327,7 +1398,8 @@ public void testCopyBatchSizeLimit() throws Exception { MutationWriter mw = copyStatement.getMutationWriter(); MutationWriter mwSpy = Mockito.spy(mw); Mockito.when(mwSpy.writeToSpanner(connectionHandler)).thenReturn(10, 2); - mwSpy.buildMutation(connectionHandler, payload); + mwSpy.addCopyData(connectionHandler, payload); + mwSpy.buildMutationList(connectionHandler); mwSpy.writeToSpanner(connectionHandler); Assert.assertEquals(copyStatement.getFormatType(), "TEXT"); @@ -1335,7 +1407,6 @@ public void testCopyBatchSizeLimit() throws Exception { Assert.assertEquals(copyStatement.hasException(), false); Assert.assertEquals(mwSpy.getRowCount(), 12); - // Verify writeToSpanner is called once inside buildMutation when batch size is exceeded Mockito.verify(mwSpy, Mockito.times(1)).writeToSpanner(connectionHandler); copyStatement.close(); } @@ -1370,7 +1441,8 @@ public void testCopyDataRowLengthMismatchLimit() throws Exception { Assert.assertThrows( SQLException.class, () -> { - mwSpy.buildMutation(connectionHandler, payload); + mwSpy.addCopyData(connectionHandler, payload); + mwSpy.buildMutationList(connectionHandler); ; }); Assert.assertEquals( @@ -1408,14 +1480,13 @@ public void testCopyResumeErrorOutputFile() throws Exception { Assert.assertThrows( SQLException.class, () -> { - mwSpy.buildMutation(connectionHandler, payload); + mwSpy.addCopyData(connectionHandler, payload); + mwSpy.buildMutationList(connectionHandler); mwSpy.writeToSpanner(connectionHandler); }); Assert.assertEquals(thrown.getMessage(), "Invalid input syntax for type INT64:\"'5'\""); - Mockito.verify(mwSpy, Mockito.times(1)).createErrorFile(payload); - Mockito.verify(mwSpy, Mockito.times(1)).writeToErrorFile(payload); - + Mockito.verify(mwSpy, Mockito.times(1)).writeCopyDataToErrorFile(); File outputFile = new File("output.txt"); Assert.assertTrue(outputFile.exists()); Assert.assertTrue(outputFile.isFile()); @@ -1454,14 +1525,13 @@ public void testCopyResumeErrorStartOutputFile() throws Exception { Assert.assertThrows( SQLException.class, () -> { - mwSpy.buildMutation(connectionHandler, payload); + mwSpy.addCopyData(connectionHandler, payload); + mwSpy.buildMutationList(connectionHandler); mwSpy.writeToSpanner(connectionHandler); }); Assert.assertEquals(thrown.getMessage(), "Invalid input syntax for type INT64:\"'1'\""); - Mockito.verify(mwSpy, Mockito.times(1)).createErrorFile(payload); - Mockito.verify(mwSpy, Mockito.times(1)).writeToErrorFile(payload); - + Mockito.verify(mwSpy, Mockito.times(1)).writeCopyDataToErrorFile(); File outputFile = new File("output.txt"); Assert.assertTrue(outputFile.exists()); Assert.assertTrue(outputFile.isFile()); diff --git a/src/test/java/com/google/cloud/spanner/pgadapter/StatementTest.java b/src/test/java/com/google/cloud/spanner/pgadapter/StatementTest.java index 739cb235f..c1f206081 100644 --- a/src/test/java/com/google/cloud/spanner/pgadapter/StatementTest.java +++ b/src/test/java/com/google/cloud/spanner/pgadapter/StatementTest.java @@ -453,7 +453,8 @@ public void testCopyBuildMutation() throws Exception { byte[] payload = "2\t3\n".getBytes(); MutationWriter mw = statement.getMutationWriter(); - mw.buildMutation(connectionHandler, payload); + mw.addCopyData(connectionHandler, payload); + mw.buildMutationList(connectionHandler); Assert.assertEquals(statement.getFormatType(), "TEXT"); Assert.assertEquals(statement.getDelimiterChar(), '\t'); @@ -489,7 +490,8 @@ public void testCopyInvalidBuildMutation() throws Exception { Assert.assertThrows( SQLException.class, () -> { - mw.buildMutation(connectionHandler, payload); + mw.addCopyData(connectionHandler, payload); + mw.buildMutationList(connectionHandler); }); Assert.assertEquals( thrown.getMessage(),