Skip to content

Commit

Permalink
[SPARK-31425][SQL][CORE] UnsafeKVExternalSorter/VariableLengthRowBase…
Browse files Browse the repository at this point in the history
…dKeyValueBatch should also respect UnsafeAlignedOffset

### What changes were proposed in this pull request?

Make `UnsafeKVExternalSorter` / `VariableLengthRowBasedKeyValueBatch ` also respect `UnsafeAlignedOffset` when reading the record and update some out of date comemnts.

### Why are the changes needed?

Since `BytesToBytesMap` respects `UnsafeAlignedOffset` when writing the record, `UnsafeKVExternalSorter` should also respect `UnsafeAlignedOffset` when reading the record from `BytesToBytesMap` otherwise it will causes data correctness issue.

Unlike `UnsafeKVExternalSorter` may reading records from `BytesToBytesMap`, `VariableLengthRowBasedKeyValueBatch` writes and reads records by itself. Thus, similar to #22053 and [comment](#22053 (comment)) there, fix for `VariableLengthRowBasedKeyValueBatch` more likely an improvement for the support of SPARC platform.

### Does this PR introduce any user-facing change?

No.

### How was this patch tested?

Manually tested `HashAggregationQueryWithControlledFallbackSuite` with `UAO_SIZE=8`  to simulate SPARC platform. And tests only pass with this fix.

Closes #28195 from Ngone51/fix_uao.

