From 5939b75b5fe701cb63fedc64f57c9f0a15ef9202 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 19 Jul 2023 09:26:26 -0400 Subject: [PATCH] [SPARK-44396][CONNECT] Direct Arrow Deserialization ### What changes were proposed in this pull request? This PR adds direct arrow to user object deserialization to the Spark Connect Scala Client. ### Why are the changes needed? We want to decouple the scala client from catalyst. We need a way to encode user object from and to arrrow. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added tests to `ArrowEncoderSuite`. Closes #42011 from hvanhovell/SPARK-44396. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- connector/connect/client/jvm/pom.xml | 19 + .../client/arrow/ScalaCollectionUtils.scala | 38 ++ .../client/arrow/ScalaCollectionUtils.scala | 37 ++ .../sql/connect/client/SparkResult.scala | 230 +++++--- .../client/arrow/ArrowDeserializer.scala | 533 ++++++++++++++++++ .../client/arrow/ArrowEncoderUtils.scala | 3 + .../ConcatenatingArrowStreamReader.scala | 185 ++++++ .../apache/spark/sql/ClientE2ETestSuite.scala | 49 +- .../KeyValueGroupedDatasetE2ETestSuite.scala | 36 +- .../spark/sql/application/ReplE2ESuite.scala | 6 +- .../client/arrow/ArrowEncoderSuite.scala | 127 +++-- 11 files changed, 1085 insertions(+), 178 deletions(-) create mode 100644 connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala create mode 100644 connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index 93cc782ab1354..60ed0f3ba46e4 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -140,6 +140,7 @@ + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes @@ -224,6 +225,24 @@ + + org.codehaus.mojo + build-helper-maven-plugin + + + add-sources + generate-sources + + add-source + + + + src/main/scala-${scala.binary.version} + + + + + \ No newline at end of file diff --git a/connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala b/connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala new file mode 100644 index 0000000000000..c2e01d974e0e4 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala-2.12/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client.arrow + +import scala.collection.generic.{GenericCompanion, GenMapFactory} +import scala.collection.mutable +import scala.reflect.ClassTag + +import org.apache.spark.sql.connect.client.arrow.ArrowDeserializers.resolveCompanion + +/** + * A couple of scala version specific collection utility functions. + */ +private[arrow] object ScalaCollectionUtils { + def getIterableCompanion(tag: ClassTag[_]): GenericCompanion[Iterable] = { + ArrowDeserializers.resolveCompanion[GenericCompanion[Iterable]](tag) + } + def getMapCompanion(tag: ClassTag[_]): GenMapFactory[Map] = { + resolveCompanion[GenMapFactory[Map]](tag) + } + def wrap[T](array: AnyRef): mutable.WrappedArray[T] = { + mutable.WrappedArray.make(array) + } +} diff --git a/connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala b/connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala new file mode 100644 index 0000000000000..8a80e34162283 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala-2.13/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client.arrow + +import scala.collection.{mutable, IterableFactory, MapFactory} +import scala.reflect.ClassTag + +import org.apache.spark.sql.connect.client.arrow.ArrowDeserializers.resolveCompanion + +/** + * A couple of scala version specific collection utility functions. + */ +private[arrow] object ScalaCollectionUtils { + def getIterableCompanion(tag: ClassTag[_]): IterableFactory[Iterable] = { + ArrowDeserializers.resolveCompanion[IterableFactory[Iterable]](tag) + } + def getMapCompanion(tag: ClassTag[_]): MapFactory[Map] = { + resolveCompanion[MapFactory[Map]](tag) + } + def wrap[T](array: AnyRef): mutable.WrappedArray[T] = { + mutable.WrappedArray.make(array.asInstanceOf[Array[T]]) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index a727c86f70fc6..1cdc2035de60b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -16,53 +16,48 @@ */ package org.apache.spark.sql.connect.client -import java.util.Collections +import java.util.Objects -import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.arrow.memory.BufferAllocator -import org.apache.arrow.vector.FieldVector -import org.apache.arrow.vector.ipc.ArrowStreamReader +import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch} +import org.apache.arrow.vector.types.pojo import org.apache.spark.connect.proto -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer -import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable} +import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator} +import org.apache.spark.sql.connect.client.util.Cleanable import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} private[sql] class SparkResult[T]( responses: java.util.Iterator[proto.ExecutePlanResponse], allocator: BufferAllocator, encoder: AgnosticEncoder[T]) extends AutoCloseable - with Cleanable { + with Cleanable { self => private[this] var numRecords: Int = 0 private[this] var structType: StructType = _ - private[this] var boundEncoder: ExpressionEncoder[T] = _ - private[this] var nextBatchIndex: Int = 0 - private val idxToBatches = mutable.Map.empty[Int, ColumnarBatch] - - private def createEncoder(schema: StructType): ExpressionEncoder[T] = { - val agnosticEncoder = createEncoder(encoder, schema).asInstanceOf[AgnosticEncoder[T]] - ExpressionEncoder(agnosticEncoder) - } + private[this] var arrowSchema: pojo.Schema = _ + private[this] var nextResultIndex: Int = 0 + private val resultMap = mutable.Map.empty[Int, (Long, Seq[ArrowMessage])] /** * Update RowEncoder and recursively update the fields of the ProductEncoder if found. */ - private def createEncoder(enc: AgnosticEncoder[_], dataType: DataType): AgnosticEncoder[_] = { + private def createEncoder[E]( + enc: AgnosticEncoder[E], + dataType: DataType): AgnosticEncoder[E] = { enc match { case UnboundRowEncoder => // Replace the row encoder with the encoder inferred from the schema. - RowEncoder.encoderFor(dataType.asInstanceOf[StructType]) + RowEncoder + .encoderFor(dataType.asInstanceOf[StructType]) + .asInstanceOf[AgnosticEncoder[E]] case ProductEncoder(clsTag, fields) if ProductEncoder.isTuple(clsTag) => // Recursively continue updating the tuple product encoder val schema = dataType.asInstanceOf[StructType] @@ -76,53 +71,61 @@ private[sql] class SparkResult[T]( } } - private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean = { - while (responses.hasNext) { + private def processResponses( + stopOnSchema: Boolean = false, + stopOnArrowSchema: Boolean = false, + stopOnFirstNonEmptyResponse: Boolean = false): Boolean = { + var nonEmpty = false + var stop = false + while (!stop && responses.hasNext) { val response = responses.next() if (response.hasSchema) { // The original schema should arrive before ArrowBatches. structType = DataTypeProtoConverter.toCatalystType(response.getSchema).asInstanceOf[StructType] - } else if (response.hasArrowBatch) { + stop |= stopOnSchema + } + if (response.hasArrowBatch) { val ipcStreamBytes = response.getArrowBatch.getData - val reader = new ArrowStreamReader(ipcStreamBytes.newInput(), allocator) - try { - val root = reader.getVectorSchemaRoot - if (structType == null) { - // If the schema is not available yet, fallback to the schema from Arrow. - structType = ArrowUtils.fromArrowSchema(root.getSchema) - } - // TODO: create encoders that directly operate on arrow vectors. - if (boundEncoder == null) { - boundEncoder = createEncoder(structType) - .resolveAndBind(DataTypeUtils.toAttributes(structType)) - } - while (reader.loadNextBatch()) { - val rowCount = root.getRowCount - if (rowCount > 0) { - val vectors = root.getFieldVectors.asScala - .map(v => new ArrowColumnVector(transferToNewVector(v))) - .toArray[ColumnVector] - idxToBatches.put(nextBatchIndex, new ColumnarBatch(vectors, rowCount)) - nextBatchIndex += 1 - numRecords += rowCount - if (stopOnFirstNonEmptyResponse) { - return true - } - } + val reader = new MessageIterator(ipcStreamBytes.newInput(), allocator) + if (arrowSchema == null) { + arrowSchema = reader.schema + stop |= stopOnArrowSchema + } else if (arrowSchema != reader.schema) { + throw new IllegalStateException( + s"""Schema Mismatch between expected and received schema: + |=== Expected Schema === + |$arrowSchema + |=== Received Schema === + |${reader.schema} + |""".stripMargin) + } + if (structType == null) { + // If the schema is not available yet, fallback to the arrow schema. + structType = ArrowUtils.fromArrowSchema(reader.schema) + } + var numRecordsInBatch = 0 + val messages = Seq.newBuilder[ArrowMessage] + while (reader.hasNext) { + val message = reader.next() + message match { + case batch: ArrowRecordBatch => + numRecordsInBatch += batch.getLength + case _ => } - } finally { - reader.close() + messages += message + } + // Skip the entire result if it is empty. + if (numRecordsInBatch > 0) { + numRecords += numRecordsInBatch + resultMap.put(nextResultIndex, (reader.bytesRead, messages.result())) + nextResultIndex += 1 + nonEmpty |= true + stop |= stopOnFirstNonEmptyResponse } } } - false - } - - private def transferToNewVector(in: FieldVector): FieldVector = { - val pair = in.getTransferPair(allocator) - pair.transfer() - pair.getTo.asInstanceOf[FieldVector] + nonEmpty } /** @@ -130,7 +133,7 @@ private[sql] class SparkResult[T]( */ def length: Int = { // We need to process all responses to make sure numRecords is correct. - processResponses(stopOnFirstNonEmptyResponse = false) + processResponses() numRecords } @@ -139,7 +142,9 @@ private[sql] class SparkResult[T]( * the schema of the result. */ def schema: StructType = { - processResponses(stopOnFirstNonEmptyResponse = true) + if (structType == null) { + processResponses(stopOnSchema = true) + } structType } @@ -172,52 +177,93 @@ private[sql] class SparkResult[T]( private def buildIterator(destructive: Boolean): java.util.Iterator[T] with AutoCloseable = { new java.util.Iterator[T] with AutoCloseable { - private[this] var batchIndex: Int = -1 - private[this] var iterator: java.util.Iterator[InternalRow] = Collections.emptyIterator() - private[this] var deserializer: Deserializer[T] = _ + private[this] var iterator: CloseableIterator[T] = _ - override def hasNext: Boolean = { - if (iterator.hasNext) { - return true - } - - val nextBatchIndex = batchIndex + 1 - if (destructive) { - idxToBatches.remove(batchIndex).foreach(_.close()) + private def initialize(): Unit = { + if (iterator == null) { + iterator = new ArrowDeserializingIterator( + createEncoder(encoder, schema), + new ConcatenatingArrowStreamReader( + allocator, + Iterator.single(new ResultMessageIterator(destructive)), + destructive)) } + } - val hasNextBatch = if (!idxToBatches.contains(nextBatchIndex)) { - processResponses(stopOnFirstNonEmptyResponse = true) - } else { - true - } - if (hasNextBatch) { - batchIndex = nextBatchIndex - iterator = idxToBatches(nextBatchIndex).rowIterator() - if (deserializer == null) { - deserializer = boundEncoder.createDeserializer() - } - } - hasNextBatch + override def hasNext: Boolean = { + initialize() + iterator.hasNext } override def next(): T = { - if (!hasNext) { - throw new NoSuchElementException - } - deserializer(iterator.next()) + initialize() + iterator.next() } - override def close(): Unit = SparkResult.this.close() + override def close(): Unit = { + if (iterator != null) { + iterator.close() + } + } } } /** * Close this result, freeing any underlying resources. */ - override def close(): Unit = { - idxToBatches.values.foreach(_.close()) + override def close(): Unit = cleaner.close() + + override val cleaner: AutoCloseable = new SparkResultCloseable(resultMap) + + private class ResultMessageIterator(destructive: Boolean) extends AbstractMessageIterator { + private[this] var totalBytesRead = 0L + private[this] var nextResultIndex = 0 + private[this] var current: Iterator[ArrowMessage] = Iterator.empty + + override def bytesRead: Long = totalBytesRead + + override def schema: pojo.Schema = { + if (arrowSchema == null) { + // We need a schema to proceed. Spark Connect will always + // return a result (with a schema) even if the result is empty. + processResponses(stopOnArrowSchema = true) + Objects.requireNonNull(arrowSchema) + } + arrowSchema + } + + override def hasNext: Boolean = { + if (current.hasNext) { + return true + } + val hasNextResult = if (!resultMap.contains(nextResultIndex)) { + self.processResponses(stopOnFirstNonEmptyResponse = true) + } else { + true + } + if (hasNextResult) { + val Some((sizeInBytes, messages)) = if (destructive) { + resultMap.remove(nextResultIndex) + } else { + resultMap.get(nextResultIndex) + } + totalBytesRead += sizeInBytes + current = messages.iterator + nextResultIndex += 1 + } + hasNextResult + } + + override def next(): ArrowMessage = { + if (!hasNext) { + throw new NoSuchElementException() + } + current.next() + } } +} - override def cleaner: AutoCloseable = AutoCloseables(idxToBatches.values.toSeq) +private[client] class SparkResultCloseable(resultMap: mutable.Map[Int, (Long, Seq[ArrowMessage])]) + extends AutoCloseable { + override def close(): Unit = resultMap.values.foreach(_._2.foreach(_.close())) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala new file mode 100644 index 0000000000000..154866d699a34 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -0,0 +1,533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client.arrow + +import java.io.{ByteArrayInputStream, IOException} +import java.lang.invoke.{MethodHandles, MethodType} +import java.lang.reflect.Modifier +import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInteger} +import java.time._ +import java.util +import java.util.{List => JList, Locale, Map => JMap} + +import scala.collection.mutable +import scala.reflect.ClassTag + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, IntervalYearVector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector, VectorSchemaRoot} +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.arrow.vector.util.Text + +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.types.Decimal + +/** + * Helper class for converting arrow batches into user objects. + */ +object ArrowDeserializers { + import ArrowEncoderUtils._ + + /** + * Create an Iterator of `T`. This iterator takes an Iterator of Arrow IPC Streams, and + * deserializes these streams into one or more instances of `T` + */ + def deserializeFromArrow[T]( + input: Iterator[Array[Byte]], + encoder: AgnosticEncoder[T], + allocator: BufferAllocator): CloseableIterator[T] = { + try { + val reader = new ConcatenatingArrowStreamReader( + allocator, + input.map(bytes => new MessageIterator(new ByteArrayInputStream(bytes), allocator)), + destructive = true) + new ArrowDeserializingIterator(encoder, reader) + } catch { + case _: IOException => + new EmptyDeserializingIterator(encoder) + } + } + + /** + * Create a deserializer of `T` on top of the given `root`. + */ + private[arrow] def deserializerFor[T]( + encoder: AgnosticEncoder[T], + root: VectorSchemaRoot): Deserializer[T] = { + val data: AnyRef = if (encoder.isStruct) { + root + } else { + // The input schema is allowed to have multiple columns, + // by convention we bind to the first one. + root.getVector(0) + } + deserializerFor(encoder, data).asInstanceOf[Deserializer[T]] + } + + private[arrow] def deserializerFor( + encoder: AgnosticEncoder[_], + data: AnyRef): Deserializer[Any] = { + (encoder, data) match { + case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) => + new FieldDeserializer[Boolean, BitVector](v) { + def value(i: Int): Boolean = vector.get(i) != 0 + } + case (PrimitiveByteEncoder | BoxedByteEncoder, v: TinyIntVector) => + new FieldDeserializer[Byte, TinyIntVector](v) { + def value(i: Int): Byte = vector.get(i) + } + case (PrimitiveShortEncoder | BoxedShortEncoder, v: SmallIntVector) => + new FieldDeserializer[Short, SmallIntVector](v) { + def value(i: Int): Short = vector.get(i) + } + case (PrimitiveIntEncoder | BoxedIntEncoder, v: IntVector) => + new FieldDeserializer[Int, IntVector](v) { + def value(i: Int): Int = vector.get(i) + } + case (PrimitiveLongEncoder | BoxedLongEncoder, v: BigIntVector) => + new FieldDeserializer[Long, BigIntVector](v) { + def value(i: Int): Long = vector.get(i) + } + case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: Float4Vector) => + new FieldDeserializer[Float, Float4Vector](v) { + def value(i: Int): Float = vector.get(i) + } + case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: Float8Vector) => + new FieldDeserializer[Double, Float8Vector](v) { + def value(i: Int): Double = vector.get(i) + } + case (NullEncoder, v: NullVector) => + new FieldDeserializer[Any, NullVector](v) { + def value(i: Int): Any = null + } + case (StringEncoder, v: VarCharVector) => + new FieldDeserializer[String, VarCharVector](v) { + def value(i: Int): String = getString(vector, i) + } + case (JavaEnumEncoder(tag), v: VarCharVector) => + // It would be nice if we can get Enum.valueOf working... + val valueOf = methodLookup.findStatic( + tag.runtimeClass, + "valueOf", + MethodType.methodType(tag.runtimeClass, classOf[String])) + new FieldDeserializer[Enum[_], VarCharVector](v) { + def value(i: Int): Enum[_] = { + valueOf.invoke(getString(vector, i)).asInstanceOf[Enum[_]] + } + } + case (ScalaEnumEncoder(parent, _), v: VarCharVector) => + val mirror = scala.reflect.runtime.currentMirror + val module = mirror.classSymbol(parent).module.asModule + val enumeration = mirror.reflectModule(module).instance.asInstanceOf[Enumeration] + new FieldDeserializer[Enumeration#Value, VarCharVector](v) { + def value(i: Int): Enumeration#Value = enumeration.withName(getString(vector, i)) + } + case (BinaryEncoder, v: VarBinaryVector) => + new FieldDeserializer[Array[Byte], VarBinaryVector](v) { + def value(i: Int): Array[Byte] = vector.get(i) + } + case (SparkDecimalEncoder(_), v: DecimalVector) => + new FieldDeserializer[Decimal, DecimalVector](v) { + def value(i: Int): Decimal = Decimal(vector.getObject(i)) + } + case (ScalaDecimalEncoder(_), v: DecimalVector) => + new FieldDeserializer[BigDecimal, DecimalVector](v) { + def value(i: Int): BigDecimal = BigDecimal(vector.getObject(i)) + } + case (JavaDecimalEncoder(_, _), v: DecimalVector) => + new FieldDeserializer[JBigDecimal, DecimalVector](v) { + def value(i: Int): JBigDecimal = vector.getObject(i) + } + case (ScalaBigIntEncoder, v: DecimalVector) => + new FieldDeserializer[BigInt, DecimalVector](v) { + def value(i: Int): BigInt = new BigInt(vector.getObject(i).toBigInteger) + } + case (JavaBigIntEncoder, v: DecimalVector) => + new FieldDeserializer[JBigInteger, DecimalVector](v) { + def value(i: Int): JBigInteger = vector.getObject(i).toBigInteger + } + case (DayTimeIntervalEncoder, v: DurationVector) => + new FieldDeserializer[Duration, DurationVector](v) { + def value(i: Int): Duration = vector.getObject(i) + } + case (YearMonthIntervalEncoder, v: IntervalYearVector) => + new FieldDeserializer[Period, IntervalYearVector](v) { + def value(i: Int): Period = vector.getObject(i).normalized() + } + case (DateEncoder(_), v: DateDayVector) => + new FieldDeserializer[java.sql.Date, DateDayVector](v) { + def value(i: Int): java.sql.Date = DateTimeUtils.toJavaDate(vector.get(i)) + } + case (LocalDateEncoder(_), v: DateDayVector) => + new FieldDeserializer[LocalDate, DateDayVector](v) { + def value(i: Int): LocalDate = DateTimeUtils.daysToLocalDate(vector.get(i)) + } + case (TimestampEncoder(_), v: TimeStampMicroTZVector) => + new FieldDeserializer[java.sql.Timestamp, TimeStampMicroTZVector](v) { + def value(i: Int): java.sql.Timestamp = DateTimeUtils.toJavaTimestamp(vector.get(i)) + } + case (InstantEncoder(_), v: TimeStampMicroTZVector) => + new FieldDeserializer[Instant, TimeStampMicroTZVector](v) { + def value(i: Int): Instant = DateTimeUtils.microsToInstant(vector.get(i)) + } + case (LocalDateTimeEncoder, v: TimeStampMicroVector) => + new FieldDeserializer[LocalDateTime, TimeStampMicroVector](v) { + def value(i: Int): LocalDateTime = DateTimeUtils.microsToLocalDateTime(vector.get(i)) + } + + case (OptionEncoder(value), v) => + val deserializer = deserializerFor(value, v) + new Deserializer[Any] { + override def get(i: Int): Any = Option(deserializer.get(i)) + } + + case (ArrayEncoder(element, _), v: ListVector) => + val deserializer = deserializerFor(element, v.getDataVector) + new FieldDeserializer[AnyRef, ListVector](v) { + def value(i: Int): AnyRef = getArray(vector, i, deserializer)(element.clsTag) + } + + case (IterableEncoder(tag, element, _, _), v: ListVector) => + val deserializer = deserializerFor(element, v.getDataVector) + if (isSubClass(Classes.WRAPPED_ARRAY, tag)) { + // Wrapped array is a bit special because we need to use an array of the element type. + // Some parts of our codebase (unfortunately) rely on this for type inference on results. + new FieldDeserializer[mutable.WrappedArray[Any], ListVector](v) { + def value(i: Int): mutable.WrappedArray[Any] = { + val array = getArray(vector, i, deserializer)(element.clsTag) + ScalaCollectionUtils.wrap(array) + } + } + } else if (isSubClass(Classes.ITERABLE, tag)) { + val companion = ScalaCollectionUtils.getIterableCompanion(tag) + new FieldDeserializer[Iterable[Any], ListVector](v) { + def value(i: Int): Iterable[Any] = { + val builder = companion.newBuilder[Any] + loadListIntoBuilder(vector, i, deserializer, builder) + builder.result() + } + } + } else if (isSubClass(Classes.JLIST, tag)) { + val newInstance = resolveJavaListCreator(tag) + new FieldDeserializer[JList[Any], ListVector](v) { + def value(i: Int): JList[Any] = { + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + val list = newInstance(end - index) + while (index < end) { + list.add(deserializer.get(index)) + index += 1 + } + list + } + } + } else { + throw unsupportedCollectionType(tag.runtimeClass) + } + + case (MapEncoder(tag, key, value, _), v: MapVector) => + val structVector = v.getDataVector.asInstanceOf[StructVector] + val keyDeserializer = deserializerFor(key, structVector.getChild(MapVector.KEY_NAME)) + val valueDeserializer = + deserializerFor(value, structVector.getChild(MapVector.VALUE_NAME)) + if (isSubClass(Classes.MAP, tag)) { + val companion = ScalaCollectionUtils.getMapCompanion(tag) + new FieldDeserializer[Map[Any, Any], MapVector](v) { + def value(i: Int): Map[Any, Any] = { + val builder = companion.newBuilder[Any, Any] + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + builder.sizeHint(end - index) + while (index < end) { + builder += (keyDeserializer.get(index) -> valueDeserializer.get(index)) + index += 1 + } + builder.result() + } + } + } else if (isSubClass(Classes.JMAP, tag)) { + val newInstance = resolveJavaMapCreator(tag) + new FieldDeserializer[JMap[Any, Any], MapVector](v) { + def value(i: Int): JMap[Any, Any] = { + val map = newInstance() + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + while (index < end) { + map.put(keyDeserializer.get(index), valueDeserializer.get(index)) + index += 1 + } + map + } + } + } else { + throw unsupportedCollectionType(tag.runtimeClass) + } + + case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) => + // We should try to make this work with MethodHandles. + val Some(constructor) = + ScalaReflection.findConstructor(tag.runtimeClass, fields.map(_.enc.clsTag.runtimeClass)) + val deserializers = if (isTuple(tag.runtimeClass)) { + fields.zip(vectors).map { case (field, vector) => + deserializerFor(field.enc, vector) + } + } else { + val lookup = createFieldLookup(vectors) + fields.map { field => + deserializerFor(field.enc, lookup(field.name)) + } + } + new StructFieldSerializer[Any](struct) { + def value(i: Int): Any = { + constructor(deserializers.map(_.get(i).asInstanceOf[AnyRef])) + } + } + + case (r @ RowEncoder(fields), StructVectors(struct, vectors)) => + val lookup = createFieldLookup(vectors) + val deserializers = fields.toArray.map { field => + deserializerFor(field.enc, lookup(field.name)) + } + new StructFieldSerializer[Any](struct) { + def value(i: Int): Any = { + val values = deserializers.map(_.get(i)) + new GenericRowWithSchema(values, r.schema) + } + } + + case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) => + val constructor = + methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit])) + val lookup = createFieldLookup(vectors) + val setters = fields.map { field => + val vector = lookup(field.name) + val deserializer = deserializerFor(field.enc, vector) + val setter = methodLookup.findVirtual( + tag.runtimeClass, + field.writeMethod.get, + MethodType.methodType(classOf[Unit], field.enc.clsTag.runtimeClass)) + (bean: Any, i: Int) => setter.invoke(bean, deserializer.get(i)) + } + new StructFieldSerializer[Any](struct) { + def value(i: Int): Any = { + val instance = constructor.invoke() + setters.foreach(_(instance, i)) + instance + } + } + + case (CalendarIntervalEncoder | _: UDTEncoder[_], _) => + throw QueryExecutionErrors.unsupportedDataTypeError(encoder.dataType) + + case _ => + throw new RuntimeException( + s"Unsupported Encoder($encoder)/Vector(${data.getClass}) combination.") + } + } + + private val methodLookup = MethodHandles.lookup() + + /** + * Resolve the companion object for a scala class. In our particular case the class we pass in + * is a Scala collection. We use the companion to create a builder for that collection. + */ + private[arrow] def resolveCompanion[T](tag: ClassTag[_]): T = { + val mirror = scala.reflect.runtime.currentMirror + val module = mirror.classSymbol(tag.runtimeClass).companion.asModule + mirror.reflectModule(module).instance.asInstanceOf[T] + } + + /** + * Create a function that creates a [[util.List]] instance. The int parameter of the creator + * function is a size hint. + * + * If the [[ClassTag]] `tag` points to an interface instead of a concrete class we try to use + * [[util.ArrayList]]. For concrete classes we try to use a constructor that takes a single + * [[Int]] argument, it is assumed this is a size hint. If no such constructor exists we + * fallback to a no-args constructor. + */ + private def resolveJavaListCreator(tag: ClassTag[_]): Int => JList[Any] = { + val cls = tag.runtimeClass + val modifiers = cls.getModifiers + if (Modifier.isInterface(modifiers) || Modifier.isAbstract(modifiers)) { + // Abstract class or interface; we try to use ArrayList. + if (!cls.isAssignableFrom(classOf[util.ArrayList[_]])) { + unsupportedCollectionType(cls) + } + (size: Int) => new util.ArrayList[Any](size) + } else { + try { + // Try to use a constructor that (hopefully) takes a size argument. + val ctor = methodLookup.findConstructor( + tag.runtimeClass, + MethodType.methodType(classOf[Unit], Integer.TYPE)) + size => ctor.invoke(size).asInstanceOf[JList[Any]] + } catch { + case _: java.lang.NoSuchMethodException => + // Use a no-args constructor. + val ctor = + methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit])) + _ => ctor.invoke().asInstanceOf[JList[Any]] + } + } + } + + /** + * Create a function that creates a [[util.Map]] instance. + * + * If the [[ClassTag]] `tag` points to an interface instead of a concrete class we try to use + * [[util.HashMap]]. For concrete classes we try to use a no-args constructor. + */ + private def resolveJavaMapCreator(tag: ClassTag[_]): () => JMap[Any, Any] = { + val cls = tag.runtimeClass + val modifiers = cls.getModifiers + if (Modifier.isInterface(modifiers) || Modifier.isAbstract(modifiers)) { + // Abstract class or interface; we try to use HashMap. + if (!cls.isAssignableFrom(classOf[java.util.HashMap[_, _]])) { + unsupportedCollectionType(cls) + } + () => new util.HashMap[Any, Any]() + } else { + // Use a no-args constructor. + val ctor = + methodLookup.findConstructor(tag.runtimeClass, MethodType.methodType(classOf[Unit])) + () => ctor.invoke().asInstanceOf[JMap[Any, Any]] + } + } + + /** + * Create a function that can lookup one [[FieldVector vectors]] in `fields` by name. This + * lookup is case insensitive. If the schema contains fields with duplicate (with + * case-insensitive resolution) names an exception is thrown. The returned function will throw + * an exception when no column can be found for a name. + * + * A small note on the binding process in general. Over complete schemas are currently allowed, + * meaning that the data can have more column than the encoder. In this the over complete + * (unbound) columns are ignored. + */ + private def createFieldLookup(fields: Seq[FieldVector]): String => FieldVector = { + def toKey(k: String): String = k.toLowerCase(Locale.ROOT) + val lookup = mutable.Map.empty[String, FieldVector] + fields.foreach { field => + val key = toKey(field.getName) + val old = lookup.put(key, field) + if (old.isDefined) { + throw QueryCompilationErrors.ambiguousColumnOrFieldError( + field.getName :: Nil, + fields.count(f => toKey(f.getName) == key)) + } + } + name => { + lookup.getOrElse(toKey(name), throw QueryCompilationErrors.columnNotFoundError(name)) + } + } + + private def isTuple(cls: Class[_]): Boolean = cls.getName.startsWith("scala.Tuple") + + private def getString(v: VarCharVector, i: Int): String = { + // This is currently a bit heavy on allocations: + // - byte array created in VarCharVector.get + // - CharBuffer created CharSetEncoder + // - char array in String + // By using direct buffers and reusing the char buffer + // we could get rid of the first two allocations. + Text.decode(v.get(i)) + } + + private def loadListIntoBuilder( + v: ListVector, + i: Int, + deserializer: Deserializer[Any], + builder: mutable.Builder[Any, _]): Unit = { + var index = v.getElementStartIndex(i) + val end = v.getElementEndIndex(i) + builder.sizeHint(end - index) + while (index < end) { + builder += deserializer.get(index) + index += 1 + } + } + + private def getArray(v: ListVector, i: Int, deserializer: Deserializer[Any])(implicit + tag: ClassTag[Any]): AnyRef = { + val builder = mutable.ArrayBuilder.make[Any] + loadListIntoBuilder(v, i, deserializer, builder) + builder.result() + } + + abstract class Deserializer[+E] { + def get(i: Int): E + } + + abstract class FieldDeserializer[E, V <: FieldVector](val vector: V) extends Deserializer[E] { + def value(i: Int): E + def isNull(i: Int): Boolean = vector.isNull(i) + override def get(i: Int): E = { + if (!isNull(i)) { + value(i) + } else { + null.asInstanceOf[E] + } + } + } + + abstract class StructFieldSerializer[E](v: StructVector) + extends FieldDeserializer[E, StructVector](v) { + override def isNull(i: Int): Boolean = vector != null && vector.isNull(i) + } +} + +class EmptyDeserializingIterator[E](val encoder: AgnosticEncoder[E]) + extends CloseableIterator[E] { + override def close(): Unit = () + override def hasNext: Boolean = false + override def next(): E = throw new NoSuchElementException() +} + +class ArrowDeserializingIterator[E]( + val encoder: AgnosticEncoder[E], + private[this] val reader: ArrowReader) + extends CloseableIterator[E] { + private[this] var index = 0 + private[this] val root = reader.getVectorSchemaRoot + private[this] val deserializer = ArrowDeserializers.deserializerFor(encoder, root) + + override def hasNext: Boolean = { + if (index >= root.getRowCount) { + if (reader.loadNextBatch()) { + index = 0 + } + } + index < root.getRowCount + } + + override def next(): E = { + if (!hasNext) { + throw new NoSuchElementException() + } + val result = deserializer.get(index) + index += 1 + result + } + + override def close(): Unit = reader.close() +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala index f6b140bae557b..ed27336985416 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala @@ -24,8 +24,11 @@ import org.apache.arrow.vector.complex.StructVector private[arrow] object ArrowEncoderUtils { object Classes { + val WRAPPED_ARRAY: Class[_] = classOf[scala.collection.mutable.WrappedArray[_]] val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]] + val MAP: Class[_] = classOf[scala.collection.Map[_, _]] val JLIST: Class[_] = classOf[java.util.List[_]] + val JMAP: Class[_] = classOf[java.util.Map[_, _]] } def isSubClass(cls: Class[_], tag: ClassTag[_]): Boolean = { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala new file mode 100644 index 0000000000000..90963c831c252 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client.arrow + +import java.io.{InputStream, IOException} +import java.nio.channels.Channels + +import org.apache.arrow.flatbuf.MessageHeader +import org.apache.arrow.memory.{ArrowBuf, BufferAllocator} +import org.apache.arrow.vector.ipc.{ArrowReader, ReadChannel} +import org.apache.arrow.vector.ipc.message.{ArrowDictionaryBatch, ArrowMessage, ArrowRecordBatch, MessageChannelReader, MessageResult, MessageSerializer} +import org.apache.arrow.vector.types.pojo.Schema + +/** + * An [[ArrowReader]] that concatenates multiple [[MessageIterator]]s into a single stream. Each + * iterator represents a single IPC stream. The concatenated streams all must have the same + * schema. If the schema is different an exception is thrown. + * + * In some cases we want to retain the messages (see `SparkResult`). Normally a stream reader + * closes its messages when it consumes them. In order to prevent that from happening in + * non-destructive mode we clone the messages before passing them to the reading logic. + */ +class ConcatenatingArrowStreamReader( + allocator: BufferAllocator, + input: Iterator[AbstractMessageIterator], + destructive: Boolean) + extends ArrowReader(allocator) { + + private[this] var totalBytesRead: Long = 0 + private[this] var current: AbstractMessageIterator = _ + + override protected def readSchema(): Schema = { + // readSchema() should only be called once during initialization. + assert(current == null) + if (!input.hasNext) { + // ArrowStreamReader throws the same exception. + throw new IOException("Unexpected end of input. Missing schema.") + } + current = input.next() + current.schema + } + + private def nextMessage(): ArrowMessage = { + // readSchema() should have been invoked at this point so 'current' should be initialized. + assert(current != null) + // Try to find a non-empty message iterator. + while (!current.hasNext && input.hasNext) { + totalBytesRead += current.bytesRead + current = input.next() + if (current.schema != getVectorSchemaRoot.getSchema) { + throw new IllegalStateException() + } + } + if (current.hasNext) { + current.next() + } else { + null + } + } + + override def loadNextBatch(): Boolean = { + // Keep looping until we load a non-empty batch or until we exhaust the input. + var message = nextMessage() + while (message != null) { + message match { + case rb: ArrowRecordBatch => + loadRecordBatch(cloneIfNonDestructive(rb)) + if (getVectorSchemaRoot.getRowCount > 0) { + return true + } + case db: ArrowDictionaryBatch => + loadDictionary(cloneIfNonDestructive(db)) + } + message = nextMessage() + } + false + } + + private def cloneIfNonDestructive(batch: ArrowRecordBatch): ArrowRecordBatch = { + if (destructive) { + return batch + } + cloneRecordBatch(batch) + } + + private def cloneIfNonDestructive(batch: ArrowDictionaryBatch): ArrowDictionaryBatch = { + if (destructive) { + return batch + } + new ArrowDictionaryBatch( + batch.getDictionaryId, + cloneRecordBatch(batch.getDictionary), + batch.isDelta) + } + + private def cloneRecordBatch(batch: ArrowRecordBatch): ArrowRecordBatch = { + new ArrowRecordBatch( + batch.getLength, + batch.getNodes, + batch.getBuffers, + batch.getBodyCompression, + true, + true) + } + + override def bytesRead(): Long = { + if (current != null) { + totalBytesRead + current.bytesRead + } else { + 0 + } + } + + override def closeReadSource(): Unit = () +} + +trait AbstractMessageIterator extends Iterator[ArrowMessage] { + def schema: Schema + def bytesRead: Long +} + +/** + * Decode an Arrow IPC stream into individual messages. Please note that this iterator MUST have a + * valid IPC stream as its input, otherwise construction will fail. + */ +class MessageIterator(input: InputStream, allocator: BufferAllocator) + extends AbstractMessageIterator { + private[this] val in = new ReadChannel(Channels.newChannel(input)) + private[this] val reader = new MessageChannelReader(in, allocator) + private[this] var result: MessageResult = _ + + // Eagerly read the schema. + val schema: Schema = { + val result = reader.readNext() + if (result == null) { + throw new IOException("Unexpected end of input. Missing schema.") + } + MessageSerializer.deserializeSchema(result.getMessage) + } + + override def bytesRead: Long = reader.bytesRead() + + override def hasNext: Boolean = { + if (result == null) { + result = reader.readNext() + } + result != null + } + + override def next(): ArrowMessage = { + if (!hasNext) { + throw new NoSuchElementException() + } + val message = result.getMessage.headerType() match { + case MessageHeader.RecordBatch => + MessageSerializer.deserializeRecordBatch(result.getMessage, bodyBuffer(result)) + case MessageHeader.DictionaryBatch => + MessageSerializer.deserializeDictionaryBatch(result.getMessage, bodyBuffer(result)) + } + result = null + message + } + + private def bodyBuffer(result: MessageResult): ArrowBuf = { + var buffer = result.getBodyBuffer + if (buffer == null) { + buffer = allocator.getEmpty + } + buffer + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 73c04389c0597..07dd2a96bd8f7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -39,7 +39,6 @@ import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils.port import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { @@ -571,7 +570,8 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM (col("id") / lit(10.0d)).as("b"), col("id"), lit("world").as("d"), - (col("id") % 2).cast("int").as("a")) + // TODO SPARK-44449 make this int again when upcasting is in. + (col("id") % 2).cast("double").as("a")) private def validateMyTypeResult(result: Array[MyType]): Unit = { result.zipWithIndex.foreach { case (MyType(id, a, b), i) => @@ -818,10 +818,11 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM } test("toJSON") { + // TODO SPARK-44449 make this int again when upcasting is in. val expected = Array( - """{"b":0.0,"id":0,"d":"world","a":0}""", - """{"b":0.1,"id":1,"d":"world","a":1}""", - """{"b":0.2,"id":2,"d":"world","a":0}""") + """{"b":0.0,"id":0,"d":"world","a":0.0}""", + """{"b":0.1,"id":1,"d":"world","a":1.0}""", + """{"b":0.2,"id":2,"d":"world","a":0.0}""") val result = spark .range(3) .select(generateMyTypeColumns: _*) @@ -893,14 +894,12 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM test("Dataset result destructive iterator") { // Helper methods for accessing private field `idxToBatches` from SparkResult - val _idxToBatches = - PrivateMethod[mutable.Map[Int, ColumnarBatch]](Symbol("idxToBatches")) + val getResultMap = + PrivateMethod[mutable.Map[Int, Any]](Symbol("resultMap")) - def getColumnarBatches(result: SparkResult[_]): Seq[ColumnarBatch] = { - val idxToBatches = result invokePrivate _idxToBatches() - - // Sort by key to get stable results. - idxToBatches.toSeq.sortBy(_._1).map(_._2) + def assertResultsMapEmpty(result: SparkResult[_]): Unit = { + val resultMap = result invokePrivate getResultMap() + assert(resultMap.isEmpty) } val df = spark @@ -911,25 +910,19 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM try { // build and verify the destructive iterator val iterator = result.destructiveIterator - // batches is empty before traversing the result iterator - assert(getColumnarBatches(result).isEmpty) - var previousBatch: ColumnarBatch = null - val buffer = mutable.Buffer.empty[Long] + // resultMap Map is empty before traversing the result iterator + assertResultsMapEmpty(result) + val buffer = mutable.Set.empty[Long] while (iterator.hasNext) { - // always having 1 batch, since a columnar batch will be removed and closed after - // its data got consumed. - val batches = getColumnarBatches(result) - assert(batches.size === 1) - assert(batches.head != previousBatch) - previousBatch = batches.head - - buffer.append(iterator.next()) + // resultMap is empty during iteration because results get removed immediately on access. + assertResultsMapEmpty(result) + buffer += iterator.next() } - // Batches should be closed and removed after traversing all the records. - assert(getColumnarBatches(result).isEmpty) + // resultMap Map is empty afterward because all results have been removed. + assertResultsMapEmpty(result) - val expectedResult = Seq(6L, 7L, 8L) - assert(buffer.size === 3 && expectedResult.forall(buffer.contains)) + val expectedResult = Set(6L, 7L, 8L) + assert(buffer.size === 3 && expectedResult == buffer) } finally { result.close() } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala index e15069f2d9e96..ab3e13da53178 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala @@ -68,10 +68,11 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("keyAs - keys") { + // TODO SPARK-44449 make this long again when upcasting is in. // It is okay to cast from Long to Double, but not Long to Int. val values = spark .range(10) - .groupByKey(v => v % 2) + .groupByKey(v => (v % 2).toDouble) .keyAs[Double] .keys .collectAsList() @@ -232,9 +233,10 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("agg, keyAs") { + // TODO SPARK-44449 make this long again when upcasting is in. val ds = spark .range(10) - .groupByKey(v => v % 2) + .groupByKey(v => (v % 2).toDouble) .keyAs[Double] .agg(count("*")) @@ -244,7 +246,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { test("typed aggregation: expr") { val session: SparkSession = spark import session.implicits._ - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long]), @@ -254,7 +257,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]), @@ -264,7 +268,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], count("*")), @@ -274,7 +279,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -289,7 +295,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -305,7 +312,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -322,7 +330,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -340,7 +349,8 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { } test("typed aggregation: expr, expr, expr, expr, expr, expr, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + // TODO SPARK-44449 make this int again when upcasting is in. + val ds = Seq(("a", 10L), ("a", 20L), ("b", 1L), ("b", 2L), ("c", 1L)).toDS() checkDatasetUnorderly( ds.groupByKey(_._1) @@ -473,9 +483,9 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1)) .toDF("key", "seq", "value") val grouped = ds.groupBy($"value").as[String, (String, Int, Int)] - val keys = grouped.keyAs[String].keys.sort($"value") - - checkDataset(keys, "1", "2", "10", "20") + // TODO SPARK-44449 make this string again when upcasting is in. + val keys = grouped.keyAs[Int].keys.sort($"value") + checkDataset(keys, 1, 2, 10, 20) } test("flatMapGroupsWithState") { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index 58758a1384031..800ce43a60df0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -208,8 +208,9 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { } test("UDF Registration") { + // TODO SPARK-44449 make this long again when upcasting is in. val input = """ - |class A(x: Int) { def get = x * 100 } + |class A(x: Int) { def get: Long = x * 100 } |val myUdf = udf((x: Int) => new A(x).get) |spark.udf.register("dummyUdf", myUdf) |spark.sql("select dummyUdf(id) from range(5)").as[Long].collect() @@ -219,8 +220,9 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { } test("UDF closure registration") { + // TODO SPARK-44449 make this int again when upcasting is in. val input = """ - |class A(x: Int) { def get = x * 15 } + |class A(x: Int) { def get: Long = x * 15 } |spark.udf.register("directUdf", (x: Int) => new A(x).get) |spark.sql("select directUdf(id) from range(5)").as[Long].collect() """.stripMargin diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 0c327484e477d..16eec3eee3110 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -21,24 +21,19 @@ import java.util import java.util.{Collections, Objects} import scala.beans.BeanProperty -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.classTag -import scala.util.control.NonFatal -import com.google.protobuf.ByteString import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector.VarBinaryVector import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.connect.proto -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedIntEncoder, CalendarIntervalEncoder, DateEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, RowEncoder, StringEncoder, TimestampEncoder, UDTEncoder} import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder} -import org.apache.spark.sql.connect.client.SparkResult import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum import org.apache.spark.sql.connect.client.util.ConnectFunSuite import org.apache.spark.sql.types.{ArrayType, DataType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType} @@ -96,15 +91,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } val resultIterator = - try { - deserializeFromArrow(inspectedIterator, encoder, deserializerAllocator) - } catch { - case NonFatal(e) => - arrowIterator.close() - serializerAllocator.close() - deserializerAllocator.close() - throw e - } + ArrowDeserializers.deserializeFromArrow(inspectedIterator, encoder, deserializerAllocator) new CloseableIterator[T] { override def close(): Unit = { arrowIterator.close() @@ -117,25 +104,6 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } } - // Temporary hack until we merge the deserializer. - private def deserializeFromArrow[E]( - batches: Iterator[Array[Byte]], - encoder: AgnosticEncoder[E], - allocator: BufferAllocator): CloseableIterator[E] = { - val responses = batches.map { batch => - val builder = proto.ExecutePlanResponse.newBuilder() - builder.getArrowBatchBuilder.setData(ByteString.copyFrom(batch)) - builder.build() - } - val result = new SparkResult[E](responses.asJava, allocator, encoder) - new CloseableIterator[E] { - private val itr = result.iterator - override def close(): Unit = itr.close() - override def hasNext: Boolean = itr.hasNext - override def next(): E = itr.next() - } - } - private def roundTripAndCheck[T]( encoder: AgnosticEncoder[T], toInputIterator: () => Iterator[Any], @@ -246,6 +214,15 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { assert(inspector.sizeInBytes > 0) } + test("deserializing empty iterator") { + withAllocator { allocator => + val iterator = + ArrowDeserializers.deserializeFromArrow(Iterator.empty, singleIntEncoder, allocator) + assert(iterator.isEmpty) + assert(allocator.getAllocatedMemory == 0) + } + } + test("single batch") { val inspector = new CountingBatchInspector roundTripAndCheckIdentical(singleIntEncoder, inspectBatch = inspector) { () => @@ -533,15 +510,22 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { val maybeNull = MaybeNull(11) Iterator.tabulate(100) { i => val bean = new JavaMapData - bean.setDummyToDoubleListMap(maybeNull { - val map = new util.HashMap[DummyBean, java.util.List[java.lang.Double]] - (0 until (i % 5)).foreach { j => - val dummy = new DummyBean - dummy.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i * j))) + bean.setMetricMap(maybeNull { + val map = new util.HashMap[String, util.List[java.lang.Double]] + (0 until (i % 20)).foreach { i => val values = Array.tabulate(i % 40) { j => Double.box(j.toDouble) } - map.put(dummy, maybeNull(util.Arrays.asList(values: _*))) + map.put("k" + i, maybeNull(util.Arrays.asList(values: _*))) + } + map + }) + bean.setDummyToStringMap(maybeNull { + val map = new util.HashMap[DummyBean, String] + (0 until (i % 5)).foreach { j => + val dummy = new DummyBean + dummy.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i * j))) + map.put(dummy, maybeNull("s" + i + "v" + j)) } map }) @@ -675,6 +659,57 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { .add("Ca", "array") .add("Cb", "binary"))) + test("bind to schema") { + // Binds to a wider schema. The narrow schema has fewer (nested) fields, has a slightly + // different field order, and uses different cased names in a couple of places. + withAllocator { allocator => + val input = Row( + 887, + "foo", + Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte), 5f), + Seq(Row(null, "a", false), Row(javaBigDecimal(57853, 10), "b", false))) + val expected = Row( + "foo", + Seq(Row(null, false), Row(javaBigDecimal(57853, 10), false)), + Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte))) + val arrowBatches = serializeToArrow(Iterator.single(input), wideSchemaEncoder, allocator) + val result = + ArrowDeserializers.deserializeFromArrow(arrowBatches, narrowSchemaEncoder, allocator) + val actual = result.next() + assert(result.isEmpty) + assert(expected === actual) + result.close() + arrowBatches.close() + } + } + + test("unknown field") { + withAllocator { allocator => + val arrowBatches = serializeToArrow(Iterator.empty, narrowSchemaEncoder, allocator) + intercept[AnalysisException] { + ArrowDeserializers.deserializeFromArrow(arrowBatches, wideSchemaEncoder, allocator) + } + arrowBatches.close() + } + } + + test("duplicate fields") { + val duplicateSchemaEncoder = toRowEncoder( + new StructType() + .add("foO", "string") + .add("Foo", "string")) + val fooSchemaEncoder = toRowEncoder( + new StructType() + .add("foo", "string")) + withAllocator { allocator => + val arrowBatches = serializeToArrow(Iterator.empty, duplicateSchemaEncoder, allocator) + intercept[AnalysisException] { + ArrowDeserializers.deserializeFromArrow(arrowBatches, fooSchemaEncoder, allocator) + } + arrowBatches.close() + } + } + /* ******************************************************************** * * Arrow serialization/deserialization specific errors * ******************************************************************** */ @@ -833,17 +868,23 @@ case class MapData(intStringMap: Map[Int, String], metricMap: Map[String, Array[ class JavaMapData { @scala.beans.BeanProperty - var dummyToDoubleListMap: java.util.Map[DummyBean, java.util.List[java.lang.Double]] = _ + var dummyToStringMap: java.util.Map[DummyBean, String] = _ + + @scala.beans.BeanProperty + var metricMap: java.util.HashMap[String, java.util.List[java.lang.Double]] = _ def canEqual(other: Any): Boolean = other.isInstanceOf[JavaMapData] override def equals(other: Any): Boolean = other match { case that: JavaMapData if that canEqual this => - dummyToDoubleListMap == that.dummyToDoubleListMap + dummyToStringMap == that.dummyToStringMap && + metricMap == that.metricMap case _ => false } - override def hashCode(): Int = Objects.hashCode(dummyToDoubleListMap) + override def hashCode(): Int = { + java.util.Arrays.deepHashCode(Array(dummyToStringMap, metricMap)) + } } class DummyBean {