Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Copy command parsing in QueryMessage and basic psql e2e test #43

Merged
merged 21 commits into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .ci/e2e-expected/copy-from-stdin.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
id | age | name
----+-----+------
1 | 1 | John
2 | 20 | Joe
3 | 23 | Jack
7 | 7 | 7
8 | 8 | 8
9 | 9 | 9
10 | 10 | 10
11 | 11 | 11
12 | 12 | 12
13 | 13 | 13
14 | 14 | 14
15 | 15 | 15
16 | 16 | 16
17 | 17 | 17
(14 rows)

14 changes: 14 additions & 0 deletions .ci/evaluate-with-psql.sh
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,17 @@ echo "------Test \"-c option invalid begin/commit batching\"------"
/usr/lib/postgresql/"${PSQL_VERSION}"/bin/psql -h localhost -p 4242 -d "${GOOGLE_CLOUD_DATABASE_WITH_VERSION}" -c "$(cat .ci/e2e-batching/invalid-commit-batch.txt)" &> .ci/e2e-result/invalid-commit-batching.txt
diff -i -w -s .ci/e2e-result/invalid-commit-batching.txt .ci/e2e-expected/invalid-commit-batching.txt
RETURN_CODE=$((${RETURN_CODE}||$?))

echo "------Test \"COPY FROM STDIN\"------"
/usr/lib/postgresql/"${PSQL_VERSION}"/bin/psql -h localhost -p 4242 -d "${GOOGLE_CLOUD_DATABASE_WITH_VERSION}" -c "COPY users FROM STDIN;" <<EOF
12 12 12
13 13 13
14 14 14
15 15 15
16 16 16
17 17 17
\.
EOF
echo "SELECT * FROM users;" | /usr/lib/postgresql/"${PSQL_VERSION}"/bin/psql -h localhost -p 4242 -d "${GOOGLE_CLOUD_DATABASE_WITH_VERSION}" > .ci/e2e-result/copy-from-stdin.txt
diff -i -w -s .ci/e2e-result/copy-from-stdin.txt .ci/e2e-expected/copy-from-stdin.txt
RETURN_CODE=$((${RETURN_CODE}||$?))
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,16 @@ public synchronized IntermediateStatement getActiveStatement() {
return activeStatementsMap.get(this.connectionId);
}

public synchronized ConnectionStatus getStatus() {
return status;
}

public synchronized void setStatus(ConnectionStatus status) {
this.status = status;
}

