Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-68][Shuffle] Adaptive compression select in Shuffle. (#69)
Browse files Browse the repository at this point in the history
* Choose the proper compress algorithm by comparing the compress size of first record batch in first partition

* Add the data type check in Java side

* add the note for the available customized compression codec
  • Loading branch information
JkSelf authored Feb 14, 2021
1 parent d4d322a commit 2576a03
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@

package com.intel.oap.vectorized;

import io.netty.buffer.ArrowBuf;
import org.apache.arrow.flatbuf.MessageHeader;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.MessageResult;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;
Expand All @@ -36,11 +41,16 @@
* ArrowRecordBatches.
*/
public class ArrowCompressedStreamReader extends ArrowStreamReader {
private String compressType;

public ArrowCompressedStreamReader(InputStream in, BufferAllocator allocator) {
super(in, allocator);
}

public String GetCompressType() {
return compressType;
}

protected void initialize() throws IOException {
Schema originalSchema = readSchema();
List<Field> fields = new ArrayList<>();
Expand All @@ -60,6 +70,47 @@ protected void initialize() throws IOException {
this.dictionaries = Collections.unmodifiableMap(dictionaries);
}

/**
* Load the next ArrowRecordBatch to the vector schema root if available.
*
* @return true if a batch was read, false on EOS
* @throws IOException on error
*/
public boolean loadNextBatch() throws IOException {
prepareLoadNextBatch();
MessageResult result = messageReader.readNext();

// Reached EOS
if (result == null) {
return false;
}
// Get the compress type from customMetadata. Currently the customMetadata only have one entry.
compressType = result.getMessage().customMetadata(0).value();

if (result.getMessage().headerType() == MessageHeader.RecordBatch) {
ArrowBuf bodyBuffer = result.getBodyBuffer();

// For zero-length batches, need an empty buffer to deserialize the batch
if (bodyBuffer == null) {
bodyBuffer = allocator.getEmpty();
}

ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(result.getMessage(), bodyBuffer);
loadRecordBatch(batch);
checkDictionaries();
return true;
} else if (result.getMessage().headerType() == MessageHeader.DictionaryBatch) {
// if it's dictionary message, read dictionary message out and continue to read unless get a batch or eos.
ArrowDictionaryBatch dictionaryBatch = readDictionary(result);
loadDictionary(dictionaryBatch);
loadedDictionaryCount++;
return loadNextBatch();
} else {
throw new IOException("Expected RecordBatch or DictionaryBatch but header was " +
result.getMessage().headerType());
}
}

@Override
protected void loadRecordBatch(ArrowRecordBatch batch) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,19 @@ public native long nativeMake(
* @param numRows Rows per batch
* @param bufAddrs Addresses of buffers
* @param bufSizes Sizes of buffers
* @param firstRecordBatch whether this record batch is the first
* record batch in the first partition.
* @return If the firstRecorBatch is true, return the compressed size, otherwise -1.
*/
public native void split(long splitterId, int numRows, long[] bufAddrs, long[] bufSizes)
public native long split(
long splitterId, int numRows, long[] bufAddrs, long[] bufSizes, boolean firstRecordBatch)
throws IOException;

/**
* Update the compress type.
*/
public native void setCompressType(long splitterId, String compressType);

/**
* Write the data remained in the buffers hold by native splitter to each partition's temporary
* file. And stop processing splitting
Expand Down
6 changes: 4 additions & 2 deletions core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ class ColumnarPluginConfig(conf: SQLConf) {
// and the cached buffers will be spilled when reach maximum memory.
val columnarShufflePreferSpill: Boolean =
conf.getConfString("spark.oap.sql.columnar.shuffle.preferSpill", "true").toBoolean
val columnarShuffleUseCustomizedCompression: Boolean =
conf.getConfString("spark.oap.sql.columnar.shuffle.customizedCompression", "false").toBoolean

// The supported customized compression codec is lz4 and fastpfor.
val columnarShuffleUseCustomizedCompressionCodec: String =
conf.getConfString("spark.oap.sql.columnar.shuffle.customizedCompression.codec", "lz4")
val isTesting: Boolean =
conf.getConfString("spark.oap.sql.columnar.testing", "false").toBoolean
val numaBindingInfo: ColumnarNumaBindingInfo = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,12 @@ object ConverterUtils extends Logging {
out.toByteArray
}

@throws[IOException]
def getSchemaFromBytesBuf(schema: Array[Byte]): Schema = {
val in: ByteArrayInputStream = new ByteArrayInputStream(schema)
MessageSerializer.deserializeSchema(new ReadChannel(Channels.newChannel(in)))
}

@throws[GandivaException]
def getExprListBytesBuf(exprs: List[ExpressionTree]): Array[Byte] = {
val builder: ExpressionList.Builder = GandivaTypes.ExpressionList.newBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,7 @@ private class ArrowColumnarBatchSerializerInstance(

private val compressionEnabled =
SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)
private val compressionCodec =
if (ColumnarPluginConfig
.getConf
.columnarShuffleUseCustomizedCompression) {
"fastpfor"
} else {
SparkEnv.get.conf.get("spark.io.compression.codec", "lz4")
}

private val allocator: BufferAllocator = SparkMemoryUtils.contextAllocator()
.newChildAllocator("ArrowColumnarBatch deserialize", 0, Long.MaxValue)

Expand Down Expand Up @@ -232,7 +225,7 @@ private class ArrowColumnarBatchSerializerInstance(

val builder = jniWrapper.decompress(
schemaHolderId,
compressionCodec,
reader.asInstanceOf[ArrowCompressedStreamReader].GetCompressType(),
root.getRowCount,
bufAddrs.toArray,
bufSizes.toArray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import java.io.IOException

import com.google.common.annotations.VisibleForTesting
import com.intel.oap.ColumnarPluginConfig
import com.intel.oap.expression.ConverterUtils
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.Spiller
import com.intel.oap.vectorized.{ArrowWritableColumnVector, ShuffleSplitterJniWrapper, SplitResult}
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.memory.MemoryConsumer
Expand Down Expand Up @@ -57,12 +59,11 @@ class ColumnarShuffleWriter[K, V](
private val localDirs = blockManager.diskBlockManager.localDirs.mkString(",")
private val nativeBufferSize =
conf.getInt("spark.sql.execution.arrow.maxRecordsPerBatch", 4096)
private val compressionCodec = if (conf.getBoolean("spark.shuffle.compress", true)) {
if (ColumnarPluginConfig.getConf.columnarShuffleUseCustomizedCompression) {
"fastpfor"
} else {
conf.get("spark.io.compression.codec", "lz4")
}

private val customizedCompressCodec =
ColumnarPluginConfig.getConf.columnarShuffleUseCustomizedCompressionCodec
private val defaultCompressionCodec = if (conf.getBoolean("spark.shuffle.compress", true)) {
conf.get("spark.io.compression.codec", "lz4")
} else {
"uncompressed"
}
Expand All @@ -76,6 +77,8 @@ class ColumnarShuffleWriter[K, V](

private var partitionLengths: Array[Long] = _

private var firstRecordBatch: Boolean = true

@throws[IOException]
override def write(records: Iterator[Product2[K, V]]): Unit = {
if (!records.hasNext) {
Expand All @@ -90,7 +93,7 @@ class ColumnarShuffleWriter[K, V](
nativeSplitter = jniWrapper.make(
dep.nativePartitioning,
nativeBufferSize,
compressionCodec,
defaultCompressionCodec,
dataTmp.getAbsolutePath,
blockManager.subDirsPerLocalDir,
localDirs,
Expand Down Expand Up @@ -128,7 +131,39 @@ class ColumnarShuffleWriter[K, V](
dep.dataSize.add(bufSizes.sum)

val startTime = System.nanoTime()
jniWrapper.split(nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray)

val existingIntType: Boolean = if (firstRecordBatch) {
// Check whether the recordbatch contain the Int data type.
val arrowSchema = ConverterUtils.getSchemaFromBytesBuf(dep.nativePartitioning.getSchema)
import scala.collection.JavaConverters._
arrowSchema.getFields.asScala.find(_.getType.getTypeID == ArrowTypeID.Int).nonEmpty
} else false

// Choose the compress type based on the compress size of the first record batch.
if (firstRecordBatch && conf.getBoolean("spark.shuffle.compress", true) &&
customizedCompressCodec != defaultCompressionCodec && existingIntType) {
// Compute the default compress size
jniWrapper.setCompressType(nativeSplitter, defaultCompressionCodec)
val defaultCompressedSize = jniWrapper.split(
nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray, firstRecordBatch)

// Compute the custom compress size.
jniWrapper.setCompressType(nativeSplitter, customizedCompressCodec)
val customizedCompressedSize = jniWrapper.split(
nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray, firstRecordBatch)

// Choose the compress algorithm based on the compress size.
if (customizedCompressedSize != -1 && defaultCompressedSize != -1) {
if (customizedCompressedSize > defaultCompressedSize) {
jniWrapper.setCompressType(nativeSplitter, defaultCompressionCodec)
}
} else {
logError("Failed to compute the compress size in the first record batch")
}
}
firstRecordBatch = false

jniWrapper.split(nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray, firstRecordBatch)
dep.splitTime.add(System.nanoTime() - startTime)
dep.numInputRows.add(cb.numRows)
writeMetrics.incRecordsWritten(1)
Expand Down
46 changes: 34 additions & 12 deletions cpp/src/jni/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1395,32 +1395,50 @@ Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_nativeMake(
return shuffle_splitter_holder_.Insert(std::shared_ptr<Splitter>(splitter));
}

JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_split(
JNIEnv* env, jobject, jlong splitter_id, jint num_rows, jlongArray buf_addrs,
jlongArray buf_sizes) {
JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_setCompressType(
JNIEnv* env, jobject, jlong splitter_id, jstring compression_type_jstr) {
auto splitter = shuffle_splitter_holder_.Lookup(splitter_id);
if (!splitter) {
std::string error_message = "Invalid splitter id " + std::to_string(splitter_id);
env->ThrowNew(illegal_argument_exception_class, error_message.c_str());
return;
}

if (compression_type_jstr != NULL) {
auto compression_type_result = GetCompressionType(env, compression_type_jstr);
if (compression_type_result.status().ok()) {
splitter->SetCompressType(compression_type_result.MoveValueUnsafe());
}
}
return;
}

JNIEXPORT jlong JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_split(
JNIEnv* env, jobject, jlong splitter_id, jint num_rows, jlongArray buf_addrs,
jlongArray buf_sizes, jboolean first_record_batch) {
auto splitter = shuffle_splitter_holder_.Lookup(splitter_id);
if (!splitter) {
std::string error_message = "Invalid splitter id " + std::to_string(splitter_id);
env->ThrowNew(illegal_argument_exception_class, error_message.c_str());
return -1;
}
if (buf_addrs == NULL) {
env->ThrowNew(illegal_argument_exception_class,
std::string("Native split: buf_addrs can't be null").c_str());
return;
return -1;
}
if (buf_sizes == NULL) {
env->ThrowNew(illegal_argument_exception_class,
std::string("Native split: buf_sizes can't be null").c_str());
return;
return -1;
}

int in_bufs_len = env->GetArrayLength(buf_addrs);
if (in_bufs_len != env->GetArrayLength(buf_sizes)) {
env->ThrowNew(
illegal_argument_exception_class,
std::string("Native split: length of buf_addrs and buf_sizes mismatch").c_str());
return;
return -1;
}

jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, JNI_FALSE);
Expand All @@ -1440,17 +1458,21 @@ JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_s
std::string("Native split: make record batch failed, error message is " +
status.message())
.c_str());
return;
return -1;
}

status = splitter->Split(*in);

if (!status.ok()) {
// Throw IOException
env->ThrowNew(io_exception_class,
if (first_record_batch) {
return splitter->CompressedSize(*in);
} else {
status = splitter->Split(*in);
if (!status.ok()) {
// Throw IOException
env->ThrowNew(io_exception_class,
std::string("Native split: splitter split failed, error message is " +
status.message())
.c_str());
}
return -1;
}
}

Expand Down
16 changes: 16 additions & 0 deletions cpp/src/shuffle/splitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,22 @@ arrow::Status Splitter::Init() {
return arrow::Status::OK();
}

int64_t Splitter::CompressedSize(const arrow::RecordBatch& rb) {
auto payload = std::make_shared<arrow::ipc::internal::IpcPayload>();
auto result = arrow::ipc::internal::GetRecordBatchPayload(
rb, options_.ipc_write_options, payload.get());
if (result.ok()) {
return payload.get()->body_length;
} else {
result.UnknownError("Failed to get the compressed size.");
return -1;
}
}

void Splitter::SetCompressType(arrow::Compression::type compressed_type) {
options_.ipc_write_options.compression = compressed_type;
}

arrow::Status Splitter::Split(const arrow::RecordBatch& rb) {
EVAL_START("split", options_.thread_id)
RETURN_NOT_OK(ComputeAndCountPartitionId(rb));
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/shuffle/splitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class Splitter {
* id. The largest partition buffer will be spilled if memory allocation failure occurs.
*/
virtual arrow::Status Split(const arrow::RecordBatch&);

/**
* Compute the compresse size of record batch.
*/
virtual int64_t CompressedSize(const arrow::RecordBatch&);

/**
* For each partition, merge spilled file into shuffle data file and write any cached
Expand All @@ -64,6 +69,8 @@ class Splitter {
*/
arrow::Status SpillPartition(int32_t partition_id);

void SetCompressType(arrow::Compression::type compressed_type);

/**
* Spill for fixed size of partition data
*/
Expand Down

0 comments on commit 2576a03

Please sign in to comment.