Authored-by: yi.wu <yi.wu@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 40f9dbb)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Ngone51 authored and cloud-fan committed Apr 17, 2020
1 parent e7fef70 commit 33d25ba
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,20 @@ public class UnsafeAlignedOffset {

private static final int UAO_SIZE = Platform.unaligned() ? 4 : 8;

private static int TEST_UAO_SIZE = 0;

// used for test only
public static void setUaoSize(int size) {
assert size == 0 || size == 4 || size == 8;
TEST_UAO_SIZE = size;
}

public static int getUaoSize() {
return UAO_SIZE;
return TEST_UAO_SIZE == 0 ? UAO_SIZE : TEST_UAO_SIZE;
}

public static int getSize(Object object, long offset) {
switch (UAO_SIZE) {
switch (getUaoSize()) {
case 4:
return Platform.getInt(object, offset);
case 8:
Expand All @@ -46,7 +54,7 @@ public static int getSize(Object object, long offset) {
}

public static void putSize(Object object, long offset, int value) {
switch (UAO_SIZE) {
switch (getUaoSize()) {
case 4:
Platform.putInt(object, offset, value);
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@
* probably be using sorting instead of hashing for better cache locality.
*
* The key and values under the hood are stored together, in the following format:
* Bytes 0 to 4: len(k) (key length in bytes) + len(v) (value length in bytes) + 4
* Bytes 4 to 8: len(k)
* Bytes 8 to 8 + len(k): key data
* Bytes 8 + len(k) to 8 + len(k) + len(v): value data
* Bytes 8 + len(k) + len(v) to 8 + len(k) + len(v) + 8: pointer to next pair
* First uaoSize bytes: len(k) (key length in bytes) + len(v) (value length in bytes) + uaoSize
* Next uaoSize bytes: len(k)
* Next len(k) bytes: key data
* Next len(v) bytes: value data
* Last 8 bytes: pointer to next pair
*
* This means that the first four bytes store the entire record (key + value) length. This format
* It means first uaoSize bytes store the entire record (key + value + uaoSize) length. This format
* is compatible with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
* so we can pass records from this map directly into the sorter to sort records in place.
*/
Expand Down Expand Up @@ -707,7 +707,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
// the key address instead of storing the absolute address of the value, the key and value
// must be stored in the same memory page.
// (8 byte key length) (key) (value) (8 byte pointer to next value)
// (total length) (key length) (key) (value) (8 byte pointer to next value)
int uaoSize = UnsafeAlignedOffset.getUaoSize();
final long recordLength = (2L * uaoSize) + klen + vlen + 8;
if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ public void expandPointerArray(LongArray newArray) {

/**
* Inserts a record to be sorted. Assumes that the record pointer points to a record length
* stored as a 4-byte integer, followed by the record's bytes.
* stored as a uaoSize(4 or 8) bytes integer, followed by the record's bytes.
*
* @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
* @param keyPrefix a user-defined key prefix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UnsafeAlignedOffset;

/**
* An implementation of `RowBasedKeyValueBatch` in which key-value records have variable lengths.
*
* The format for each record looks like this:
* The format for each record looks like this (in case of uaoSize = 4):
* [4 bytes total size = (klen + vlen + 4)] [4 bytes key size = klen]
* [UnsafeRow for key of length klen] [UnsafeRow for Value of length vlen]
* [8 bytes pointer to next]
Expand All @@ -41,18 +42,19 @@ public final class VariableLengthRowBasedKeyValueBatch extends RowBasedKeyValueB
@Override
public UnsafeRow appendRow(Object kbase, long koff, int klen,
Object vbase, long voff, int vlen) {
final long recordLength = 8L + klen + vlen + 8;
int uaoSize = UnsafeAlignedOffset.getUaoSize();
final long recordLength = 2 * uaoSize + klen + vlen + 8L;
// if run out of max supported rows or page size, return null
if (numRows >= capacity || page == null || page.size() - pageCursor < recordLength) {
return null;
}

long offset = page.getBaseOffset() + pageCursor;
final long recordOffset = offset;
Platform.putInt(base, offset, klen + vlen + 4);
Platform.putInt(base, offset + 4, klen);
UnsafeAlignedOffset.putSize(base, offset, klen + vlen + uaoSize);
UnsafeAlignedOffset.putSize(base, offset + uaoSize, klen);

offset += 8;
offset += 2 * uaoSize;
Platform.copyMemory(kbase, koff, base, offset, klen);
offset += klen;
Platform.copyMemory(vbase, voff, base, offset, vlen);
Expand All @@ -61,11 +63,11 @@ public UnsafeRow appendRow(Object kbase, long koff, int klen,

pageCursor += recordLength;

keyOffsets[numRows] = recordOffset + 8;
keyOffsets[numRows] = recordOffset + 2 * uaoSize;

keyRowId = numRows;
keyRow.pointTo(base, recordOffset + 8, klen);
valueRow.pointTo(base, recordOffset + 8 + klen, vlen);
keyRow.pointTo(base, recordOffset + 2 * uaoSize, klen);
valueRow.pointTo(base, recordOffset + 2 * uaoSize + klen, vlen);
numRows++;
return valueRow;
}
Expand All @@ -79,7 +81,7 @@ public UnsafeRow getKeyRow(int rowId) {
assert(rowId < numRows);
if (keyRowId != rowId) { // if keyRowId == rowId, desired keyRow is already cached
long offset = keyOffsets[rowId];
int klen = Platform.getInt(base, offset - 4);
int klen = UnsafeAlignedOffset.getSize(base, offset - UnsafeAlignedOffset.getUaoSize());
keyRow.pointTo(base, offset, klen);
// set keyRowId so we can check if desired row is cached
keyRowId = rowId;
Expand All @@ -99,9 +101,10 @@ public UnsafeRow getValueFromKey(int rowId) {
getKeyRow(rowId);
}
assert(rowId >= 0);
int uaoSize = UnsafeAlignedOffset.getUaoSize();
long offset = keyRow.getBaseOffset();
int klen = keyRow.getSizeInBytes();
int vlen = Platform.getInt(base, offset - 8) - klen - 4;
int vlen = UnsafeAlignedOffset.getSize(base, offset - uaoSize * 2) - klen - uaoSize;
valueRow.pointTo(base, offset + klen, vlen);
return valueRow;
}
Expand Down Expand Up @@ -141,14 +144,15 @@ public boolean next() {
return false;
}

totalLength = Platform.getInt(base, offsetInPage) - 4;
currentklen = Platform.getInt(base, offsetInPage + 4);
int uaoSize = UnsafeAlignedOffset.getUaoSize();
totalLength = UnsafeAlignedOffset.getSize(base, offsetInPage) - uaoSize;
currentklen = UnsafeAlignedOffset.getSize(base, offsetInPage + uaoSize);
currentvlen = totalLength - currentklen;

key.pointTo(base, offsetInPage + 8, currentklen);
value.pointTo(base, offsetInPage + 8 + currentklen, currentvlen);
key.pointTo(base, offsetInPage + 2 * uaoSize, currentklen);
value.pointTo(base, offsetInPage + 2 * uaoSize + currentklen, currentvlen);

offsetInPage += 8 + totalLength + 8;
offsetInPage += 2 * uaoSize + totalLength + 8;
recordsInPage -= 1;
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UnsafeAlignedOffset;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryBlock;
Expand Down Expand Up @@ -141,9 +142,10 @@ public UnsafeKVExternalSorter(

// Get encoded memory address
// baseObject + baseOffset point to the beginning of the key data in the map, but that
// the KV-pair's length data is stored in the word immediately before that address
// the KV-pair's length data is stored at 2 * uaoSize bytes immediately before that address
MemoryBlock page = loc.getMemoryPage();
long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8);
long address = taskMemoryManager.encodePageNumberAndOffset(page,
baseOffset - 2 * UnsafeAlignedOffset.getUaoSize());

// Compute prefix
row.pointTo(baseObject, baseOffset, loc.getKeyLength());
Expand Down Expand Up @@ -262,10 +264,11 @@ public int compare(
Object baseObj2,
long baseOff2,
int baseLen2) {
int uaoSize = UnsafeAlignedOffset.getUaoSize();
// Note that since ordering doesn't need the total length of the record, we just pass 0
// into the row.
row1.pointTo(baseObj1, baseOff1 + 4, 0);
row2.pointTo(baseObj2, baseOff2 + 4, 0);
row1.pointTo(baseObj1, baseOff1 + uaoSize, 0);
row2.pointTo(baseObj2, baseOff2 + uaoSize, 0);
return ordering.compare(row1, row2);
}
}
Expand All @@ -289,11 +292,12 @@ public boolean next() throws IOException {
long recordOffset = underlying.getBaseOffset();
int recordLen = underlying.getRecordLength();

// Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
// Note that recordLen = keyLen + valueLen + uaoSize (for the keyLen itself)
int uaoSize = UnsafeAlignedOffset.getUaoSize();
int keyLen = Platform.getInt(baseObj, recordOffset);
int valueLen = recordLen - keyLen - 4;
key.pointTo(baseObj, recordOffset + 4, keyLen);
value.pointTo(baseObj, recordOffset + 4 + keyLen, valueLen);
int valueLen = recordLen - keyLen - uaoSize;
key.pointTo(baseObj, recordOffset + uaoSize, keyLen);
value.pointTo(baseObj, recordOffset + uaoSize + keyLen, valueLen);

return true;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UnsafeAlignedOffset


class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
Expand Down Expand Up @@ -1055,30 +1056,35 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu
Seq("true", "false").foreach { enableTwoLevelMaps =>
withSQLConf(SQLConf.ENABLE_TWOLEVEL_AGG_MAP.key ->
enableTwoLevelMaps) {
(1 to 3).foreach { fallbackStartsAt =>
withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {
// Create a new df to make sure its physical operator picks up
// spark.sql.TungstenAggregate.testFallbackStartsAt.
// todo: remove it?
val newActual = Dataset.ofRows(spark, actual.logicalPlan)

QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match {
case Some(errorMessage) =>
val newErrorMessage =
s"""
|The following aggregation query failed when using HashAggregate with
|controlled fallback (it falls back to bytes to bytes map once it has processed
|${fallbackStartsAt - 1} input rows and to sort-based aggregation once it has
|processed $fallbackStartsAt input rows). The query is ${actual.queryExecution}
|
|$errorMessage
""".stripMargin

fail(newErrorMessage)
case None => // Success
Seq(4, 8).foreach { uaoSize =>
UnsafeAlignedOffset.setUaoSize(uaoSize)
(1 to 3).foreach { fallbackStartsAt =>
withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {
// Create a new df to make sure its physical operator picks up
// spark.sql.TungstenAggregate.testFallbackStartsAt.
// todo: remove it?
val newActual = Dataset.ofRows(spark, actual.logicalPlan)

QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match {
case Some(errorMessage) =>
val newErrorMessage =
s"""
|The following aggregation query failed when using HashAggregate with
|controlled fallback (it falls back to bytes to bytes map once it has
|processed ${fallbackStartsAt - 1} input rows and to sort-based aggregation
|once it has processed $fallbackStartsAt input rows).
|The query is ${actual.queryExecution}
|$errorMessage
""".stripMargin

fail(newErrorMessage)
case None => // Success
}
}
}
// reset static uaoSize to avoid affect other tests
UnsafeAlignedOffset.setUaoSize(0)
}
}
}
Expand Down

0 comments on commit 33d25ba

Please sign in to comment.