Skip to content

Commit

Permalink
fix: synchronize lazy ResultSet decoding (#3267)
Browse files Browse the repository at this point in the history
Using one of the options DecodeMode.LAZY_PER_ROW or DecodeMode.LAZY_PER_COLUMN
in combination with multi-threaded access to the ResultSet could lead to
ClassCastExceptions, as the underlying decode methods were not synchronized.
This could lead to multiple threads trying to access either the raw proto data
or the decoded data at the same time, and expecting to get the other type of
data.
  • Loading branch information
olavloite authored and lqiu96 committed Sep 10, 2024
1 parent eac0c71 commit 174497e
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.spanner.v1.ResultSetMetadata;
import com.google.spanner.v1.ResultSetStats;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nullable;

Expand All @@ -35,6 +36,7 @@ class GrpcResultSet extends AbstractResultSet<List<Object>> implements ProtobufR
private final DecodeMode decodeMode;
private ResultSetMetadata metadata;
private GrpcStruct currRow;
private List<Object> rowData;
private SpannerException error;
private ResultSetStats statistics;
private boolean closed;
Expand Down Expand Up @@ -85,7 +87,15 @@ public boolean next() throws SpannerException {
throw SpannerExceptionFactory.newSpannerException(
ErrorCode.FAILED_PRECONDITION, AbstractReadContext.NO_TRANSACTION_RETURNED_MSG);
}
currRow = new GrpcStruct(iterator.type(), new ArrayList<>(), decodeMode);
if (rowData == null) {
rowData = new ArrayList<>(metadata.getRowType().getFieldsCount());
if (decodeMode != DecodeMode.DIRECT) {
rowData = Collections.synchronizedList(rowData);
}
} else {
rowData.clear();
}
currRow = new GrpcStruct(iterator.type(), rowData, decodeMode);
}
boolean hasNext = currRow.consumeRow(iterator);
if (!hasNext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand All @@ -60,7 +61,7 @@ class GrpcStruct extends Struct implements Serializable {
private final List<Object> rowData;
private final DecodeMode decodeMode;
private final BitSet colDecoded;
private boolean rowDecoded;
private final AtomicBoolean rowDecoded;

/**
* Builds an immutable version of this struct using {@link Struct#newBuilder()} which is used as a
Expand Down Expand Up @@ -224,7 +225,7 @@ private GrpcStruct(
this.type = type;
this.rowData = rowData;
this.decodeMode = decodeMode;
this.rowDecoded = rowDecoded;
this.rowDecoded = new AtomicBoolean(rowDecoded);
this.colDecoded = colDecoded;
}

Expand All @@ -234,29 +235,31 @@ public String toString() {
}

boolean consumeRow(Iterator<com.google.protobuf.Value> iterator) {
rowData.clear();
if (decodeMode == DecodeMode.LAZY_PER_ROW) {
rowDecoded = false;
} else if (decodeMode == DecodeMode.LAZY_PER_COL) {
colDecoded.clear();
}
if (!iterator.hasNext()) {
return false;
}
for (Type.StructField fieldType : getType().getStructFields()) {
synchronized (rowData) {
rowData.clear();
if (decodeMode == DecodeMode.LAZY_PER_ROW) {
rowDecoded.set(false);
} else if (decodeMode == DecodeMode.LAZY_PER_COL) {
colDecoded.clear();
}
if (!iterator.hasNext()) {
throw newSpannerException(
ErrorCode.INTERNAL,
"Invalid value stream: end of stream reached before row is complete");
return false;
}
com.google.protobuf.Value value = iterator.next();
if (decodeMode == DecodeMode.DIRECT) {
rowData.add(decodeValue(fieldType.getType(), value));
} else {
rowData.add(value);
for (Type.StructField fieldType : getType().getStructFields()) {
if (!iterator.hasNext()) {
throw newSpannerException(
ErrorCode.INTERNAL,
"Invalid value stream: end of stream reached before row is complete");
}
com.google.protobuf.Value value = iterator.next();
if (decodeMode == DecodeMode.DIRECT) {
rowData.add(decodeValue(fieldType.getType(), value));
} else {
rowData.add(value);
}
}
return true;
}
return true;
}

private static Object decodeValue(Type fieldType, com.google.protobuf.Value proto) {
Expand Down Expand Up @@ -367,12 +370,16 @@ private static void checkType(
}

Struct immutableCopy() {
return new GrpcStruct(
type,
new ArrayList<>(rowData),
this.decodeMode,
this.rowDecoded,
this.colDecoded == null ? null : (BitSet) this.colDecoded.clone());
synchronized (rowData) {
return new GrpcStruct(
type,
this.decodeMode == DecodeMode.DIRECT
? new ArrayList<>(rowData)
: Collections.synchronizedList(new ArrayList<>(rowData)),
this.decodeMode,
this.rowDecoded.get(),
this.colDecoded == null ? null : (BitSet) this.colDecoded.clone());
}
}

@Override
Expand All @@ -382,9 +389,14 @@ public Type getType() {

@Override
public boolean isNull(int columnIndex) {
if ((decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded)
|| (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex))) {
return ((com.google.protobuf.Value) rowData.get(columnIndex)).hasNullValue();
if (decodeMode == DecodeMode.LAZY_PER_ROW || decodeMode == DecodeMode.LAZY_PER_COL) {
synchronized (rowData) {
if ((decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded.get())
|| (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex))) {
return ((com.google.protobuf.Value) rowData.get(columnIndex)).hasNullValue();
}
return rowData.get(columnIndex) == null;
}
}
return rowData.get(columnIndex) == null;
}
Expand Down Expand Up @@ -496,14 +508,18 @@ private boolean isUnrecognizedType(int columnIndex) {
}

boolean canGetProtoValue(int columnIndex) {
return isUnrecognizedType(columnIndex)
|| (decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded)
|| (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex));
synchronized (rowData) {
return isUnrecognizedType(columnIndex)
|| (decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded.get())
|| (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex));
}
}

protected com.google.protobuf.Value getProtoValueInternal(int columnIndex) {
checkProtoValueSupported(columnIndex);
return (com.google.protobuf.Value) rowData.get(columnIndex);
synchronized (rowData) {
checkProtoValueSupported(columnIndex);
return (com.google.protobuf.Value) rowData.get(columnIndex);
}
}

private void checkProtoValueSupported(int columnIndex) {
Expand All @@ -515,30 +531,56 @@ private void checkProtoValueSupported(int columnIndex) {
decodeMode != DecodeMode.DIRECT,
"Getting proto value is not supported when DecodeMode#DIRECT is used.");
Preconditions.checkState(
!(decodeMode == DecodeMode.LAZY_PER_ROW && rowDecoded),
!(decodeMode == DecodeMode.LAZY_PER_ROW && rowDecoded.get()),
"Getting proto value after the row has been decoded is not supported.");
Preconditions.checkState(
!(decodeMode == DecodeMode.LAZY_PER_COL && colDecoded.get(columnIndex)),
"Getting proto value after the column has been decoded is not supported.");
}

private void ensureDecoded(int columnIndex) {
if (decodeMode == DecodeMode.LAZY_PER_ROW && !rowDecoded) {
for (int i = 0; i < rowData.size(); i++) {
rowData.set(
i,
decodeValue(
type.getStructFields().get(i).getType(),
(com.google.protobuf.Value) rowData.get(i)));
if (decodeMode == DecodeMode.LAZY_PER_ROW) {
synchronized (rowData) {
if (!rowDecoded.get()) {
for (int i = 0; i < rowData.size(); i++) {
rowData.set(
i,
decodeValue(
type.getStructFields().get(i).getType(),
(com.google.protobuf.Value) rowData.get(i)));
}
}
rowDecoded.set(true);
}
} else if (decodeMode == DecodeMode.LAZY_PER_COL) {
boolean decoded;
Object value;
synchronized (rowData) {
decoded = colDecoded.get(columnIndex);
value = rowData.get(columnIndex);
}
if (!decoded) {
// Use the column as a lock during decoding to ensure that we decode once (mostly), but also
// that multiple different columns can be decoded in parallel if requested.
synchronized (type.getStructFields().get(columnIndex)) {
// Note: It can be that we decode the value twice if two threads request this at the same
// time, but the synchronization on rowData above and below makes sure that we always get
// and set a consistent value (and only set it once).
if (!colDecoded.get(columnIndex)) {
value =
decodeValue(
type.getStructFields().get(columnIndex).getType(),
(com.google.protobuf.Value) value);
decoded = true;
}
}
if (decoded) {
synchronized (rowData) {
rowData.set(columnIndex, value);
colDecoded.set(columnIndex);
}
}
}
rowDecoded = true;
} else if (decodeMode == DecodeMode.LAZY_PER_COL && !colDecoded.get(columnIndex)) {
rowData.set(
columnIndex,
decodeValue(
type.getStructFields().get(columnIndex).getType(),
(com.google.protobuf.Value) rowData.get(columnIndex)));
colDecoded.set(columnIndex);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
import com.google.cloud.spanner.ResultSet;
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.Statement;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadLocalRandom;
import org.junit.After;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -41,7 +47,7 @@ public void clearRequests() {
}

@Test
public void testAllDecodeModes() {
public void testAllDecodeModes() throws Exception {
int numRows = 10;
RandomResultSetGenerator generator = new RandomResultSetGenerator(numRows);
String sql = "select * from random";
Expand All @@ -50,57 +56,85 @@ public void testAllDecodeModes() {
MockSpannerServiceImpl.StatementResult.query(statement, generator.generate()));

try (Connection connection = createConnection()) {
for (boolean readonly : new boolean[] {true, false}) {
for (boolean autocommit : new boolean[] {true, false}) {
connection.setReadOnly(readonly);
connection.setAutocommit(autocommit);
for (boolean multiThreaded : new boolean[] {true, false}) {
for (boolean readonly : new boolean[] {true, false}) {
for (boolean autocommit : new boolean[] {true, false}) {
connection.setReadOnly(readonly);
connection.setAutocommit(autocommit);

int receivedRows = 0;
// DecodeMode#DIRECT is not supported in read/write transactions, as the protobuf value is
// used for checksum calculation.
try (ResultSet direct =
connection.executeQuery(
statement,
!readonly && !autocommit
? Options.decodeMode(DecodeMode.LAZY_PER_ROW)
: Options.decodeMode(DecodeMode.DIRECT));
ResultSet lazyPerRow =
connection.executeQuery(statement, Options.decodeMode(DecodeMode.LAZY_PER_ROW));
ResultSet lazyPerCol =
connection.executeQuery(statement, Options.decodeMode(DecodeMode.LAZY_PER_COL))) {
while (direct.next() && lazyPerRow.next() && lazyPerCol.next()) {
assertEquals(direct.getColumnCount(), lazyPerRow.getColumnCount());
assertEquals(direct.getColumnCount(), lazyPerCol.getColumnCount());
for (int col = 0; col < direct.getColumnCount(); col++) {
// Test getting the entire row as a struct both as the first thing we do, and as the
// last thing we do. This ensures that the method works as expected both when a row
// is lazily decoded by this method, and when it has already been decoded by another
// method.
if (col % 2 == 0) {
assertEquals(direct.getCurrentRowAsStruct(), lazyPerRow.getCurrentRowAsStruct());
assertEquals(direct.getCurrentRowAsStruct(), lazyPerCol.getCurrentRowAsStruct());
}
assertEquals(direct.isNull(col), lazyPerRow.isNull(col));
assertEquals(direct.isNull(col), lazyPerCol.isNull(col));
assertEquals(direct.getValue(col), lazyPerRow.getValue(col));
assertEquals(direct.getValue(col), lazyPerCol.getValue(col));
if (col % 2 == 1) {
assertEquals(direct.getCurrentRowAsStruct(), lazyPerRow.getCurrentRowAsStruct());
assertEquals(direct.getCurrentRowAsStruct(), lazyPerCol.getCurrentRowAsStruct());
int receivedRows = 0;
// DecodeMode#DIRECT is not supported in read/write transactions, as the protobuf value
// is
// used for checksum calculation.
try (ResultSet direct =
connection.executeQuery(
statement,
!readonly && !autocommit
? Options.decodeMode(DecodeMode.LAZY_PER_ROW)
: Options.decodeMode(DecodeMode.DIRECT));
ResultSet lazyPerRow =
connection.executeQuery(
statement, Options.decodeMode(DecodeMode.LAZY_PER_ROW));
ResultSet lazyPerCol =
connection.executeQuery(
statement, Options.decodeMode(DecodeMode.LAZY_PER_COL))) {
while (direct.next() && lazyPerRow.next() && lazyPerCol.next()) {
assertEquals(direct.getColumnCount(), lazyPerRow.getColumnCount());
assertEquals(direct.getColumnCount(), lazyPerCol.getColumnCount());
if (multiThreaded) {
ExecutorService service = Executors.newFixedThreadPool(direct.getColumnCount());
List<Future<?>> futures = new ArrayList<>(direct.getColumnCount());
for (int col = 0; col < direct.getColumnCount(); col++) {
final int colNumber = col;
futures.add(
service.submit(
() -> checkRowValues(colNumber, direct, lazyPerRow, lazyPerCol)));
}
service.shutdown();
for (Future<?> future : futures) {
future.get();
}
} else {
for (int col = 0; col < direct.getColumnCount(); col++) {
checkRowValues(col, direct, lazyPerRow, lazyPerCol);
}
}
receivedRows++;
}
receivedRows++;
assertEquals(numRows, receivedRows);
}
if (!autocommit) {
connection.commit();
}
assertEquals(numRows, receivedRows);
}
if (!autocommit) {
connection.commit();
}
}
}
}
}

private void checkRowValues(
int col, ResultSet direct, ResultSet lazyPerRow, ResultSet lazyPerCol) {
// Randomly decode and get a column to trigger parallel decoding of one column.
lazyPerCol.getValue(ThreadLocalRandom.current().nextInt(lazyPerCol.getColumnCount()));

// Test getting the entire row as a struct both as the first thing we do, and as the
// last thing we do. This ensures that the method works as expected both when a row
// is lazily decoded by this method, and when it has already been decoded by another
// method.
if (col % 2 == 0) {
assertEquals(direct.getCurrentRowAsStruct(), lazyPerRow.getCurrentRowAsStruct());
assertEquals(direct.getCurrentRowAsStruct(), lazyPerCol.getCurrentRowAsStruct());
}
assertEquals(direct.isNull(col), lazyPerRow.isNull(col));
assertEquals(direct.isNull(col), lazyPerCol.isNull(col));
assertEquals(direct.getValue(col), lazyPerRow.getValue(col));
assertEquals(direct.getValue(col), lazyPerCol.getValue(col));
if (col % 2 == 1) {
assertEquals(direct.getCurrentRowAsStruct(), lazyPerRow.getCurrentRowAsStruct());
assertEquals(direct.getCurrentRowAsStruct(), lazyPerCol.getCurrentRowAsStruct());
}
}

@Test
public void testDecodeModeDirect_failsInReadWriteTransaction() {
int numRows = 1;
Expand Down

0 comments on commit 174497e

Please sign in to comment.