/** Status of a {@link ConnectionHandler} */
private enum ConnectionStatus {
public enum ConnectionStatus {
UNAUTHENTICATED,
IDLE,
COPY_IN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ private void queryInformationSchema() throws SQLException {
+ COLUMN_NAME
+ ", "
+ SPANNER_TYPE
+ " FROM information_schema.columns WHERE table_name = \"?\"");
+ " FROM information_schema.columns WHERE table_name = ?");
statement.setString(1, getTableName());
ResultSet result = statement.executeQuery();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public int getRowCount() {
public void addCopyData(ConnectionHandler connectionHandler, byte[] payload) throws Exception {
this.payload.write(payload, 0, payload.length);
if (!commitSizeIsWithinLimit()) {
rollback(connectionHandler);
handleError(connectionHandler);
throw new SQLException(
"Commit size: " + this.payload.size() + " has exceeded the limit: " + COMMIT_LIMIT);
}
Expand All @@ -87,7 +87,7 @@ public void buildMutationList(ConnectionHandler connectionHandler) throws Except
for (CSVRecord record : records) {
// Check that the number of columns in a record matches the number of columns in the table
if (record.size() != this.tableColumns.keySet().size()) {
rollback(connectionHandler);
handleError(connectionHandler);
throw new SQLException(
"Invalid COPY data: Row length mismatched. Expected "
+ this.tableColumns.keySet().size()
Expand Down Expand Up @@ -122,7 +122,7 @@ public void buildMutationList(ConnectionHandler connectionHandler) throws Except
break;
}
} catch (NumberFormatException | DateTimeParseException e) {
rollback(connectionHandler);
handleError(connectionHandler);
throw new SQLException(
"Invalid input syntax for type "
+ columnType.toString()
Expand All @@ -131,10 +131,10 @@ public void buildMutationList(ConnectionHandler connectionHandler) throws Except
+ recordValue
+ "\"");
} catch (IllegalArgumentException e) {
rollback(connectionHandler);
handleError(connectionHandler);
throw new SQLException("Invalid input syntax for column \"" + columnName + "\"");
} catch (Exception e) {
rollback(connectionHandler);
handleError(connectionHandler);
throw e;
}
}
Expand All @@ -143,7 +143,7 @@ public void buildMutationList(ConnectionHandler connectionHandler) throws Except
this.rowCount++; // Increment the number of COPY rows by one
}
if (!mutationCountIsWithinLimit()) {
rollback(connectionHandler);
handleError(connectionHandler);
throw new SQLException(
"Mutation count: " + mutationCount + " has exceeded the limit: " + MUTATION_LIMIT);
}
Expand Down Expand Up @@ -176,7 +176,14 @@ private List<CSVRecord> parsePayloadData(byte[] payload) throws IOException {
} else {
parser = CSVParser.parse(copyData, this.format);
}
return parser.getRecords();
// Skip the last record if that is the '\.' end of file indicator.
List<CSVRecord> records = parser.getRecords();
if (!records.isEmpty()
&& records.get(records.size() - 1).size() == 1
&& "\\.".equals(records.get(records.size() - 1).get(0))) {
return records.subList(0, records.size() - 1);
}
return records;
}

/**
Expand All @@ -198,9 +205,7 @@ public int writeToSpanner(ConnectionHandler connectionHandler) throws SQLExcepti
return this.rowCount;
}

public void rollback(ConnectionHandler connectionHandler) throws Exception {
Connection connection = connectionHandler.getJdbcConnection();
connection.rollback();
public void handleError(ConnectionHandler connectionHandler) throws Exception {
this.mutations = new ArrayList<>();
this.mutationCount = 0;
writeCopyDataToErrorFile();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package com.google.cloud.spanner.pgadapter.wireprotocol;

import com.google.cloud.spanner.pgadapter.ConnectionHandler;
import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus;
import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode;
import com.google.cloud.spanner.pgadapter.metadata.SendResultSetState;
import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement;
Expand Down Expand Up @@ -49,35 +50,42 @@ public ControlMessage(ConnectionHandler connection) throws IOException {
*/
public static ControlMessage create(ConnectionHandler connection) throws Exception {
char nextMsg = (char) connection.getConnectionMetadata().getInputStream().readUnsignedByte();
switch (nextMsg) {
case QueryMessage.IDENTIFIER:
return new QueryMessage(connection);
case ParseMessage.IDENTIFIER:
return new ParseMessage(connection);
case BindMessage.IDENTIFIER:
return new BindMessage(connection);
case DescribeMessage.IDENTIFIER:
return new DescribeMessage(connection);
case ExecuteMessage.IDENTIFIER:
return new ExecuteMessage(connection);
case CloseMessage.IDENTIFIER:
return new CloseMessage(connection);
case SyncMessage.IDENTIFIER:
return new SyncMessage(connection);
case TerminateMessage.IDENTIFIER:
return new TerminateMessage(connection);
case CopyDoneMessage.IDENTIFIER:
return new CopyDoneMessage(connection);
case CopyDataMessage.IDENTIFIER:
return new CopyDataMessage(connection);
case CopyFailMessage.IDENTIFIER:
return new CopyFailMessage(connection);
case FunctionCallMessage.IDENTIFIER:
return new FunctionCallMessage(connection);
case FlushMessage.IDENTIFIER:
return new FlushMessage(connection);
default:
throw new IllegalStateException("Unknown message");
if (connection.getStatus() == ConnectionStatus.COPY_IN) {
switch (nextMsg) {
case CopyDoneMessage.IDENTIFIER:
return new CopyDoneMessage(connection);
case CopyDataMessage.IDENTIFIER:
return new CopyDataMessage(connection);
case CopyFailMessage.IDENTIFIER:
return new CopyFailMessage(connection);
default:
tinaspark marked this conversation as resolved.
Show resolved Hide resolved
throw new IllegalStateException("Expected 0 or more Copy Data messages.");
}
} else {
switch (nextMsg) {
case QueryMessage.IDENTIFIER:
return new QueryMessage(connection);
case ParseMessage.IDENTIFIER:
return new ParseMessage(connection);
case BindMessage.IDENTIFIER:
return new BindMessage(connection);
case DescribeMessage.IDENTIFIER:
return new DescribeMessage(connection);
case ExecuteMessage.IDENTIFIER:
return new ExecuteMessage(connection);
case CloseMessage.IDENTIFIER:
return new CloseMessage(connection);
case SyncMessage.IDENTIFIER:
return new SyncMessage(connection);
case TerminateMessage.IDENTIFIER:
return new TerminateMessage(connection);
case FunctionCallMessage.IDENTIFIER:
return new FunctionCallMessage(connection);
case FlushMessage.IDENTIFIER:
return new FlushMessage(connection);
default:
throw new IllegalStateException("Unknown message");
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package com.google.cloud.spanner.pgadapter.wireprotocol;

import com.google.cloud.spanner.pgadapter.ConnectionHandler;
import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus;
import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode;
import com.google.cloud.spanner.pgadapter.statements.CopyStatement;
import com.google.cloud.spanner.pgadapter.utils.MutationWriter;
Expand Down Expand Up @@ -52,17 +53,15 @@ protected void sendPayload() throws Exception {
// Spanner returned an error when trying to commit the batch of mutations.
mw.writeCopyDataToErrorFile();
mw.closeErrorFile();
// TODO: enable in next PR
// this.connection.setStatus(ConnectionStatus.IDLE);
this.connection.setStatus(ConnectionStatus.IDLE);
olavloite marked this conversation as resolved.
Show resolved Hide resolved
this.connection.removeActiveStatement(this.statement);
throw e;
}
} else {
mw.closeErrorFile();
}
new ReadyResponse(this.outputStream, ReadyResponse.Status.IDLE).send();
// TODO: enable in next PR
// this.connection.setStatus(ConnectionStatus.IDLE);
this.connection.setStatus(ConnectionStatus.IDLE);
olavloite marked this conversation as resolved.
Show resolved Hide resolved
this.connection.removeActiveStatement(this.statement);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

import com.google.cloud.spanner.jdbc.CloudSpannerJdbcConnection;
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
import com.google.cloud.spanner.pgadapter.ConnectionHandler.ConnectionStatus;
import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode;
import com.google.cloud.spanner.pgadapter.statements.CopyStatement;
import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement;
import com.google.cloud.spanner.pgadapter.statements.MatcherStatement;
import com.google.cloud.spanner.pgadapter.utils.StatementParser;
import com.google.cloud.spanner.pgadapter.wireoutput.CopyInResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse.Status;
import com.google.cloud.spanner.pgadapter.wireoutput.RowDescriptionResponse;
Expand All @@ -29,13 +32,17 @@
public class QueryMessage extends ControlMessage {

protected static final char IDENTIFIER = 'Q';
protected static final String COPY = "COPY";

private IntermediateStatement statement;

public QueryMessage(ConnectionHandler connection) throws Exception {
super(connection);
String query = StatementParser.removeCommentsAndTrim(this.readAll());
if (!connection.getServer().getOptions().requiresMatcher()) {
String command = StatementParser.parseCommand(query);
if (command.equalsIgnoreCase(COPY)) {
this.statement = new CopyStatement(query, this.connection.getJdbcConnection());
} else if (!connection.getServer().getOptions().requiresMatcher()) {
this.statement = new IntermediateStatement(query, this.connection.getJdbcConnection());
} else {
this.statement = new MatcherStatement(query, this.connection);
Expand All @@ -47,7 +54,9 @@ public QueryMessage(ConnectionHandler connection) throws Exception {
protected void sendPayload() throws Exception {
this.statement.execute();
this.handleQuery();
this.connection.removeActiveStatement(this.statement);
if (!this.statement.getCommand().equalsIgnoreCase(COPY)) {
this.connection.removeActiveStatement(this.statement);
}
}

@Override
Expand Down Expand Up @@ -81,7 +90,18 @@ public void handleQuery() throws Exception {
if (this.statement.hasException()) {
this.handleError(this.statement.getException());
} else {
if (this.statement.containsResultSet()) {
if (this.statement.getCommand().equalsIgnoreCase(COPY)) {
CopyStatement copyStatement = (CopyStatement) this.statement;
new CopyInResponse(
this.outputStream,
copyStatement.getTableColumns().size(),
copyStatement.getFormatCode())
.send();
this.connection.setStatus(ConnectionStatus.COPY_IN);

// Return early as we do not respond with CommandComplete after a COPY command.
return;
} else if (this.statement.containsResultSet()) {
new RowDescriptionResponse(
this.outputStream,
this.statement,
Expand Down
Loading