Skip to content

Commit

Permalink
fix: return correct transaction status
Browse files Browse the repository at this point in the history
PgAdapter should correctly return either 'I' (idle) or 'T' (transaction)
based on whether the session is still in a transaction. Some drivers
(psycopg) inspect and use the returned value to determine the state of the
connection.
  • Loading branch information
olavloite committed Jan 31, 2022
1 parent 869d44c commit 69c4017
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.google.cloud.spanner.pgadapter.wireprotocol;

import com.google.cloud.spanner.jdbc.CloudSpannerJdbcConnection;
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse.Status;
Expand All @@ -33,7 +34,11 @@ public FlushMessage(ConnectionHandler connection) throws Exception {

@Override
protected void sendPayload() throws Exception {
new ReadyResponse(this.outputStream, Status.IDLE).send();
boolean inTransaction =
connection.getJdbcConnection().unwrap(CloudSpannerJdbcConnection.class).isInTransaction();
new ReadyResponse(
this.outputStream, inTransaction ? Status.TRANSACTION : ReadyResponse.Status.IDLE)
.send();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

package com.google.cloud.spanner.pgadapter.wireprotocol;

import com.google.cloud.spanner.jdbc.CloudSpannerJdbcConnection;
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode;
import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement;
import com.google.cloud.spanner.pgadapter.statements.MatcherStatement;
import com.google.cloud.spanner.pgadapter.utils.StatementParser;
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse.Status;
import com.google.cloud.spanner.pgadapter.wireoutput.RowDescriptionResponse;
import java.text.MessageFormat;

Expand Down Expand Up @@ -89,7 +91,11 @@ public void handleQuery() throws Exception {
.send();
}
this.sendSpannerResult(this.statement, QueryMode.SIMPLE, 0L);
new ReadyResponse(this.outputStream, ReadyResponse.Status.IDLE).send();
boolean inTransaction =
connection.getJdbcConnection().unwrap(CloudSpannerJdbcConnection.class).isInTransaction();
new ReadyResponse(
this.outputStream, inTransaction ? Status.TRANSACTION : ReadyResponse.Status.IDLE)
.send();
}
this.connection.cleanUp(this.statement);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package com.google.cloud.spanner.pgadapter.wireprotocol;

import com.google.cloud.spanner.jdbc.CloudSpannerJdbcConnection;
import com.google.cloud.spanner.pgadapter.ConnectionHandler;
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse;
import com.google.cloud.spanner.pgadapter.wireoutput.ReadyResponse.Status;
Expand All @@ -33,7 +34,11 @@ public SyncMessage(ConnectionHandler connection) throws Exception {

@Override
protected void sendPayload() throws Exception {
new ReadyResponse(this.outputStream, Status.IDLE).send();
boolean inTransaction =
connection.getJdbcConnection().unwrap(CloudSpannerJdbcConnection.class).isInTransaction();
new ReadyResponse(
this.outputStream, inTransaction ? Status.TRANSACTION : ReadyResponse.Status.IDLE)
.send();
}

@Override
Expand Down
126 changes: 126 additions & 0 deletions src/test/java/com/google/cloud/spanner/pgadapter/ProtocolTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
package com.google.cloud.spanner.pgadapter;

import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import com.google.cloud.spanner.jdbc.CloudSpannerJdbcConnection;
import com.google.cloud.spanner.jdbc.JdbcConstants;
import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode;
import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata;
import com.google.cloud.spanner.pgadapter.metadata.DescribePortalMetadata;
Expand Down Expand Up @@ -955,6 +960,9 @@ public void testSyncMessage() throws Exception {
ByteArrayOutputStream result = new ByteArrayOutputStream();
DataOutputStream outputStream = new DataOutputStream(result);

Mockito.when(connectionHandler.getJdbcConnection()).thenReturn(connection);
Mockito.when(connection.unwrap(CloudSpannerJdbcConnection.class))
.thenReturn(mock(CloudSpannerJdbcConnection.class));
Mockito.when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata);
Mockito.when(connectionMetadata.getInputStream()).thenReturn(inputStream);
Mockito.when(connectionMetadata.getOutputStream()).thenReturn(outputStream);
Expand All @@ -971,6 +979,38 @@ public void testSyncMessage() throws Exception {
Assert.assertEquals(outputResult.readByte(), 'I');
}

@Test
public void testSyncMessageInTransaction() throws Exception {
byte[] messageMetadata = {'S'};

byte[] length = intToBytes(4);

byte[] value = Bytes.concat(messageMetadata, length);

DataInputStream inputStream = new DataInputStream(new ByteArrayInputStream(value));
ByteArrayOutputStream result = new ByteArrayOutputStream();
DataOutputStream outputStream = new DataOutputStream(result);

CloudSpannerJdbcConnection cloudSpannerConnection = mock(CloudSpannerJdbcConnection.class);
when(cloudSpannerConnection.isInTransaction()).thenReturn(true);
when(connectionHandler.getJdbcConnection()).thenReturn(connection);
when(connection.unwrap(CloudSpannerJdbcConnection.class)).thenReturn(cloudSpannerConnection);
when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata);
when(connectionMetadata.getInputStream()).thenReturn(inputStream);
when(connectionMetadata.getOutputStream()).thenReturn(outputStream);

WireMessage message = ControlMessage.create(connectionHandler);
assertEquals(SyncMessage.class, message.getClass());

message.send();

// ReadyResponse
DataInputStream outputResult = inputStreamFromOutputStream(result);
assertEquals('Z', outputResult.readByte());
assertEquals(5, outputResult.readInt());
assertEquals('T', outputResult.readByte());
}

@Test
public void testFlushMessage() throws Exception {
byte[] messageMetadata = {'H'};
Expand All @@ -983,6 +1023,9 @@ public void testFlushMessage() throws Exception {
ByteArrayOutputStream result = new ByteArrayOutputStream();
DataOutputStream outputStream = new DataOutputStream(result);

Mockito.when(connectionHandler.getJdbcConnection()).thenReturn(connection);
Mockito.when(connection.unwrap(CloudSpannerJdbcConnection.class))
.thenReturn(mock(CloudSpannerJdbcConnection.class));
Mockito.when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata);
Mockito.when(connectionMetadata.getInputStream()).thenReturn(inputStream);
Mockito.when(connectionMetadata.getOutputStream()).thenReturn(outputStream);
Expand All @@ -999,6 +1042,89 @@ public void testFlushMessage() throws Exception {
Assert.assertEquals(outputResult.readByte(), 'I');
}

@Test
public void testFlushMessageInTransaction() throws Exception {
byte[] messageMetadata = {'H'};

byte[] length = intToBytes(4);

byte[] value = Bytes.concat(messageMetadata, length);

DataInputStream inputStream = new DataInputStream(new ByteArrayInputStream(value));
ByteArrayOutputStream result = new ByteArrayOutputStream();
DataOutputStream outputStream = new DataOutputStream(result);

CloudSpannerJdbcConnection cloudSpannerConnection = mock(CloudSpannerJdbcConnection.class);
when(cloudSpannerConnection.isInTransaction()).thenReturn(true);
when(connectionHandler.getJdbcConnection()).thenReturn(connection);
when(connection.unwrap(CloudSpannerJdbcConnection.class)).thenReturn(cloudSpannerConnection);
when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata);
when(connectionMetadata.getInputStream()).thenReturn(inputStream);
when(connectionMetadata.getOutputStream()).thenReturn(outputStream);

WireMessage message = ControlMessage.create(connectionHandler);
assertEquals(FlushMessage.class, message.getClass());

message.send();

// ReadyResponse
DataInputStream outputResult = inputStreamFromOutputStream(result);
assertEquals('Z', outputResult.readByte());
assertEquals(5, outputResult.readInt());
assertEquals('T', outputResult.readByte());
}

@Test
public void testQueryMessageInTransaction() throws Exception {
byte[] messageMetadata = {'Q', 0, 0, 0, 24};
String payload = "SELECT * FROM users\0";
byte[] value = Bytes.concat(messageMetadata, payload.getBytes());

DataInputStream inputStream = new DataInputStream(new ByteArrayInputStream(value));
ByteArrayOutputStream result = new ByteArrayOutputStream();
DataOutputStream outputStream = new DataOutputStream(result);

String expectedSQL = "SELECT * FROM users";

CloudSpannerJdbcConnection cloudSpannerConnection = mock(CloudSpannerJdbcConnection.class);
when(cloudSpannerConnection.isInTransaction()).thenReturn(true);
when(connectionHandler.getJdbcConnection()).thenReturn(connection);
when(connection.unwrap(CloudSpannerJdbcConnection.class)).thenReturn(cloudSpannerConnection);
when(connection.createStatement()).thenReturn(statement);
// TODO: Remove the following mock result, as this is wrong, but is currently needed because of
// a bug in the way that PgAdapter determines the result type of a statement. That is fixed in
// https://github.com/GoogleCloudPlatform/pgadapter/pull/23
when(statement.getUpdateCount()).thenReturn(JdbcConstants.STATEMENT_NO_RESULT);
when(connectionHandler.getConnectionMetadata()).thenReturn(connectionMetadata);
when(connectionHandler.getServer()).thenReturn(server);
when(server.getOptions()).thenReturn(mock(OptionsMetadata.class));
when(connectionMetadata.getInputStream()).thenReturn(inputStream);
when(connectionMetadata.getOutputStream()).thenReturn(outputStream);

WireMessage message = ControlMessage.create(connectionHandler);
assertEquals(QueryMessage.class, message.getClass());
assertEquals(expectedSQL, ((QueryMessage) message).getStatement().getSql());

message.send();

// NoData response (query does not return any results).
DataInputStream outputResult = inputStreamFromOutputStream(result);
assertEquals('C', outputResult.readByte()); // CommandComplete
assertEquals('\0', outputResult.readByte());
assertEquals('\0', outputResult.readByte());
assertEquals('\0', outputResult.readByte());
// 11 = 4 + "SELECT".length() + 1 (header + command length + null terminator)
assertEquals(11, outputResult.readByte());
byte[] command = new byte[6];
assertEquals(6, outputResult.read(command, 0, 6));
assertEquals("SELECT", new String(command));
assertEquals('\0', outputResult.readByte());
// ReadyResponse in transaction ('T')
assertEquals('Z', outputResult.readByte());
assertEquals(5, outputResult.readInt());
assertEquals('T', outputResult.readByte());
}

@Test
public void testTerminateMessage() throws Exception {
byte[] messageMetadata = {'X'};
Expand Down

0 comments on commit 69c4017

Please sign in to comment.