Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable Comet broadcast by default #213

Merged
merged 13 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 0 additions & 51 deletions common/src/main/java/org/apache/comet/CometArrowStreamWriter.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ protected CometDecodedVector(ValueVector vector, Field valueField, boolean useDe
}

@Override
ValueVector getValueVector() {
public ValueVector getValueVector() {
return valueVector;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public ColumnVector getChild(int i) {
}

@Override
ValueVector getValueVector() {
public ValueVector getValueVector() {
return delegate.getValueVector();
}

Expand All @@ -163,7 +163,7 @@ public CometVector slice(int offset, int length) {
}

@Override
DictionaryProvider getDictionaryProvider() {
public DictionaryProvider getDictionaryProvider() {
return delegate.getDictionaryProvider();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public CometDictionaryVector(
}

@Override
DictionaryProvider getDictionaryProvider() {
public DictionaryProvider getDictionaryProvider() {
return this.provider;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public byte[] getBinary(int rowId) {
}

@Override
CDataDictionaryProvider getDictionaryProvider() {
public CDataDictionaryProvider getDictionaryProvider() {
return null;
}

Expand Down
4 changes: 2 additions & 2 deletions common/src/main/java/org/apache/comet/vector/CometVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,11 @@ public void close() {
getValueVector().close();
}

DictionaryProvider getDictionaryProvider() {
public DictionaryProvider getDictionaryProvider() {
throw new UnsupportedOperationException("Not implemented");
}

abstract ValueVector getValueVector();
public abstract ValueVector getValueVector();

/**
* Returns a zero-copying new vector that contains the values from [offset, offset + length).
Expand Down
9 changes: 5 additions & 4 deletions common/src/main/scala/org/apache/comet/CometConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,13 @@ object CometConf {
.booleanConf
.createWithDefault(false)

val COMET_EXEC_BROADCAST_ENABLED: ConfigEntry[Boolean] =
val COMET_EXEC_BROADCAST_FORCE_ENABLED: ConfigEntry[Boolean] =
conf(s"$COMET_EXEC_CONFIG_PREFIX.broadcast.enabled")
.doc(
"Whether to enable broadcasting for Comet native operators. By default, " +
"this config is false. Note that this feature is not fully supported yet " +
"and only enabled for test purpose.")
"Whether to force enabling broadcasting for Comet native operators. By default, " +
"this config is false. Comet broadcast feature will be enabled automatically by " +
"Comet extension. But for unit tests, we need this feature to force enabling it " +
"for invalid cases. So this config is only used for unit test.")
.booleanConf
.createWithDefault(false)

Expand Down
85 changes: 3 additions & 82 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,89 +19,21 @@

package org.apache.comet.vector

import java.io.OutputStream
import java.nio.channels.Channels

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictionaryProvider, Data}
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.spark.SparkException
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.CometArrowStreamWriter

class NativeUtil {
import Utils._

private val allocator = new RootAllocator(Long.MaxValue)
private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider
private val importer = new ArrowImporter(allocator)

/**
* Serializes a list of `ColumnarBatch` into an output stream.
*
* @param batches
* the output batches, each batch is a list of Arrow vectors wrapped in `CometVector`
* @param out
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = {
var writer: Option[CometArrowStreamWriter] = None
var rowCount = 0

batches.foreach { batch =>
val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
val root = new VectorSchemaRoot(fieldVectors.asJava)
val provider = batchProviderOpt.getOrElse(dictionaryProvider)

if (writer.isEmpty) {
writer = Some(new CometArrowStreamWriter(root, provider, Channels.newChannel(out)))
writer.get.start()
writer.get.writeBatch()
} else {
writer.get.writeMoreBatch(root)
}

root.clear()
rowCount += batch.numRows()
}

writer.map(_.end())

rowCount
}

def getBatchFieldVectors(
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
var provider: Option[DictionaryProvider] = None
val fieldVectors = (0 until batch.numCols()).map { index =>
batch.column(index) match {
case a: CometVector =>
val valueVector = a.getValueVector
if (valueVector.getField.getDictionary != null) {
if (provider.isEmpty) {
provider = Some(a.getDictionaryProvider)
} else {
if (provider.get != a.getDictionaryProvider) {
throw new SparkException(
"Comet execution only takes Arrow Arrays with the same dictionary provider")
}
}
}

getFieldVector(valueVector)

case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}
}
(fieldVectors, provider)
}

/**
* Exports a Comet `ColumnarBatch` into a list of memory addresses that can be consumed by the
* native execution.
Expand Down Expand Up @@ -199,15 +131,4 @@ class NativeUtil {

new ColumnarBatch(arrayVectors.toArray, maxNumRows)
}

private def getFieldVector(valueVector: ValueVector): FieldVector = {
valueVector match {
case v @ (_: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector |
_: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector |
_: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector |
_: FixedSizeBinaryVector | _: TimeStampMicroVector) =>
v.asInstanceOf[FieldVector]
case _ => throw new SparkException(s"Unsupported Arrow Vector: ${valueVector.getClass}")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ case class StreamReader(channel: ReadableByteChannel) extends AutoCloseable {
// Native shuffle always uses decimal128.
CometVector.getVector(vector, true, arrowReader).asInstanceOf[ColumnVector]
}.toArray

val batch = new ColumnarBatch(columns)
batch.setNumRows(root.getRowCount)
batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,26 @@

package org.apache.spark.sql.comet.util

import java.io.File
import java.io.{DataOutputStream, File}
import java.nio.ByteBuffer
import java.nio.channels.Channels

import scala.collection.JavaConverters._

import org.apache.arrow.c.CDataDictionaryProvider
import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, FixedSizeBinaryVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector, VectorSchemaRoot}
import org.apache.arrow.vector.complex.MapVector
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.arrow.vector.types._
import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

import org.apache.comet.vector.CometVector

object Utils {
def getConfPath(confFileName: String): String = {
Expand Down Expand Up @@ -161,4 +173,79 @@ object Utils {
toArrowField(field.name, field.dataType, field.nullable, timeZoneId)
}.asJava)
}

/**
* Serializes a list of `ColumnarBatch` into an output stream. This method must be in `spark`
* package because `ChunkedByteBufferOutputStream` is spark private class. As it uses Arrow
* classes, it must be in `common` module.
*
* @param batches
* the output batches, each batch is a list of Arrow vectors wrapped in `CometVector`
* @param out
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch]): Iterator[(Long, ChunkedByteBuffer)] = {
batches.map { batch =>
val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider

val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
val cbbos = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate)
Comment on lines +187 to +192
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to move serializeBatches into spark package because ChunkedByteBufferOutputStream is a spark private class. I cannot move serializeBatches to spark module because it uses arrow packages (we shade arrow in common module).

val out = new DataOutputStream(codec.compressedOutputStream(cbbos))

val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
val root = new VectorSchemaRoot(fieldVectors.asJava)
val provider = batchProviderOpt.getOrElse(dictionaryProvider)

val writer = new ArrowStreamWriter(root, provider, Channels.newChannel(out))
writer.start()
writer.writeBatch()

root.clear()
writer.end()

out.flush()
out.close()

if (out.size() > 0) {
(batch.numRows(), cbbos.toChunkedByteBuffer)
} else {
(batch.numRows(), new ChunkedByteBuffer(Array.empty[ByteBuffer]))
}
}
}

def getBatchFieldVectors(
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
var provider: Option[DictionaryProvider] = None
val fieldVectors = (0 until batch.numCols()).map { index =>
batch.column(index) match {
case a: CometVector =>
val valueVector = a.getValueVector
if (valueVector.getField.getDictionary != null) {
if (provider.isEmpty) {
provider = Some(a.getDictionaryProvider)
}
}

getFieldVector(valueVector)

case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
s"${c.getClass}")
}
}
(fieldVectors, provider)
}

def getFieldVector(valueVector: ValueVector): FieldVector = {
valueVector match {
case v @ (_: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector |
_: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector |
_: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector |
_: FixedSizeBinaryVector | _: TimeStampMicroVector) =>
v.asInstanceOf[FieldVector]
case _ => throw new SparkException(s"Unsupported Arrow Vector: ${valueVector.getClass}")
}
}
}
23 changes: 21 additions & 2 deletions dev/diffs/3.4.2.diff
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ index 9ddb4abe98b..1bebe99f1cc 100644
sql(
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
index f33432ddb6f..fe9f74ff8f1 100644
index f33432ddb6f..6160c8d241a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen
Expand All @@ -270,7 +270,26 @@ index f33432ddb6f..fe9f74ff8f1 100644
case _ => Nil
}
}
@@ -1729,6 +1733,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat
@@ -1238,7 +1242,8 @@ abstract class DynamicPartitionPruningSuiteBase
}
}

- test("Plan broadcast pruning only when the broadcast can be reused") {
+ test("Plan broadcast pruning only when the broadcast can be reused",
+ IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
Given("dynamic pruning filter on the build side")
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") {
val df = sql(
@@ -1485,7 +1490,7 @@ abstract class DynamicPartitionPruningSuiteBase
}

test("SPARK-38148: Do not add dynamic partition pruning if there exists static partition " +
- "pruning") {
+ "pruning", IgnoreComet("TODO: Support SubqueryBroadcastExec in Comet: #242")) {
withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") {
Seq(
"f.store_id = 1" -> false,
@@ -1729,6 +1734,8 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat
case s: BatchScanExec =>
// we use f1 col for v2 tables due to schema pruning
s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1")))
Expand Down
Loading
Loading