diff --git a/core/src/main/java/com/intel/oap/vectorized/ArrowCompressedStreamReader.java b/core/src/main/java/com/intel/oap/vectorized/ArrowCompressedStreamReader.java index ce0a9d767..b1336864e 100644 --- a/core/src/main/java/com/intel/oap/vectorized/ArrowCompressedStreamReader.java +++ b/core/src/main/java/com/intel/oap/vectorized/ArrowCompressedStreamReader.java @@ -17,12 +17,17 @@ package com.intel.oap.vectorized; +import io.netty.buffer.ArrowBuf; +import org.apache.arrow.flatbuf.MessageHeader; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.MessageResult; +import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.DictionaryUtility; @@ -36,11 +41,16 @@ * ArrowRecordBatches. */ public class ArrowCompressedStreamReader extends ArrowStreamReader { + private String compressType; public ArrowCompressedStreamReader(InputStream in, BufferAllocator allocator) { super(in, allocator); } + public String GetCompressType() { + return compressType; + } + protected void initialize() throws IOException { Schema originalSchema = readSchema(); List fields = new ArrayList<>(); @@ -60,6 +70,47 @@ protected void initialize() throws IOException { this.dictionaries = Collections.unmodifiableMap(dictionaries); } + /** + * Load the next ArrowRecordBatch to the vector schema root if available. + * + * @return true if a batch was read, false on EOS + * @throws IOException on error + */ + public boolean loadNextBatch() throws IOException { + prepareLoadNextBatch(); + MessageResult result = messageReader.readNext(); + + // Reached EOS + if (result == null) { + return false; + } + // Get the compress type from customMetadata. Currently the customMetadata only have one entry. + compressType = result.getMessage().customMetadata(0).value(); + + if (result.getMessage().headerType() == MessageHeader.RecordBatch) { + ArrowBuf bodyBuffer = result.getBodyBuffer(); + + // For zero-length batches, need an empty buffer to deserialize the batch + if (bodyBuffer == null) { + bodyBuffer = allocator.getEmpty(); + } + + ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(result.getMessage(), bodyBuffer); + loadRecordBatch(batch); + checkDictionaries(); + return true; + } else if (result.getMessage().headerType() == MessageHeader.DictionaryBatch) { + // if it's dictionary message, read dictionary message out and continue to read unless get a batch or eos. + ArrowDictionaryBatch dictionaryBatch = readDictionary(result); + loadDictionary(dictionaryBatch); + loadedDictionaryCount++; + return loadNextBatch(); + } else { + throw new IOException("Expected RecordBatch or DictionaryBatch but header was " + + result.getMessage().headerType()); + } + } + @Override protected void loadRecordBatch(ArrowRecordBatch batch) { try { diff --git a/core/src/main/java/com/intel/oap/vectorized/ShuffleSplitterJniWrapper.java b/core/src/main/java/com/intel/oap/vectorized/ShuffleSplitterJniWrapper.java index 83a3d92ca..f3c2a39d2 100644 --- a/core/src/main/java/com/intel/oap/vectorized/ShuffleSplitterJniWrapper.java +++ b/core/src/main/java/com/intel/oap/vectorized/ShuffleSplitterJniWrapper.java @@ -96,10 +96,19 @@ public native long nativeMake( * @param numRows Rows per batch * @param bufAddrs Addresses of buffers * @param bufSizes Sizes of buffers + * @param firstRecordBatch whether this record batch is the first + * record batch in the first partition. + * @return If the firstRecorBatch is true, return the compressed size, otherwise -1. */ - public native void split(long splitterId, int numRows, long[] bufAddrs, long[] bufSizes) + public native long split( + long splitterId, int numRows, long[] bufAddrs, long[] bufSizes, boolean firstRecordBatch) throws IOException; + /** + * Update the compress type. + */ + public native void setCompressType(long splitterId, String compressType); + /** * Write the data remained in the buffers hold by native splitter to each partition's temporary * file. And stop processing splitting diff --git a/core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala b/core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala index 3cbc513cf..b4354e0f7 100644 --- a/core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala +++ b/core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala @@ -66,8 +66,10 @@ class ColumnarPluginConfig(conf: SQLConf) { // and the cached buffers will be spilled when reach maximum memory. val columnarShufflePreferSpill: Boolean = conf.getConfString("spark.oap.sql.columnar.shuffle.preferSpill", "true").toBoolean - val columnarShuffleUseCustomizedCompression: Boolean = - conf.getConfString("spark.oap.sql.columnar.shuffle.customizedCompression", "false").toBoolean + + // The supported customized compression codec is lz4 and fastpfor. + val columnarShuffleUseCustomizedCompressionCodec: String = + conf.getConfString("spark.oap.sql.columnar.shuffle.customizedCompression.codec", "lz4") val isTesting: Boolean = conf.getConfString("spark.oap.sql.columnar.testing", "false").toBoolean val numaBindingInfo: ColumnarNumaBindingInfo = { diff --git a/core/src/main/scala/com/intel/oap/expression/ConverterUtils.scala b/core/src/main/scala/com/intel/oap/expression/ConverterUtils.scala index 440b2e55b..6b6a56e84 100644 --- a/core/src/main/scala/com/intel/oap/expression/ConverterUtils.scala +++ b/core/src/main/scala/com/intel/oap/expression/ConverterUtils.scala @@ -456,6 +456,12 @@ object ConverterUtils extends Logging { out.toByteArray } + @throws[IOException] + def getSchemaFromBytesBuf(schema: Array[Byte]): Schema = { + val in: ByteArrayInputStream = new ByteArrayInputStream(schema) + MessageSerializer.deserializeSchema(new ReadChannel(Channels.newChannel(in))) + } + @throws[GandivaException] def getExprListBytesBuf(exprs: List[ExpressionTree]): Array[Byte] = { val builder: ExpressionList.Builder = GandivaTypes.ExpressionList.newBuilder diff --git a/core/src/main/scala/com/intel/oap/vectorized/ArrowColumnarBatchSerializer.scala b/core/src/main/scala/com/intel/oap/vectorized/ArrowColumnarBatchSerializer.scala index 2f29d071f..6011fa4bc 100644 --- a/core/src/main/scala/com/intel/oap/vectorized/ArrowColumnarBatchSerializer.scala +++ b/core/src/main/scala/com/intel/oap/vectorized/ArrowColumnarBatchSerializer.scala @@ -68,14 +68,7 @@ private class ArrowColumnarBatchSerializerInstance( private val compressionEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true) - private val compressionCodec = - if (ColumnarPluginConfig - .getConf - .columnarShuffleUseCustomizedCompression) { - "fastpfor" - } else { - SparkEnv.get.conf.get("spark.io.compression.codec", "lz4") - } + private val allocator: BufferAllocator = SparkMemoryUtils.contextAllocator() .newChildAllocator("ArrowColumnarBatch deserialize", 0, Long.MaxValue) @@ -232,7 +225,7 @@ private class ArrowColumnarBatchSerializerInstance( val builder = jniWrapper.decompress( schemaHolderId, - compressionCodec, + reader.asInstanceOf[ArrowCompressedStreamReader].GetCompressType(), root.getRowCount, bufAddrs.toArray, bufSizes.toArray, diff --git a/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala index f594a6e97..e7045c3dd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ColumnarShuffleWriter.scala @@ -21,8 +21,10 @@ import java.io.IOException import com.google.common.annotations.VisibleForTesting import com.intel.oap.ColumnarPluginConfig +import com.intel.oap.expression.ConverterUtils import com.intel.oap.spark.sql.execution.datasources.v2.arrow.Spiller import com.intel.oap.vectorized.{ArrowWritableColumnVector, ShuffleSplitterJniWrapper, SplitResult} +import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.memory.MemoryConsumer @@ -57,12 +59,11 @@ class ColumnarShuffleWriter[K, V]( private val localDirs = blockManager.diskBlockManager.localDirs.mkString(",") private val nativeBufferSize = conf.getInt("spark.sql.execution.arrow.maxRecordsPerBatch", 4096) - private val compressionCodec = if (conf.getBoolean("spark.shuffle.compress", true)) { - if (ColumnarPluginConfig.getConf.columnarShuffleUseCustomizedCompression) { - "fastpfor" - } else { - conf.get("spark.io.compression.codec", "lz4") - } + + private val customizedCompressCodec = + ColumnarPluginConfig.getConf.columnarShuffleUseCustomizedCompressionCodec + private val defaultCompressionCodec = if (conf.getBoolean("spark.shuffle.compress", true)) { + conf.get("spark.io.compression.codec", "lz4") } else { "uncompressed" } @@ -76,6 +77,8 @@ class ColumnarShuffleWriter[K, V]( private var partitionLengths: Array[Long] = _ + private var firstRecordBatch: Boolean = true + @throws[IOException] override def write(records: Iterator[Product2[K, V]]): Unit = { if (!records.hasNext) { @@ -90,7 +93,7 @@ class ColumnarShuffleWriter[K, V]( nativeSplitter = jniWrapper.make( dep.nativePartitioning, nativeBufferSize, - compressionCodec, + defaultCompressionCodec, dataTmp.getAbsolutePath, blockManager.subDirsPerLocalDir, localDirs, @@ -128,7 +131,39 @@ class ColumnarShuffleWriter[K, V]( dep.dataSize.add(bufSizes.sum) val startTime = System.nanoTime() - jniWrapper.split(nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray) + + val existingIntType: Boolean = if (firstRecordBatch) { + // Check whether the recordbatch contain the Int data type. + val arrowSchema = ConverterUtils.getSchemaFromBytesBuf(dep.nativePartitioning.getSchema) + import scala.collection.JavaConverters._ + arrowSchema.getFields.asScala.find(_.getType.getTypeID == ArrowTypeID.Int).nonEmpty + } else false + + // Choose the compress type based on the compress size of the first record batch. + if (firstRecordBatch && conf.getBoolean("spark.shuffle.compress", true) && + customizedCompressCodec != defaultCompressionCodec && existingIntType) { + // Compute the default compress size + jniWrapper.setCompressType(nativeSplitter, defaultCompressionCodec) + val defaultCompressedSize = jniWrapper.split( + nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray, firstRecordBatch) + + // Compute the custom compress size. + jniWrapper.setCompressType(nativeSplitter, customizedCompressCodec) + val customizedCompressedSize = jniWrapper.split( + nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray, firstRecordBatch) + + // Choose the compress algorithm based on the compress size. + if (customizedCompressedSize != -1 && defaultCompressedSize != -1) { + if (customizedCompressedSize > defaultCompressedSize) { + jniWrapper.setCompressType(nativeSplitter, defaultCompressionCodec) + } + } else { + logError("Failed to compute the compress size in the first record batch") + } + } + firstRecordBatch = false + + jniWrapper.split(nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray, firstRecordBatch) dep.splitTime.add(System.nanoTime() - startTime) dep.numInputRows.add(cb.numRows) writeMetrics.incRecordsWritten(1) diff --git a/cpp/src/jni/jni_wrapper.cc b/cpp/src/jni/jni_wrapper.cc index e67039079..f5d0d0fc8 100644 --- a/cpp/src/jni/jni_wrapper.cc +++ b/cpp/src/jni/jni_wrapper.cc @@ -1395,24 +1395,42 @@ Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_nativeMake( return shuffle_splitter_holder_.Insert(std::shared_ptr(splitter)); } -JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_split( - JNIEnv* env, jobject, jlong splitter_id, jint num_rows, jlongArray buf_addrs, - jlongArray buf_sizes) { +JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_setCompressType( + JNIEnv* env, jobject, jlong splitter_id, jstring compression_type_jstr) { auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); if (!splitter) { std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); env->ThrowNew(illegal_argument_exception_class, error_message.c_str()); return; } + + if (compression_type_jstr != NULL) { + auto compression_type_result = GetCompressionType(env, compression_type_jstr); + if (compression_type_result.status().ok()) { + splitter->SetCompressType(compression_type_result.MoveValueUnsafe()); + } + } + return; +} + +JNIEXPORT jlong JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_split( + JNIEnv* env, jobject, jlong splitter_id, jint num_rows, jlongArray buf_addrs, + jlongArray buf_sizes, jboolean first_record_batch) { + auto splitter = shuffle_splitter_holder_.Lookup(splitter_id); + if (!splitter) { + std::string error_message = "Invalid splitter id " + std::to_string(splitter_id); + env->ThrowNew(illegal_argument_exception_class, error_message.c_str()); + return -1; + } if (buf_addrs == NULL) { env->ThrowNew(illegal_argument_exception_class, std::string("Native split: buf_addrs can't be null").c_str()); - return; + return -1; } if (buf_sizes == NULL) { env->ThrowNew(illegal_argument_exception_class, std::string("Native split: buf_sizes can't be null").c_str()); - return; + return -1; } int in_bufs_len = env->GetArrayLength(buf_addrs); @@ -1420,7 +1438,7 @@ JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_s env->ThrowNew( illegal_argument_exception_class, std::string("Native split: length of buf_addrs and buf_sizes mismatch").c_str()); - return; + return -1; } jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, JNI_FALSE); @@ -1440,17 +1458,21 @@ JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_s std::string("Native split: make record batch failed, error message is " + status.message()) .c_str()); - return; + return -1; } - status = splitter->Split(*in); - - if (!status.ok()) { - // Throw IOException - env->ThrowNew(io_exception_class, + if (first_record_batch) { + return splitter->CompressedSize(*in); + } else { + status = splitter->Split(*in); + if (!status.ok()) { + // Throw IOException + env->ThrowNew(io_exception_class, std::string("Native split: splitter split failed, error message is " + status.message()) .c_str()); + } + return -1; } } diff --git a/cpp/src/shuffle/splitter.cc b/cpp/src/shuffle/splitter.cc index 6006d0d1c..7673dc2b9 100644 --- a/cpp/src/shuffle/splitter.cc +++ b/cpp/src/shuffle/splitter.cc @@ -331,6 +331,22 @@ arrow::Status Splitter::Init() { return arrow::Status::OK(); } +int64_t Splitter::CompressedSize(const arrow::RecordBatch& rb) { + auto payload = std::make_shared(); + auto result = arrow::ipc::internal::GetRecordBatchPayload( + rb, options_.ipc_write_options, payload.get()); + if (result.ok()) { + return payload.get()->body_length; + } else { + result.UnknownError("Failed to get the compressed size."); + return -1; + } +} + +void Splitter::SetCompressType(arrow::Compression::type compressed_type) { + options_.ipc_write_options.compression = compressed_type; +} + arrow::Status Splitter::Split(const arrow::RecordBatch& rb) { EVAL_START("split", options_.thread_id) RETURN_NOT_OK(ComputeAndCountPartitionId(rb)); diff --git a/cpp/src/shuffle/splitter.h b/cpp/src/shuffle/splitter.h index 796294cad..9ffdf95c8 100644 --- a/cpp/src/shuffle/splitter.h +++ b/cpp/src/shuffle/splitter.h @@ -52,6 +52,11 @@ class Splitter { * id. The largest partition buffer will be spilled if memory allocation failure occurs. */ virtual arrow::Status Split(const arrow::RecordBatch&); + + /** + * Compute the compresse size of record batch. + */ + virtual int64_t CompressedSize(const arrow::RecordBatch&); /** * For each partition, merge spilled file into shuffle data file and write any cached @@ -64,6 +69,8 @@ class Splitter { */ arrow::Status SpillPartition(int32_t partition_id); + void SetCompressType(arrow::Compression::type compressed_type); + /** * Spill for fixed size of partition data */