diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml
index 8a51bf65d6a88..0f6783cbd685b 100644
--- a/connector/connect/client/jvm/pom.xml
+++ b/connector/connect/client/jvm/pom.xml
@@ -140,6 +140,7 @@
+ target/scala-${scala.binary.version}/classes
target/scala-${scala.binary.version}/test-classes
@@ -224,6 +225,24 @@
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ add-sources
+ generate-sources
+
+ add-source
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala b/connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala
new file mode 100644
index 0000000000000..c2e01d974e0e4
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client.arrow
+
+import scala.collection.generic.{GenericCompanion, GenMapFactory}
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.connect.client.arrow.ArrowDeserializers.resolveCompanion
+
+/**
+ * A couple of scala version specific collection utility functions.
+ */
+private[arrow] object ScalaCollectionUtils {
+ def getIterableCompanion(tag: ClassTag[_]): GenericCompanion[Iterable] = {
+ ArrowDeserializers.resolveCompanion[GenericCompanion[Iterable]](tag)
+ }
+ def getMapCompanion(tag: ClassTag[_]): GenMapFactory[Map] = {
+ resolveCompanion[GenMapFactory[Map]](tag)
+ }
+ def wrap[T](array: AnyRef): mutable.WrappedArray[T] = {
+ mutable.WrappedArray.make(array)
+ }
+}
diff --git a/connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala b/connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala
new file mode 100644
index 0000000000000..8a80e34162283
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client.arrow
+
+import scala.collection.{mutable, IterableFactory, MapFactory}
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.connect.client.arrow.ArrowDeserializers.resolveCompanion
+
+/**
+ * A couple of scala version specific collection utility functions.
+ */
+private[arrow] object ScalaCollectionUtils {
+ def getIterableCompanion(tag: ClassTag[_]): IterableFactory[Iterable] = {
+ ArrowDeserializers.resolveCompanion[IterableFactory[Iterable]](tag)
+ }
+ def getMapCompanion(tag: ClassTag[_]): MapFactory[Map] = {
+ resolveCompanion[MapFactory[Map]](tag)
+ }
+ def wrap[T](array: AnyRef): mutable.WrappedArray[T] = {
+ mutable.WrappedArray.make(array.asInstanceOf[Array[T]])
+ }
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index a727c86f70fc6..1cdc2035de60b 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -16,53 +16,48 @@
*/
package org.apache.spark.sql.connect.client
-import java.util.Collections
+import java.util.Objects
-import scala.collection.JavaConverters._
import scala.collection.mutable
import org.apache.arrow.memory.BufferAllocator
-import org.apache.arrow.vector.FieldVector
-import org.apache.arrow.vector.ipc.ArrowStreamReader
+import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch}
+import org.apache.arrow.vector.types.pojo
import org.apache.spark.connect.proto
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder}
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer
-import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable}
+import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator}
+import org.apache.spark.sql.connect.client.util.Cleanable
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.ArrowUtils
-import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
private[sql] class SparkResult[T](
responses: java.util.Iterator[proto.ExecutePlanResponse],
allocator: BufferAllocator,
encoder: AgnosticEncoder[T])
extends AutoCloseable
- with Cleanable {
+ with Cleanable { self =>
private[this] var numRecords: Int = 0
private[this] var structType: StructType = _
- private[this] var boundEncoder: ExpressionEncoder[T] = _
- private[this] var nextBatchIndex: Int = 0
- private val idxToBatches = mutable.Map.empty[Int, ColumnarBatch]
-
- private def createEncoder(schema: StructType): ExpressionEncoder[T] = {
- val agnosticEncoder = createEncoder(encoder, schema).asInstanceOf[AgnosticEncoder[T]]
- ExpressionEncoder(agnosticEncoder)
- }
+ private[this] var arrowSchema: pojo.Schema = _
+ private[this] var nextResultIndex: Int = 0
+ private val resultMap = mutable.Map.empty[Int, (Long, Seq[ArrowMessage])]
/**
* Update RowEncoder and recursively update the fields of the ProductEncoder if found.
*/
- private def createEncoder(enc: AgnosticEncoder[_], dataType: DataType): AgnosticEncoder[_] = {
+ private def createEncoder[E](
+ enc: AgnosticEncoder[E],
+ dataType: DataType): AgnosticEncoder[E] = {
enc match {
case UnboundRowEncoder =>
// Replace the row encoder with the encoder inferred from the schema.
- RowEncoder.encoderFor(dataType.asInstanceOf[StructType])
+ RowEncoder
+ .encoderFor(dataType.asInstanceOf[StructType])
+ .asInstanceOf[AgnosticEncoder[E]]
case ProductEncoder(clsTag, fields) if ProductEncoder.isTuple(clsTag) =>
// Recursively continue updating the tuple product encoder
val schema = dataType.asInstanceOf[StructType]
@@ -76,53 +71,61 @@ private[sql] class SparkResult[T](
}
}
- private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean = {
- while (responses.hasNext) {
+ private def processResponses(
+ stopOnSchema: Boolean = false,
+ stopOnArrowSchema: Boolean = false,
+ stopOnFirstNonEmptyResponse: Boolean = false): Boolean = {
+ var nonEmpty = false
+ var stop = false
+ while (!stop && responses.hasNext) {
val response = responses.next()
if (response.hasSchema) {
// The original schema should arrive before ArrowBatches.
structType =
DataTypeProtoConverter.toCatalystType(response.getSchema).asInstanceOf[StructType]
- } else if (response.hasArrowBatch) {
+ stop |= stopOnSchema
+ }
+ if (response.hasArrowBatch) {
val ipcStreamBytes = response.getArrowBatch.getData
- val reader = new ArrowStreamReader(ipcStreamBytes.newInput(), allocator)
- try {
- val root = reader.getVectorSchemaRoot
- if (structType == null) {
- // If the schema is not available yet, fallback to the schema from Arrow.
- structType = ArrowUtils.fromArrowSchema(root.getSchema)
- }
- // TODO: create encoders that directly operate on arrow vectors.
- if (boundEncoder == null) {
- boundEncoder = createEncoder(structType)
- .resolveAndBind(DataTypeUtils.toAttributes(structType))
- }
- while (reader.loadNextBatch()) {
- val rowCount = root.getRowCount
- if (rowCount > 0) {
- val vectors = root.getFieldVectors.asScala
- .map(v => new ArrowColumnVector(transferToNewVector(v)))
- .toArray[ColumnVector]
- idxToBatches.put(nextBatchIndex, new ColumnarBatch(vectors, rowCount))
- nextBatchIndex += 1
- numRecords += rowCount
- if (stopOnFirstNonEmptyResponse) {
- return true
- }
- }
+ val reader = new MessageIterator(ipcStreamBytes.newInput(), allocator)
+ if (arrowSchema == null) {
+ arrowSchema = reader.schema
+ stop |= stopOnArrowSchema
+ } else if (arrowSchema != reader.schema) {
+ throw new IllegalStateException(
+ s"""Schema Mismatch between expected and received schema:
+ |=== Expected Schema ===
+ |$arrowSchema
+ |=== Received Schema ===
+ |${reader.schema}
+ |""".stripMargin)
+ }
+ if (structType == null) {
+ // If the schema is not available yet, fallback to the arrow schema.
+ structType = ArrowUtils.fromArrowSchema(reader.schema)
+ }
+ var numRecordsInBatch = 0
+ val messages = Seq.newBuilder[ArrowMessage]
+ while (reader.hasNext) {
+ val message = reader.next()
+ message match {
+ case batch: ArrowRecordBatch =>
+ numRecordsInBatch += batch.getLength
+ case _ =>
}
- } finally {
- reader.close()
+ messages += message
+ }
+ // Skip the entire result if it is empty.
+ if (numRecordsInBatch > 0) {
+ numRecords += numRecordsInBatch
+ resultMap.put(nextResultIndex, (reader.bytesRead, messages.result()))
+ nextResultIndex += 1
+ nonEmpty |= true
+ stop |= stopOnFirstNonEmptyResponse
}
}
}
- false
- }
-
- private def transferToNewVector(in: FieldVector): FieldVector = {
- val pair = in.getTransferPair(allocator)
- pair.transfer()
- pair.getTo.asInstanceOf[FieldVector]
+ nonEmpty
}
/**
@@ -130,7 +133,7 @@ private[sql] class SparkResult[T](
*/
def length: Int = {
// We need to process all responses to make sure numRecords is correct.
- processResponses(stopOnFirstNonEmptyResponse = false)
+ processResponses()
numRecords
}
@@ -139,7 +142,9 @@ private[sql] class SparkResult[T](
* the schema of the result.
*/
def schema: StructType = {
- processResponses(stopOnFirstNonEmptyResponse = true)
+ if (structType == null) {
+ processResponses(stopOnSchema = true)
+ }
structType
}
@@ -172,52 +177,93 @@ private[sql] class SparkResult[T](
private def buildIterator(destructive: Boolean): java.util.Iterator[T] with AutoCloseable = {
new java.util.Iterator[T] with AutoCloseable {
- private[this] var batchIndex: Int = -1
- private[this] var iterator: java.util.Iterator[InternalRow] = Collections.emptyIterator()
- private[this] var deserializer: Deserializer[T] = _
+ private[this] var iterator: CloseableIterator[T] = _
- override def hasNext: Boolean = {
- if (iterator.hasNext) {
- return true
- }
-
- val nextBatchIndex = batchIndex + 1
- if (destructive) {
- idxToBatches.remove(batchIndex).foreach(_.close())
+ private def initialize(): Unit = {
+ if (iterator == null) {
+ iterator = new ArrowDeserializingIterator(
+ createEncoder(encoder, schema),
+ new ConcatenatingArrowStreamReader(
+ allocator,
+ Iterator.single(new ResultMessageIterator(destructive)),
+ destructive))
}
+ }
- val hasNextBatch = if (!idxToBatches.contains(nextBatchIndex)) {
- processResponses(stopOnFirstNonEmptyResponse = true)
- } else {
- true
- }
- if (hasNextBatch) {
- batchIndex = nextBatchIndex
- iterator = idxToBatches(nextBatchIndex).rowIterator()
- if (deserializer == null) {
- deserializer = boundEncoder.createDeserializer()
- }
- }
- hasNextBatch
+ override def hasNext: Boolean = {
+ initialize()
+ iterator.hasNext
}
override def next(): T = {
- if (!hasNext) {
- throw new NoSuchElementException
- }
- deserializer(iterator.next())
+ initialize()
+ iterator.next()
}
- override def close(): Unit = SparkResult.this.close()
+ override def close(): Unit = {
+ if (iterator != null) {
+ iterator.close()
+ }
+ }
}
}
/**
* Close this result, freeing any underlying resources.
*/
- override def close(): Unit = {
- idxToBatches.values.foreach(_.close())
+ override def close(): Unit = cleaner.close()
+
+ override val cleaner: AutoCloseable = new SparkResultCloseable(resultMap)
+
+ private class ResultMessageIterator(destructive: Boolean) extends AbstractMessageIterator {
+ private[this] var totalBytesRead = 0L
+ private[this] var nextResultIndex = 0
+ private[this] var current: Iterator[ArrowMessage] = Iterator.empty
+
+ override def bytesRead: Long = totalBytesRead
+
+ override def schema: pojo.Schema = {
+ if (arrowSchema == null) {
+ // We need a schema to proceed. Spark Connect will always
+ // return a result (with a schema) even if the result is empty.
+ processResponses(stopOnArrowSchema = true)
+ Objects.requireNonNull(arrowSchema)
+ }
+ arrowSchema
+ }
+
+ override def hasNext: Boolean = {
+ if (current.hasNext) {
+ return true
+ }
+ val hasNextResult = if (!resultMap.contains(nextResultIndex)) {
+ self.processResponses(stopOnFirstNonEmptyResponse = true)
+ } else {
+ true
+ }
+ if (hasNextResult) {
+ val Some((sizeInBytes, messages)) = if (destructive) {
+ resultMap.remove(nextResultIndex)
+ } else {
+ resultMap.get(nextResultIndex)
+ }
+ totalBytesRead += sizeInBytes
+ current = messages.iterator
+ nextResultIndex += 1
+ }
+ hasNextResult
+ }
+
+ override def next(): ArrowMessage = {
+ if (!hasNext) {
+ throw new NoSuchElementException()
+ }
+ current.next()
+ }
}
+}
- override def cleaner: AutoCloseable = AutoCloseables(idxToBatches.values.toSeq)
+private[client] class SparkResultCloseable(resultMap: mutable.Map[Int, (Long, Seq[ArrowMessage])])
+ extends AutoCloseable {
+ override def close(): Unit = resultMap.values.foreach(_._2.foreach(_.close()))
}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
new file mode 100644
index 0000000000000..154866d699a34
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -0,0 +1,533 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client.arrow
+
+import java.io.{ByteArrayInputStream, IOException}
+import java.lang.invoke.{MethodHandles, MethodType}
+import java.lang.reflect.Modifier
+import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInteger}
+import java.time._
+import java.util
+import java.util.{List => JList, Locale, Map => JMap}
+
+import scala.collection.mutable
+import scala.reflect.ClassTag
+
+import org.apache.arrow.memory.BufferAllocator
+import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, IntervalYearVector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector, VectorSchemaRoot}
+import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
+import org.apache.arrow.vector.ipc.ArrowReader
+import org.apache.arrow.vector.util.Text
+
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
+import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
+import org.apache.spark.sql.types.Decimal
+
+/**
+ * Helper class for converting arrow batches into user objects.
+ */
+object ArrowDeserializers {
+ import ArrowEncoderUtils._
+
+ /**
+ * Create an Iterator of `T`. This iterator takes an Iterator of Arrow IPC Streams, and
+ * deserializes these streams into one or more instances of `T`
+ */
+ def deserializeFromArrow[T](
+ input: Iterator[Array[Byte]],
+ encoder: AgnosticEncoder[T],
+ allocator: BufferAllocator): CloseableIterator[T] = {
+ try {
+ val reader = new ConcatenatingArrowStreamReader(
+ allocator,
+ input.map(bytes => new MessageIterator(new ByteArrayInputStream(bytes), allocator)),
+ destructive = true)
+ new ArrowDeserializingIterator(encoder, reader)
+ } catch {
+ case _: IOException =>
+ new EmptyDeserializingIterator(encoder)
+ }
+ }
+
+ /**
+ * Create a deserializer of `T` on top of the given `root`.
+ */
+ private[arrow] def deserializerFor[T](
+ encoder: AgnosticEncoder[T],
+ root: VectorSchemaRoot): Deserializer[T] = {
+ val data: AnyRef = if (encoder.isStruct) {
+ root
+ } else {
+ // The input schema is allowed to have multiple columns,
+ // by convention we bind to the first one.
+ root.getVector(0)
+ }
+ deserializerFor(encoder, data).asInstanceOf[Deserializer[T]]
+ }
+
+ private[arrow] def deserializerFor(
+ encoder: AgnosticEncoder[_],
+ data: AnyRef): Deserializer[Any] = {
+ (encoder, data) match {
+ case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) =>
+ new FieldDeserializer[Boolean, BitVector](v) {
+ def value(i: Int): Boolean = vector.get(i) != 0
+ }
+ case (PrimitiveByteEncoder | BoxedByteEncoder, v: TinyIntVector) =>
+ new FieldDeserializer[Byte, TinyIntVector](v) {
+ def value(i: Int): Byte = vector.get(i)
+ }
+ case (PrimitiveShortEncoder | BoxedShortEncoder, v: SmallIntVector) =>
+ new FieldDeserializer[Short, SmallIntVector](v) {
+ def value(i: Int): Short = vector.get(i)
+ }
+ case (PrimitiveIntEncoder | BoxedIntEncoder, v: IntVector) =>
+ new FieldDeserializer[Int, IntVector](v) {
+ def value(i: Int): Int = vector.get(i)
+ }
+ case (PrimitiveLongEncoder | BoxedLongEncoder, v: BigIntVector) =>
+ new FieldDeserializer[Long, BigIntVector](v) {
+ def value(i: Int): Long = vector.get(i)
+ }
+ case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: Float4Vector) =>
+ new FieldDeserializer[Float, Float4Vector](v) {
+ def value(i: Int): Float = vector.get(i)
+ }
+ case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: Float8Vector) =>
+ new FieldDeserializer[Double, Float8Vector](v) {
+ def value(i: Int): Double = vector.get(i)
+ }
+ case (NullEncoder, v: NullVector) =>
+ new FieldDeserializer[Any, NullVector](v) {
+ def value(i: Int): Any = null
+ }
+ case (StringEncoder, v: VarCharVector) =>
+ new FieldDeserializer[String, VarCharVector](v) {
+ def value(i: Int): String = getString(vector, i)
+ }
+ case (JavaEnumEncoder(tag), v: VarCharVector) =>
+ // It would be nice if we can get Enum.valueOf working...
+ val valueOf = methodLookup.findStatic(
+ tag.runtimeClass,
+ "valueOf",
+ MethodType.methodType(tag.runtimeClass, classOf[String]))
+ new FieldDeserializer[Enum[_], VarCharVector](v) {
+ def value(i: Int): Enum[_] = {
+ valueOf.invoke(getString(vector, i)).asInstanceOf[Enum[_]]
+ }
+ }
+ case (ScalaEnumEncoder(parent, _), v: VarCharVector) =>
+ val mirror = scala.reflect.runtime.currentMirror
+ val module = mirror.classSymbol(parent).module.asModule
+ val enumeration = mirror.reflectModule(module).instance.asInstanceOf[Enumeration]
+ new FieldDeserializer[Enumeration#Value, VarCharVector](v) {
+ def value(i: Int): Enumeration#Value = enumeration.withName(getString(vector, i))
+ }
+ case (BinaryEncoder, v: VarBinaryVector) =>
+ new FieldDeserializer[Array[Byte], VarBinaryVector](v) {
+ def value(i: Int): Array[Byte] = vector.get(i)
+ }
+ case (SparkDecimalEncoder(_), v: DecimalVector) =>
+ new FieldDeserializer[Decimal, DecimalVector](v) {
+ def value(i: Int): Decimal = Decimal(vector.getObject(i))
+ }
+ case (ScalaDecimalEncoder(_), v: DecimalVector) =>
+ new FieldDeserializer[BigDecimal, DecimalVector](v) {
+ def value(i: Int): BigDecimal = BigDecimal(vector.getObject(i))
+ }
+ case (JavaDecimalEncoder(_, _), v: DecimalVector) =>
+ new FieldDeserializer[JBigDecimal, DecimalVector](v) {
+ def value(i: Int): JBigDecimal = vector.getObject(i)
+ }
+ case (ScalaBigIntEncoder, v: DecimalVector) =>
+ new FieldDeserializer[BigInt, DecimalVector](v) {
+ def value(i: Int): BigInt = new BigInt(vector.getObject(i).toBigInteger)
+ }
+ case (JavaBigIntEncoder, v: DecimalVector) =>
+ new FieldDeserializer[JBigInteger, DecimalVector](v) {
+ def value(i: Int): JBigInteger = vector.getObject(i).toBigInteger
+ }
+ case (DayTimeIntervalEncoder, v: DurationVector) =>
+ new FieldDeserializer[Duration, DurationVector](v) {
+ def value(i: Int): Duration = vector.getObject(i)
+ }
+ case (YearMonthIntervalEncoder, v: IntervalYearVector) =>
+ new FieldDeserializer[Period, IntervalYearVector](v) {
+ def value(i: Int): Period = vector.getObject(i).normalized()
+ }
+ case (DateEncoder(_), v: DateDayVector) =>
+ new FieldDeserializer[java.sql.Date, DateDayVector](v) {
+ def value(i: Int): java.sql.Date = DateTimeUtils.toJavaDate(vector.get(i))
+ }
+ case (LocalDateEncoder(_), v: DateDayVector) =>
+ new FieldDeserializer[LocalDate, DateDayVector](v) {
+ def value(i: Int): LocalDate = DateTimeUtils.daysToLocalDate(vector.get(i))
+ }
+ case (TimestampEncoder(_), v: TimeStampMicroTZVector) =>
+ new FieldDeserializer[java.sql.Timestamp, TimeStampMicroTZVector](v) {
+ def value(i: Int): java.sql.Timestamp = DateTimeUtils.toJavaTimestamp(vector.get(i))
+ }
+ case (InstantEncoder(_), v: TimeStampMicroTZVector) =>
+ new FieldDeserializer[Instant, TimeStampMicroTZVector](v) {
+ def value(i: Int): Instant = DateTimeUtils.microsToInstant(vector.get(i))
+ }
+ case (LocalDateTimeEncoder, v: TimeStampMicroVector) =>
+ new FieldDeserializer[LocalDateTime, TimeStampMicroVector](v) {
+ def value(i: Int): LocalDateTime = DateTimeUtils.microsToLocalDateTime(vector.get(i))
+ }
+
+ case (OptionEncoder(value), v) =>
+ val deserializer = deserializerFor(value, v)
+ new Deserializer[Any] {
+ override def get(i: Int): Any = Option(deserializer.get(i))
+ }
+
+ case (ArrayEncoder(element, _), v: ListVector) =>
+ val deserializer = deserializerFor(element, v.getDataVector)
+ new FieldDeserializer[AnyRef, ListVector](v) {
+ def value(i: Int): AnyRef = getArray(vector, i, deserializer)(element.clsTag)
+ }
+
+ case (IterableEncoder(tag, element, _, _), v: ListVector) =>
+ val deserializer = deserializerFor(element, v.getDataVector)
+ if (isSubClass(Classes.WRAPPED_ARRAY, tag)) {
+ // Wrapped array is a bit special because we need to use an array of the element type.
+ // Some parts of our codebase (unfortunately) rely on this for type inference on results.
+ new FieldDeserializer[mutable.WrappedArray[Any], ListVector](v) {
+ def value(i: Int): mutable.WrappedArray[Any] = {
+ val array = getArray(vector, i, deserializer)(element.clsTag)
+ ScalaCollectionUtils.wrap(array)
+ }
+ }
+ } else if (isSubClass(Classes.ITERABLE, tag)) {
+ val companion = ScalaCollectionUtils.getIterableCompanion(tag)
+ new FieldDeserializer[Iterable[Any], ListVector](v) {
+ def value(i: Int): Iterable[Any] = {
+ val builder = companion.newBuilder[Any]
+ loadListIntoBuilder(vector, i, deserializer, builder)
+ builder.result()
+ }
+ }
+ } else if (isSubClass(Classes.JLIST, tag)) {
+ val newInstance = resolveJavaListCreator(tag)
+ new FieldDeserializer[JList[Any], ListVector](v) {
+ def value(i: Int): JList[Any] = {
+ var index = v.getElementStartIndex(i)
+ val end = v.getElementEndIndex(i)
+ val list = newInstance(end - index)
+ while (index < end) {
+ list.add(deserializer.get(index))
+ index += 1
+ }
+ list
+ }
+ }
+ } else {
+ throw unsupportedCollectionType(tag.runtimeClass)
+ }
+
+ case (MapEncoder(tag, key, value, _), v: MapVector) =>
+ val structVector = v.getDataVector.asInstanceOf[StructVector]
+ val keyDeserializer = deserializerFor(key, structVector.getChild(MapVector.KEY_NAME))
+ val valueDeserializer =
+ deserializerFor(value, structVector.getChild(MapVector.VALUE_NAME))
+ if (isSubClass(Classes.MAP, tag)) {
+ val companion = ScalaCollectionUtils.getMapCompanion(tag)
+ new FieldDeserializer[Map[Any, Any], MapVector](v) {
+ def value(i: Int): Map[Any, Any] = {
+ val builder = companion.newBuilder[Any, Any]
+ var index = v.getElementStartIndex(i)
+ val end = v.getElementEndIndex(i)
+ builder.sizeHint(end - index)
+ while (index < end) {
+ builder += (keyDeserializer.get(index) -> valueDeserializer.get(index))
+ index += 1
+ }
+ builder.result()
+ }
+ }
+ } else if (isSubClass(Classes.JMAP, tag)) {
+ val newInstance = resolveJavaMapCreator(tag)
+ new FieldDeserializer[JMap[Any, Any], MapVector](v) {
+ def value(i: Int): JMap[Any, Any] = {
+ val map = newInstance()
+ var index = v.getElementStartIndex(i)
+ val end = v.getElementEndIndex(i)
+ while (index < end) {
+ map.put(keyDeserializer.get(index), valueDeserializer.get(index))
+ index += 1
+ }
+ map
+ }
+ }
+ } else {
+ throw unsupportedCollectionType(tag.runtimeClass)
+ }
+
+ case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) =>
+ // We should try to make this work with MethodHandles.
+ val Some(constructor) =
+ ScalaReflection.findConstructor(tag.runtimeClass, fields.map(_.enc.clsTag.runtimeClass))
+ val deserializers = if (isTuple(tag.runtimeClass)) {
+ fields.zip(vectors).map { case (field, vector) =>
+ deserializerFor(field.enc, vector)
+ }
+ } else {
+ val lookup = createFieldLookup(vectors)
+ fields.map { field =>
+ deserializerFor(field.enc, lookup(field.name))
+ }
+ }
+ new StructFieldSerializer[Any](struct) {
+ def value(i: Int): Any = {
+ constructor(deserializers.map(_.get(i).asInstanceOf[AnyRef]))
+ }
+ }
+
+ case (r @ RowEncoder(fields), StructVectors(struct, vectors)) =>
+ val lookup = createFieldLookup(vectors)
+ val deserializers = fields.toArray.map { field =>
+ deserializerFor(field.enc, lookup(field.name))
+ }
+ new StructFieldSerializer[Any](struct) {
+ def value(i: Int): Any = {
+ val values = deserializers.map(_.get(i))
+ new GenericRowWithSchema(values, r.schema)
+ }
+ }
+
+ case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) =>
+ val constructor =
+ methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit]))
+ val lookup = createFieldLookup(vectors)
+ val setters = fields.map { field =>
+ val vector = lookup(field.name)
+ val deserializer = deserializerFor(field.enc, vector)
+ val setter = methodLookup.findVirtual(
+ tag.runtimeClass,
+ field.writeMethod.get,
+ MethodType.methodType(classOf[Unit], field.enc.clsTag.runtimeClass))
+ (bean: Any, i: Int) => setter.invoke(bean, deserializer.get(i))
+ }
+ new StructFieldSerializer[Any](struct) {
+ def value(i: Int): Any = {
+ val instance = constructor.invoke()
+ setters.foreach(_(instance, i))
+ instance
+ }
+ }
+
+ case (CalendarIntervalEncoder | _: UDTEncoder[_], _) =>
+ throw QueryExecutionErrors.unsupportedDataTypeError(encoder.dataType)
+
+ case _ =>
+ throw new RuntimeException(
+ s"Unsupported Encoder($encoder)/Vector(${data.getClass}) combination.")
+ }
+ }
+
+ private val methodLookup = MethodHandles.lookup()
+
+ /**
+ * Resolve the companion object for a scala class. In our particular case the class we pass in
+ * is a Scala collection. We use the companion to create a builder for that collection.
+ */
+ private[arrow] def resolveCompanion[T](tag: ClassTag[_]): T = {
+ val mirror = scala.reflect.runtime.currentMirror
+ val module = mirror.classSymbol(tag.runtimeClass).companion.asModule
+ mirror.reflectModule(module).instance.asInstanceOf[T]
+ }
+
+ /**
+ * Create a function that creates a [[util.List]] instance. The int parameter of the creator
+ * function is a size hint.
+ *
+ * If the [[ClassTag]] `tag` points to an interface instead of a concrete class we try to use
+ * [[util.ArrayList]]. For concrete classes we try to use a constructor that takes a single
+ * [[Int]] argument, it is assumed this is a size hint. If no such constructor exists we
+ * fallback to a no-args constructor.
+ */
+ private def resolveJavaListCreator(tag: ClassTag[_]): Int => JList[Any] = {
+ val cls = tag.runtimeClass
+ val modifiers = cls.getModifiers
+ if (Modifier.isInterface(modifiers) || Modifier.isAbstract(modifiers)) {
+ // Abstract class or interface; we try to use ArrayList.
+ if (!cls.isAssignableFrom(classOf[util.ArrayList[_]])) {
+ unsupportedCollectionType(cls)
+ }
+ (size: Int) => new util.ArrayList[Any](size)
+ } else {
+ try {
+ // Try to use a constructor that (hopefully) takes a size argument.
+ val ctor = methodLookup.findConstructor(
+ tag.runtimeClass,
+ MethodType.methodType(classOf[Unit], Integer.TYPE))
+ size => ctor.invoke(size).asInstanceOf[JList[Any]]
+ } catch {
+ case _: java.lang.NoSuchMethodException =>
+ // Use a no-args constructor.
+ val ctor =
+ methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit]))
+ _ => ctor.invoke().asInstanceOf[JList[Any]]
+ }
+ }
+ }
+
+ /**
+ * Create a function that creates a [[util.Map]] instance.
+ *
+ * If the [[ClassTag]] `tag` points to an interface instead of a concrete class we try to use
+ * [[util.HashMap]]. For concrete classes we try to use a no-args constructor.
+ */
+ private def resolveJavaMapCreator(tag: ClassTag[_]): () => JMap[Any, Any] = {
+ val cls = tag.runtimeClass
+ val modifiers = cls.getModifiers
+ if (Modifier.isInterface(modifiers) || Modifier.isAbstract(modifiers)) {
+ // Abstract class or interface; we try to use HashMap.
+ if (!cls.isAssignableFrom(classOf[java.util.HashMap[_, _]])) {
+ unsupportedCollectionType(cls)
+ }
+ () => new util.HashMap[Any, Any]()
+ } else {
+ // Use a no-args constructor.
+ val ctor =
+ methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit]))
+ () => ctor.invoke().asInstanceOf[JMap[Any, Any]]
+ }
+ }
+
+ /**
+ * Create a function that can lookup one [[FieldVector vectors]] in `fields` by name. This
+ * lookup is case insensitive. If the schema contains fields with duplicate (with
+ * case-insensitive resolution) names an exception is thrown. The returned function will throw
+ * an exception when no column can be found for a name.
+ *
+ * A small note on the binding process in general. Over complete schemas are currently allowed,
+ * meaning that the data can have more column than the encoder. In this the over complete
+ * (unbound) columns are ignored.
+ */
+ private def createFieldLookup(fields: Seq[FieldVector]): String => FieldVector = {
+ def toKey(k: String): String = k.toLowerCase(Locale.ROOT)
+ val lookup = mutable.Map.empty[String, FieldVector]
+ fields.foreach { field =>
+ val key = toKey(field.getName)
+ val old = lookup.put(key, field)
+ if (old.isDefined) {
+ throw QueryCompilationErrors.ambiguousColumnOrFieldError(
+ field.getName :: Nil,
+ fields.count(f => toKey(f.getName) == key))
+ }
+ }
+ name => {
+ lookup.getOrElse(toKey(name), throw QueryCompilationErrors.columnNotFoundError(name))
+ }
+ }
+
+ private def isTuple(cls: Class[_]): Boolean = cls.getName.startsWith("scala.Tuple")
+
+ private def getString(v: VarCharVector, i: Int): String = {
+ // This is currently a bit heavy on allocations:
+ // - byte array created in VarCharVector.get
+ // - CharBuffer created CharSetEncoder
+ // - char array in String
+ // By using direct buffers and reusing the char buffer
+ // we could get rid of the first two allocations.
+ Text.decode(v.get(i))
+ }
+
+ private def loadListIntoBuilder(
+ v: ListVector,
+ i: Int,
+ deserializer: Deserializer[Any],
+ builder: mutable.Builder[Any, _]): Unit = {
+ var index = v.getElementStartIndex(i)
+ val end = v.getElementEndIndex(i)
+ builder.sizeHint(end - index)
+ while (index < end) {
+ builder += deserializer.get(index)
+ index += 1
+ }
+ }
+
+ private def getArray(v: ListVector, i: Int, deserializer: Deserializer[Any])(implicit
+ tag: ClassTag[Any]): AnyRef = {
+ val builder = mutable.ArrayBuilder.make[Any]
+ loadListIntoBuilder(v, i, deserializer, builder)
+ builder.result()
+ }
+
+ abstract class Deserializer[+E] {
+ def get(i: Int): E
+ }
+
+ abstract class FieldDeserializer[E, V <: FieldVector](val vector: V) extends Deserializer[E] {
+ def value(i: Int): E
+ def isNull(i: Int): Boolean = vector.isNull(i)
+ override def get(i: Int): E = {
+ if (!isNull(i)) {
+ value(i)
+ } else {
+ null.asInstanceOf[E]
+ }
+ }
+ }
+
+ abstract class StructFieldSerializer[E](v: StructVector)
+ extends FieldDeserializer[E, StructVector](v) {
+ override def isNull(i: Int): Boolean = vector != null && vector.isNull(i)
+ }
+}
+
+class EmptyDeserializingIterator[E](val encoder: AgnosticEncoder[E])
+ extends CloseableIterator[E] {
+ override def close(): Unit = ()
+ override def hasNext: Boolean = false
+ override def next(): E = throw new NoSuchElementException()
+}
+
+class ArrowDeserializingIterator[E](
+ val encoder: AgnosticEncoder[E],
+ private[this] val reader: ArrowReader)
+ extends CloseableIterator[E] {
+ private[this] var index = 0
+ private[this] val root = reader.getVectorSchemaRoot
+ private[this] val deserializer = ArrowDeserializers.deserializerFor(encoder, root)
+
+ override def hasNext: Boolean = {
+ if (index >= root.getRowCount) {
+ if (reader.loadNextBatch()) {
+ index = 0
+ }
+ }
+ index < root.getRowCount
+ }
+
+ override def next(): E = {
+ if (!hasNext) {
+ throw new NoSuchElementException()
+ }
+ val result = deserializer.get(index)
+ index += 1
+ result
+ }
+
+ override def close(): Unit = reader.close()
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
index f6b140bae557b..ed27336985416 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
@@ -24,8 +24,11 @@ import org.apache.arrow.vector.complex.StructVector
private[arrow] object ArrowEncoderUtils {
object Classes {
+ val WRAPPED_ARRAY: Class[_] = classOf[scala.collection.mutable.WrappedArray[_]]
val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]]
+ val MAP: Class[_] = classOf[scala.collection.Map[_, _]]
val JLIST: Class[_] = classOf[java.util.List[_]]
+ val JMAP: Class[_] = classOf[java.util.Map[_, _]]
}
def isSubClass(cls: Class[_], tag: ClassTag[_]): Boolean = {
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala
new file mode 100644
index 0000000000000..90963c831c252
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client.arrow
+
+import java.io.{InputStream, IOException}
+import java.nio.channels.Channels
+
+import org.apache.arrow.flatbuf.MessageHeader
+import org.apache.arrow.memory.{ArrowBuf, BufferAllocator}
+import org.apache.arrow.vector.ipc.{ArrowReader, ReadChannel}
+import org.apache.arrow.vector.ipc.message.{ArrowDictionaryBatch, ArrowMessage, ArrowRecordBatch, MessageChannelReader, MessageResult, MessageSerializer}
+import org.apache.arrow.vector.types.pojo.Schema
+
+/**
+ * An [[ArrowReader]] that concatenates multiple [[MessageIterator]]s into a single stream. Each
+ * iterator represents a single IPC stream. The concatenated streams all must have the same
+ * schema. If the schema is different an exception is thrown.
+ *
+ * In some cases we want to retain the messages (see `SparkResult`). Normally a stream reader
+ * closes its messages when it consumes them. In order to prevent that from happening in
+ * non-destructive mode we clone the messages before passing them to the reading logic.
+ */
+class ConcatenatingArrowStreamReader(
+ allocator: BufferAllocator,
+ input: Iterator[AbstractMessageIterator],
+ destructive: Boolean)
+ extends ArrowReader(allocator) {
+
+ private[this] var totalBytesRead: Long = 0
+ private[this] var current: AbstractMessageIterator = _
+
+ override protected def readSchema(): Schema = {
+ // readSchema() should only be called once during initialization.
+ assert(current == null)
+ if (!input.hasNext) {
+ // ArrowStreamReader throws the same exception.
+ throw new IOException("Unexpected end of input. Missing schema.")
+ }
+ current = input.next()
+ current.schema
+ }
+
+ private def nextMessage(): ArrowMessage = {
+ // readSchema() should have been invoked at this point so 'current' should be initialized.
+ assert(current != null)
+ // Try to find a non-empty message iterator.
+ while (!current.hasNext && input.hasNext) {
+ totalBytesRead += current.bytesRead
+ current = input.next()
+ if (current.schema != getVectorSchemaRoot.getSchema) {
+ throw new IllegalStateException()
+ }
+ }
+ if (current.hasNext) {
+ current.next()
+ } else {
+ null
+ }
+ }
+
+ override def loadNextBatch(): Boolean = {
+ // Keep looping until we load a non-empty batch or until we exhaust the input.
+ var message = nextMessage()
+ while (message != null) {
+ message match {
+ case rb: ArrowRecordBatch =>
+ loadRecordBatch(cloneIfNonDestructive(rb))
+ if (getVectorSchemaRoot.getRowCount > 0) {
+ return true
+ }
+ case db: ArrowDictionaryBatch =>
+ loadDictionary(cloneIfNonDestructive(db))
+ }
+ message = nextMessage()
+ }
+ false
+ }
+
+ private def cloneIfNonDestructive(batch: ArrowRecordBatch): ArrowRecordBatch = {
+ if (destructive) {
+ return batch
+ }
+ cloneRecordBatch(batch)
+ }
+
+ private def cloneIfNonDestructive(batch: ArrowDictionaryBatch): ArrowDictionaryBatch = {
+ if (destructive) {
+ return batch
+ }
+ new ArrowDictionaryBatch(
+ batch.getDictionaryId,
+ cloneRecordBatch(batch.getDictionary),
+ batch.isDelta)
+ }
+
+ private def cloneRecordBatch(batch: ArrowRecordBatch): ArrowRecordBatch = {
+ new ArrowRecordBatch(
+ batch.getLength,
+ batch.getNodes,
+ batch.getBuffers,
+ batch.getBodyCompression,
+ true,
+ true)
+ }
+
+ override def bytesRead(): Long = {
+ if (current != null) {
+ totalBytesRead + current.bytesRead
+ } else {
+ 0
+ }
+ }
+
+ override def closeReadSource(): Unit = ()
+}
+
+trait AbstractMessageIterator extends Iterator[ArrowMessage] {
+ def schema: Schema
+ def bytesRead: Long
+}
+
+/**
+ * Decode an Arrow IPC stream into individual messages. Please note that this iterator MUST have a
+ * valid IPC stream as its input, otherwise construction will fail.
+ */
+class MessageIterator(input: InputStream, allocator: BufferAllocator)
+ extends AbstractMessageIterator {
+ private[this] val in = new ReadChannel(Channels.newChannel(input))
+ private[this] val reader = new MessageChannelReader(in, allocator)
+ private[this] var result: MessageResult = _
+
+ // Eagerly read the schema.
+ val schema: Schema = {
+ val result = reader.readNext()
+ if (result == null) {
+ throw new IOException("Unexpected end of input. Missing schema.")
+ }
+ MessageSerializer.deserializeSchema(result.getMessage)
+ }
+
+ override def bytesRead: Long = reader.bytesRead()
+
+ override def hasNext: Boolean = {
+ if (result == null) {
+ result = reader.readNext()
+ }
+ result != null
+ }
+
+ override def next(): ArrowMessage = {
+ if (!hasNext) {
+ throw new NoSuchElementException()
+ }
+ val message = result.getMessage.headerType() match {
+ case MessageHeader.RecordBatch =>
+ MessageSerializer.deserializeRecordBatch(result.getMessage, bodyBuffer(result))
+ case MessageHeader.DictionaryBatch =>
+ MessageSerializer.deserializeDictionaryBatch(result.getMessage, bodyBuffer(result))
+ }
+ result = null
+ message
+ }
+
+ private def bodyBuffer(result: MessageResult): ArrowBuf = {
+ var buffer = result.getBodyBuffer
+ if (buffer == null) {
+ buffer = allocator.getEmpty
+ }
+ buffer
+ }
+}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 73c04389c0597..07dd2a96bd8f7 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -39,7 +39,6 @@ import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils.port
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.sql.vectorized.ColumnarBatch
class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester {
@@ -571,7 +570,8 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
(col("id") / lit(10.0d)).as("b"),
col("id"),
lit("world").as("d"),
- (col("id") % 2).cast("int").as("a"))
+ // TODO SPARK-44449 make this int again when upcasting is in.
+ (col("id") % 2).cast("double").as("a"))
private def validateMyTypeResult(result: Array[MyType]): Unit = {
result.zipWithIndex.foreach { case (MyType(id, a, b), i) =>
@@ -818,10 +818,11 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
}
test("toJSON") {
+ // TODO SPARK-44449 make this int again when upcasting is in.
val expected = Array(
- """{"b":0.0,"id":0,"d":"world","a":0}""",
- """{"b":0.1,"id":1,"d":"world","a":1}""",
- """{"b":0.2,"id":2,"d":"world","a":0}""")
+ """{"b":0.0,"id":0,"d":"world","a":0.0}""",
+ """{"b":0.1,"id":1,"d":"world","a":1.0}""",
+ """{"b":0.2,"id":2,"d":"world","a":0.0}""")
val result = spark
.range(3)
.select(generateMyTypeColumns: _*)
@@ -893,14 +894,12 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
test("Dataset result destructive iterator") {
// Helper methods for accessing private field `idxToBatches` from SparkResult
- val _idxToBatches =
- PrivateMethod[mutable.Map[Int, ColumnarBatch]](Symbol("idxToBatches"))
+ val getResultMap =
+ PrivateMethod[mutable.Map[Int, Any]](Symbol("resultMap"))
- def getColumnarBatches(result: SparkResult[_]): Seq[ColumnarBatch] = {
- val idxToBatches = result invokePrivate _idxToBatches()
-
- // Sort by key to get stable results.
- idxToBatches.toSeq.sortBy(_._1).map(_._2)
+ def assertResultsMapEmpty(result: SparkResult[_]): Unit = {
+ val resultMap = result invokePrivate getResultMap()
+ assert(resultMap.isEmpty)
}
val df = spark
@@ -911,25 +910,19 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
try {
// build and verify the destructive iterator
val iterator = result.destructiveIterator
- // batches is empty before traversing the result iterator
- assert(getColumnarBatches(result).isEmpty)
- var previousBatch: ColumnarBatch = null
- val buffer = mutable.Buffer.empty[Long]
+ // resultMap Map is empty before traversing the result iterator
+ assertResultsMapEmpty(result)
+ val buffer = mutable.Set.empty[Long]
while (iterator.hasNext) {
- // always having 1 batch, since a columnar batch will be removed and closed after
- // its data got consumed.
- val batches = getColumnarBatches(result)
- assert(batches.size === 1)
- assert(batches.head != previousBatch)
- previousBatch = batches.head
-
- buffer.append(iterator.next())
+ // resultMap is empty during iteration because results get removed immediately on access.
+ assertResultsMapEmpty(result)
+ buffer += iterator.next()
}
- // Batches should be closed and removed after traversing all the records.
- assert(getColumnarBatches(result).isEmpty)
+ // resultMap Map is empty afterward because all results have been removed.
+ assertResultsMapEmpty(result)
- val expectedResult = Seq(6L, 7L, 8L)
- assert(buffer.size === 3 && expectedResult.forall(buffer.contains))
+ val expectedResult = Set(6L, 7L, 8L)
+ assert(buffer.size === 3 && expectedResult == buffer)
} finally {
result.close()
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
index e15069f2d9e96..ab3e13da53178 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
@@ -68,10 +68,11 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
}
test("keyAs - keys") {
+ // TODO SPARK-44449 make this long again when upcasting is in.
// It is okay to cast from Long to Double, but not Long to Int.
val values = spark
.range(10)
- .groupByKey(v => v % 2)
+ .groupByKey(v => (v % 2).toDouble)
.keyAs[Double]
.keys
.collectAsList()
@@ -232,9 +233,10 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
}
test("agg, keyAs") {
+ // TODO SPARK-44449 make this long again when upcasting is in.
val ds = spark
.range(10)
- .groupByKey(v => v % 2)
+ .groupByKey(v => (v % 2).toDouble)
.keyAs[Double]
.agg(count("*"))
@@ -244,7 +246,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
test("typed aggregation: expr") {
val session: SparkSession = spark
import session.implicits._
- val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ // TODO SPARK-44449 make this int again when upcasting is in.
+ val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS()
checkDatasetUnorderly(
ds.groupByKey(_._1).agg(sum("_2").as[Long]),
@@ -254,7 +257,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
}
test("typed aggregation: expr, expr") {
- val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ // TODO SPARK-44449 make this int again when upcasting is in.
+ val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS()
checkDatasetUnorderly(
ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]),
@@ -264,7 +268,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
}
test("typed aggregation: expr, expr, expr") {
- val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ // TODO SPARK-44449 make this int again when upcasting is in.
+ val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS()
checkDatasetUnorderly(
ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")),
@@ -274,7 +279,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
}
test("typed aggregation: expr, expr, expr, expr") {
- val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ // TODO SPARK-44449 make this int again when upcasting is in.
+ val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS()
checkDatasetUnorderly(
ds.groupByKey(_._1)
@@ -289,7 +295,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
}
test("typed aggregation: expr, expr, expr, expr, expr") {
- val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ // TODO SPARK-44449 make this int again when upcasting is in.
+ val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS()
checkDatasetUnorderly(
ds.groupByKey(_._1)
@@ -305,7 +312,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
}
test("typed aggregation: expr, expr, expr, expr, expr, expr") {
- val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ // TODO SPARK-44449 make this int again when upcasting is in.
+ val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS()
checkDatasetUnorderly(
ds.groupByKey(_._1)
@@ -322,7 +330,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
}
test("typed aggregation: expr, expr, expr, expr, expr, expr, expr") {
- val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ // TODO SPARK-44449 make this int again when upcasting is in.
+ val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS()
checkDatasetUnorderly(
ds.groupByKey(_._1)
@@ -340,7 +349,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
}
test("typed aggregation: expr, expr, expr, expr, expr, expr, expr, expr") {
- val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ // TODO SPARK-44449 make this int again when upcasting is in.
+ val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS()
checkDatasetUnorderly(
ds.groupByKey(_._1)
@@ -473,9 +483,9 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper {
val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1))
.toDF("key", "seq", "value")
val grouped = ds.groupBy($"value").as[String, (String, Int, Int)]
- val keys = grouped.keyAs[String].keys.sort($"value")
-
- checkDataset(keys, "1", "2", "10", "20")
+ // TODO SPARK-44449 make this string again when upcasting is in.
+ val keys = grouped.keyAs[Int].keys.sort($"value")
+ checkDataset(keys, 1, 2, 10, 20)
}
test("flatMapGroupsWithState") {
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 58758a1384031..800ce43a60df0 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -208,8 +208,9 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
}
test("UDF Registration") {
+ // TODO SPARK-44449 make this long again when upcasting is in.
val input = """
- |class A(x: Int) { def get = x * 100 }
+ |class A(x: Int) { def get: Long = x * 100 }
|val myUdf = udf((x: Int) => new A(x).get)
|spark.udf.register("dummyUdf", myUdf)
|spark.sql("select dummyUdf(id) from range(5)").as[Long].collect()
@@ -219,8 +220,9 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
}
test("UDF closure registration") {
+ // TODO SPARK-44449 make this int again when upcasting is in.
val input = """
- |class A(x: Int) { def get = x * 15 }
+ |class A(x: Int) { def get: Long = x * 15 }
|spark.udf.register("directUdf", (x: Int) => new A(x).get)
|spark.sql("select directUdf(id) from range(5)").as[Long].collect()
""".stripMargin
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index 0c327484e477d..16eec3eee3110 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -21,24 +21,19 @@ import java.util
import java.util.{Collections, Objects}
import scala.beans.BeanProperty
-import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.classTag
-import scala.util.control.NonFatal
-import com.google.protobuf.ByteString
import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
import org.apache.arrow.vector.VarBinaryVector
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkUnsupportedOperationException
-import org.apache.spark.connect.proto
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedIntEncoder, CalendarIntervalEncoder, DateEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, RowEncoder, StringEncoder, TimestampEncoder, UDTEncoder}
import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder}
-import org.apache.spark.sql.connect.client.SparkResult
import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum
import org.apache.spark.sql.connect.client.util.ConnectFunSuite
import org.apache.spark.sql.types.{ArrayType, DataType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType}
@@ -96,15 +91,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll {
}
val resultIterator =
- try {
- deserializeFromArrow(inspectedIterator, encoder, deserializerAllocator)
- } catch {
- case NonFatal(e) =>
- arrowIterator.close()
- serializerAllocator.close()
- deserializerAllocator.close()
- throw e
- }
+ ArrowDeserializers.deserializeFromArrow(inspectedIterator, encoder, deserializerAllocator)
new CloseableIterator[T] {
override def close(): Unit = {
arrowIterator.close()
@@ -117,25 +104,6 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll {
}
}
- // Temporary hack until we merge the deserializer.
- private def deserializeFromArrow[E](
- batches: Iterator[Array[Byte]],
- encoder: AgnosticEncoder[E],
- allocator: BufferAllocator): CloseableIterator[E] = {
- val responses = batches.map { batch =>
- val builder = proto.ExecutePlanResponse.newBuilder()
- builder.getArrowBatchBuilder.setData(ByteString.copyFrom(batch))
- builder.build()
- }
- val result = new SparkResult[E](responses.asJava, allocator, encoder)
- new CloseableIterator[E] {
- private val itr = result.iterator
- override def close(): Unit = itr.close()
- override def hasNext: Boolean = itr.hasNext
- override def next(): E = itr.next()
- }
- }
-
private def roundTripAndCheck[T](
encoder: AgnosticEncoder[T],
toInputIterator: () => Iterator[Any],
@@ -246,6 +214,15 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll {
assert(inspector.sizeInBytes > 0)
}
+ test("deserializing empty iterator") {
+ withAllocator { allocator =>
+ val iterator =
+ ArrowDeserializers.deserializeFromArrow(Iterator.empty, singleIntEncoder, allocator)
+ assert(iterator.isEmpty)
+ assert(allocator.getAllocatedMemory == 0)
+ }
+ }
+
test("single batch") {
val inspector = new CountingBatchInspector
roundTripAndCheckIdentical(singleIntEncoder, inspectBatch = inspector) { () =>
@@ -533,15 +510,22 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll {
val maybeNull = MaybeNull(11)
Iterator.tabulate(100) { i =>
val bean = new JavaMapData
- bean.setDummyToDoubleListMap(maybeNull {
- val map = new util.HashMap[DummyBean, java.util.List[java.lang.Double]]
- (0 until (i % 5)).foreach { j =>
- val dummy = new DummyBean
- dummy.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i * j)))
+ bean.setMetricMap(maybeNull {
+ val map = new util.HashMap[String, util.List[java.lang.Double]]
+ (0 until (i % 20)).foreach { i =>
val values = Array.tabulate(i % 40) { j =>
Double.box(j.toDouble)
}
- map.put(dummy, maybeNull(util.Arrays.asList(values: _*)))
+ map.put("k" + i, maybeNull(util.Arrays.asList(values: _*)))
+ }
+ map
+ })
+ bean.setDummyToStringMap(maybeNull {
+ val map = new util.HashMap[DummyBean, String]
+ (0 until (i % 5)).foreach { j =>
+ val dummy = new DummyBean
+ dummy.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i * j)))
+ map.put(dummy, maybeNull("s" + i + "v" + j))
}
map
})
@@ -675,6 +659,57 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll {
.add("Ca", "array")
.add("Cb", "binary")))
+ test("bind to schema") {
+ // Binds to a wider schema. The narrow schema has fewer (nested) fields, has a slightly
+ // different field order, and uses different cased names in a couple of places.
+ withAllocator { allocator =>
+ val input = Row(
+ 887,
+ "foo",
+ Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte), 5f),
+ Seq(Row(null, "a", false), Row(javaBigDecimal(57853, 10), "b", false)))
+ val expected = Row(
+ "foo",
+ Seq(Row(null, false), Row(javaBigDecimal(57853, 10), false)),
+ Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte)))
+ val arrowBatches = serializeToArrow(Iterator.single(input), wideSchemaEncoder, allocator)
+ val result =
+ ArrowDeserializers.deserializeFromArrow(arrowBatches, narrowSchemaEncoder, allocator)
+ val actual = result.next()
+ assert(result.isEmpty)
+ assert(expected === actual)
+ result.close()
+ arrowBatches.close()
+ }
+ }
+
+ test("unknown field") {
+ withAllocator { allocator =>
+ val arrowBatches = serializeToArrow(Iterator.empty, narrowSchemaEncoder, allocator)
+ intercept[AnalysisException] {
+ ArrowDeserializers.deserializeFromArrow(arrowBatches, wideSchemaEncoder, allocator)
+ }
+ arrowBatches.close()
+ }
+ }
+
+ test("duplicate fields") {
+ val duplicateSchemaEncoder = toRowEncoder(
+ new StructType()
+ .add("foO", "string")
+ .add("Foo", "string"))
+ val fooSchemaEncoder = toRowEncoder(
+ new StructType()
+ .add("foo", "string"))
+ withAllocator { allocator =>
+ val arrowBatches = serializeToArrow(Iterator.empty, duplicateSchemaEncoder, allocator)
+ intercept[AnalysisException] {
+ ArrowDeserializers.deserializeFromArrow(arrowBatches, fooSchemaEncoder, allocator)
+ }
+ arrowBatches.close()
+ }
+ }
+
/* ******************************************************************** *
* Arrow serialization/deserialization specific errors
* ******************************************************************** */
@@ -833,17 +868,23 @@ case class MapData(intStringMap: Map[Int, String], metricMap: Map[String, Array[
class JavaMapData {
@scala.beans.BeanProperty
- var dummyToDoubleListMap: java.util.Map[DummyBean, java.util.List[java.lang.Double]] = _
+ var dummyToStringMap: java.util.Map[DummyBean, String] = _
+
+ @scala.beans.BeanProperty
+ var metricMap: java.util.HashMap[String, java.util.List[java.lang.Double]] = _
def canEqual(other: Any): Boolean = other.isInstanceOf[JavaMapData]
override def equals(other: Any): Boolean = other match {
case that: JavaMapData if that canEqual this =>
- dummyToDoubleListMap == that.dummyToDoubleListMap
+ dummyToStringMap == that.dummyToStringMap &&
+ metricMap == that.metricMap
case _ => false
}
- override def hashCode(): Int = Objects.hashCode(dummyToDoubleListMap)
+ override def hashCode(): Int = {
+ java.util.Arrays.deepHashCode(Array(dummyToStringMap, metricMap))
+ }
}
class DummyBean {