From ccc81fd85cd873bccb83a8baeb6c00070fe66e46 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 30 Mar 2023 13:27:33 -0400 Subject: [PATCH 01/16] Add direct arrow serialization --- connector/connect/client/jvm/pom.xml | 13 + .../org/apache/spark/sql/SparkSession.scala | 5 +- .../sql/connect/client/SparkResult.scala | 1 - .../client/arrow/ArrowEncoderUtils.scala | 53 ++ .../client/arrow/ArrowSerializer.scala | 529 +++++++++++ .../client/arrow/ArrowEncoderSuite.scala | 837 ++++++++++++++++++ .../sql/catalyst/JavaTypeInferenceSuite.scala | 9 +- 7 files changed, 1443 insertions(+), 4 deletions(-) create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala create mode 100644 connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index 8543057d0c0d1..1be149803c9b0 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -120,6 +120,19 @@ + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + com.google.guava + guava + + + org.scalacheck scalacheck_${scala.binary.version} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 548545b969d5a..da8852a97d125 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder} import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult} -import org.apache.spark.sql.connect.client.util.{Cleaner, ConvertToArrow} +import org.apache.spark.sql.connect.client.arrow.ArrowSerializer +import org.apache.spark.sql.connect.client.util.Cleaner import org.apache.spark.sql.types.StructType /** @@ -118,7 +119,7 @@ class SparkSession private[sql] ( .setSchema(encoder.schema.json) if (data.nonEmpty) { val timeZoneId = conf.get("spark.sql.session.timeZone") - val arrowData = ConvertToArrow(encoder, data, timeZoneId, allocator) + val arrowData = ArrowSerializer.serialize(data, encoder, allocator, timeZoneId) localRelationBuilder.setData(arrowData) } } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index 80db558918bba..39aed614e3f48 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -72,7 +72,6 @@ private[sql] class SparkResult[T]( } while (reader.loadNextBatch()) { val rowCount = root.getRowCount - assert(root.getRowCount == response.getArrowBatch.getRowCount) // HUH! if (rowCount > 0) { val vectors = root.getFieldVectors.asScala .map(v => new ArrowColumnVector(transferToNewVector(v))) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala new file mode 100644 index 0000000000000..d022d3005b5ff --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client.arrow + +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} +import org.apache.arrow.vector.complex.StructVector + +private[arrow] object ArrowEncoderUtils { + object Classes { + val WRAPPED_ARRAY: Class[_] = classOf[scala.collection.mutable.WrappedArray[_]] + val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]] + val SEQ: Class[_] = classOf[scala.collection.Seq[_]] + val SET: Class[_] = classOf[scala.collection.Set[_]] + val MAP: Class[_] = classOf[scala.collection.Map[_, _]] + val JLIST: Class[_] = classOf[java.util.List[_]] + val JMAP: Class[_] = classOf[java.util.Map[_, _]] + } + + def isSubClass(cls: Class[_], tag: ClassTag[_]): Boolean = { + cls.isAssignableFrom(tag.runtimeClass) + } + + def unsupportedCollectionType(cls: Class[_]): Nothing = { + throw new RuntimeException(s"Unsupported collection type: $cls") + } +} + +trait CloseableIterator[E] extends Iterator[E] with AutoCloseable + +private[arrow] object StructVectors { + def unapply(v: AnyRef): Option[(StructVector, Seq[FieldVector])] = v match { + case root: VectorSchemaRoot => Option((null, root.getFieldVectors.asScala)) + case struct: StructVector => Option((struct, struct.getChildrenFromFields.asScala)) + case _ => None + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala new file mode 100644 index 0000000000000..038e6a49516cd --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -0,0 +1,529 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client.arrow + +import java.io.{ByteArrayOutputStream, OutputStream} +import java.lang.invoke.{MethodHandles, MethodType} +import java.math.{BigDecimal => JBigDecimal, BigInteger => JBigInteger} +import java.nio.channels.Channels +import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} +import java.util.{Map => JMap} + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import com.google.protobuf.ByteString +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, DurationVector, FieldVector, Float4Vector, Float8Vector, IntervalYearVector, IntVector, NullVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.arrow.vector.ipc.{ArrowStreamWriter, WriteChannel} +import org.apache.arrow.vector.ipc.message.{IpcOption, MessageSerializer} +import org.apache.arrow.vector.util.Text + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.DefinedByConstructorParams +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.util.ArrowUtils + +/** + * Helper class for converting user objects into arrow batches. + */ +class ArrowSerializer[T]( + private[this] val enc: AgnosticEncoder[T], + private[this] val allocator: BufferAllocator, + private[this] val timeZoneId: String) { + private val (root, serializer) = ArrowSerializer.serializerFor(enc, allocator, timeZoneId) + private val vectors = root.getFieldVectors.asScala + private val unloader = new VectorUnloader(root) + private val schemaBytes = { + // Only serialize the schema once. + val bytes = new ByteArrayOutputStream() + MessageSerializer.serialize(newChannel(bytes), root.getSchema) + bytes.toByteArray + } + private var i: Int = 0 + + private def newChannel(output: OutputStream): WriteChannel = { + new WriteChannel(Channels.newChannel(output)) + } + + /** + * The size of the current batch. + * + * The size computed consist of the size of the schema and the size of the arrow buffers. The + * actual batch will be larger than that because of alignment, written IPC tokens, and the + * written record batch metadata. The size of the record batch metadata is proportional to the + * complexity of the schema. + */ + def sizeInBytes: Long = { + // We need to set the row count for getBufferSize to return the actual value. + root.setRowCount(i) + schemaBytes.length + vectors.map(_.getBufferSize).sum + } + + /** + * Append a record to the current batch. + */ + def append(record: T): Unit = { + serializer.write(i, record) + i += 1 + } + + /** + * Write the schema and the current batch in Arrow IPC stream format to the [[OutputStream]]. + */ + def writeIpcStream(output: OutputStream): Unit = { + val channel = newChannel(output) + root.setRowCount(i) + val batch = unloader.getRecordBatch + try { + channel.write(schemaBytes) + MessageSerializer.serialize(channel, batch) + ArrowStreamWriter.writeEndOfStream(channel, IpcOption.DEFAULT) + } finally { + batch.close() + } + } + + /** + * Reset the serializer. + */ + def reset(): Unit = { + i = 0 + vectors.foreach(_.reset()) + } + + /** + * Close the serializer. + */ + def close(): Unit = root.close() +} + +object ArrowSerializer { + import ArrowEncoderUtils._ + + /** + * Create an [[Iterator]] that converts the input [[Iterator]] of type `T` into an [[Iterator]] + * of Arrow IPC Streams. + */ + def serialize[T]( + input: Iterator[T], + enc: AgnosticEncoder[T], + allocator: BufferAllocator, + maxRecordsPerBatch: Int, + maxBatchSize: Long, + timeZoneId: String, + batchSizeCheckInterval: Int = 128): CloseableIterator[Array[Byte]] = { + assert(maxRecordsPerBatch > 0) + assert(maxBatchSize > 0) + assert(batchSizeCheckInterval > 0) + new CloseableIterator[Array[Byte]] { + private val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId) + private val bytes = new ByteArrayOutputStream + private var hasWrittenFirstBatch = false + + /** + * Periodical check to make sure we don't go over the size threshold by too much. + */ + private def sizeOk(i: Int): Boolean = { + if (i > 0 && i % batchSizeCheckInterval == 0) { + return serializer.sizeInBytes < maxBatchSize + } + true + } + + override def hasNext: Boolean = input.hasNext || !hasWrittenFirstBatch + + override def next(): Array[Byte] = { + if (!hasNext) { + throw new NoSuchElementException() + } + serializer.reset() + bytes.reset() + var i = 0 + while (i < maxRecordsPerBatch && input.hasNext && sizeOk(i)) { + serializer.append(input.next()) + i += 1 + } + serializer.writeIpcStream(bytes) + hasWrittenFirstBatch = true + bytes.toByteArray + } + + override def close(): Unit = serializer.close() + } + } + + def serialize[T]( + input: Iterator[T], + enc: AgnosticEncoder[T], + allocator: BufferAllocator, + timeZoneId: String): ByteString = { + val serializer = new ArrowSerializer[T](enc, allocator, timeZoneId) + serializer.reset() + input.foreach(serializer.append) + val output = ByteString.newOutput() + serializer.writeIpcStream(output) + output.toByteString + } + + /** + * Create a (root) [[Serializer]] for [[AgnosticEncoder]] `encoder`. + * + * The serializer returned by this method is NOT thread-safe. + */ + def serializerFor[T]( + encoder: AgnosticEncoder[T], + allocator: BufferAllocator, + timeZoneId: String): (VectorSchemaRoot, Serializer) = { + val arrowSchema = ArrowUtils.toArrowSchema(encoder.schema, timeZoneId) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val serializer = if (encoder.schema != encoder.dataType) { + assert(root.getSchema.getFields.size() == 1) + serializerFor(encoder, root.getVector(0)) + } else { + serializerFor(encoder, root) + } + root -> serializer + } + + // TODO throw better errors on class cast exceptions. + private[arrow] def serializerFor[E](encoder: AgnosticEncoder[E], v: AnyRef): Serializer = { + (encoder, v) match { + case (PrimitiveBooleanEncoder | BoxedBooleanEncoder, v: BitVector) => + new FieldSerializer[Boolean, BitVector](v) { + override def set(index: Int, value: Boolean): Unit = + vector.setSafe(index, if (value) 1 else 0) + } + case (PrimitiveByteEncoder | BoxedByteEncoder, v: TinyIntVector) => + new FieldSerializer[Byte, TinyIntVector](v) { + override def set(index: Int, value: Byte): Unit = vector.setSafe(index, value) + } + case (PrimitiveShortEncoder | BoxedShortEncoder, v: SmallIntVector) => + new FieldSerializer[Short, SmallIntVector](v) { + override def set(index: Int, value: Short): Unit = vector.setSafe(index, value) + } + case (PrimitiveIntEncoder | BoxedIntEncoder, v: IntVector) => + new FieldSerializer[Int, IntVector](v) { + override def set(index: Int, value: Int): Unit = vector.setSafe(index, value) + } + case (PrimitiveLongEncoder | BoxedLongEncoder, v: BigIntVector) => + new FieldSerializer[Long, BigIntVector](v) { + override def set(index: Int, value: Long): Unit = vector.setSafe(index, value) + } + case (PrimitiveFloatEncoder | BoxedFloatEncoder, v: Float4Vector) => + new FieldSerializer[Float, Float4Vector](v) { + override def set(index: Int, value: Float): Unit = vector.setSafe(index, value) + } + case (PrimitiveDoubleEncoder | BoxedDoubleEncoder, v: Float8Vector) => + new FieldSerializer[Double, Float8Vector](v) { + override def set(index: Int, value: Double): Unit = vector.setSafe(index, value) + } + case (NullEncoder, v: NullVector) => + new FieldSerializer[Unit, NullVector](v) { + override def set(index: Int, value: Unit): Unit = vector.setNull(index) + } + case (StringEncoder, v: VarCharVector) => + new FieldSerializer[String, VarCharVector](v) { + override def set(index: Int, value: String): Unit = setString(v, index, value) + } + case (JavaEnumEncoder(_), v: VarCharVector) => + new FieldSerializer[Enum[_], VarCharVector](v) { + override def set(index: Int, value: Enum[_]): Unit = setString(v, index, value.name()) + } + case (ScalaEnumEncoder(_, _), v: VarCharVector) => + new FieldSerializer[Enumeration#Value, VarCharVector](v) { + override def set(index: Int, value: Enumeration#Value): Unit = + setString(v, index, value.toString) + } + case (BinaryEncoder, v: VarBinaryVector) => + new FieldSerializer[Array[Byte], VarBinaryVector](v) { + override def set(index: Int, value: Array[Byte]): Unit = vector.setSafe(index, value) + } + case (SparkDecimalEncoder(_), v: DecimalVector) => + new FieldSerializer[Decimal, DecimalVector](v) { + override def set(index: Int, value: Decimal): Unit = + setDecimal(vector, index, value.toJavaBigDecimal) + } + case (ScalaDecimalEncoder(_), v: DecimalVector) => + new FieldSerializer[BigDecimal, DecimalVector](v) { + override def set(index: Int, value: BigDecimal): Unit = + setDecimal(vector, index, value.bigDecimal) + } + case (JavaDecimalEncoder(_, false), v: DecimalVector) => + new FieldSerializer[JBigDecimal, DecimalVector](v) { + override def set(index: Int, value: JBigDecimal): Unit = + setDecimal(vector, index, value) + } + case (JavaDecimalEncoder(_, true), v: DecimalVector) => + new FieldSerializer[Any, DecimalVector](v) { + override def set(index: Int, value: Any): Unit = { + val decimal = value match { + case j: JBigDecimal => j + case d: BigDecimal => d.bigDecimal + case k: BigInt => new JBigDecimal(k.bigInteger) + case l: JBigInteger => new JBigDecimal(l) + case d: Decimal => d.toJavaBigDecimal + } + setDecimal(vector, index, decimal) + } + } + case (ScalaBigIntEncoder, v: DecimalVector) => + new FieldSerializer[BigInt, DecimalVector](v) { + override def set(index: Int, value: BigInt): Unit = + setDecimal(vector, index, new JBigDecimal(value.bigInteger)) + } + case (JavaBigIntEncoder, v: DecimalVector) => + new FieldSerializer[JBigInteger, DecimalVector](v) { + override def set(index: Int, value: JBigInteger): Unit = + setDecimal(vector, index, new JBigDecimal(value)) + } + case (DayTimeIntervalEncoder, v: DurationVector) => + new FieldSerializer[Duration, DurationVector](v) { + override def set(index: Int, value: Duration): Unit = + vector.setSafe(index, IntervalUtils.durationToMicros(value)) + } + case (YearMonthIntervalEncoder, v: IntervalYearVector) => + new FieldSerializer[Period, IntervalYearVector](v) { + override def set(index: Int, value: Period): Unit = + vector.setSafe(index, IntervalUtils.periodToMonths(value)) + } + case (DateEncoder(true) | LocalDateEncoder(true), v: DateDayVector) => + new FieldSerializer[Any, DateDayVector](v) { + override def set(index: Int, value: Any): Unit = + vector.setSafe(index, DateTimeUtils.anyToDays(value)) + } + case (DateEncoder(false), v: DateDayVector) => + new FieldSerializer[java.sql.Date, DateDayVector](v) { + override def set(index: Int, value: java.sql.Date): Unit = + vector.setSafe(index, DateTimeUtils.fromJavaDate(value)) + } + case (LocalDateEncoder(false), v: DateDayVector) => + new FieldSerializer[LocalDate, DateDayVector](v) { + override def set(index: Int, value: LocalDate): Unit = + vector.setSafe(index, DateTimeUtils.localDateToDays(value)) + } + case (TimestampEncoder(true) | InstantEncoder(true), v: TimeStampMicroTZVector) => + new FieldSerializer[Any, TimeStampMicroTZVector](v) { + override def set(index: Int, value: Any): Unit = + vector.setSafe(index, DateTimeUtils.anyToMicros(value)) + } + case (TimestampEncoder(false), v: TimeStampMicroTZVector) => + new FieldSerializer[java.sql.Timestamp, TimeStampMicroTZVector](v) { + override def set(index: Int, value: java.sql.Timestamp): Unit = + vector.setSafe(index, DateTimeUtils.fromJavaTimestamp(value)) + } + case (InstantEncoder(false), v: TimeStampMicroTZVector) => + new FieldSerializer[Instant, TimeStampMicroTZVector](v) { + override def set(index: Int, value: Instant): Unit = + vector.setSafe(index, DateTimeUtils.instantToMicros(value)) + } + case (LocalDateTimeEncoder, v: TimeStampMicroVector) => + new FieldSerializer[LocalDateTime, TimeStampMicroVector](v) { + override def set(index: Int, value: LocalDateTime): Unit = + vector.setSafe(index, DateTimeUtils.localDateTimeToMicros(value)) + } + + case (OptionEncoder(value), v) => + new Serializer { + private[this] val delegate: Serializer = serializerFor(value, v) + override def write(index: Int, value: Any): Unit = value match { + case Some(value) => delegate.write(index, value) + case _ => delegate.write(index, null) + } + } + + case (ArrayEncoder(element, _), v: ListVector) => + val elementSerializer = serializerFor(element, v.getDataVector) + val toIterator = { array: Any => + mutable.WrappedArray.make(array.asInstanceOf[AnyRef]).iterator + } + new ArraySerializer(v, toIterator, elementSerializer) + + case (IterableEncoder(tag, element, _, lenient), v: ListVector) => + val elementSerializer = serializerFor(element, v.getDataVector) + val toIterator: Any => Iterator[_] = if (lenient) { + { + case i: scala.collection.Iterable[_] => i.toIterator + case l: java.util.List[_] => l.iterator().asScala + case a: Array[_] => a.iterator + case o => unsupportedCollectionType(o.getClass) + } + } else if (isSubClass(Classes.ITERABLE, tag)) { v => + v.asInstanceOf[scala.collection.Iterable[_]].toIterator + } else if (isSubClass(Classes.JLIST, tag)) { v => + v.asInstanceOf[java.util.List[_]].iterator().asScala + } else { + unsupportedCollectionType(tag.runtimeClass) + } + new ArraySerializer(v, toIterator, elementSerializer) + + case (MapEncoder(tag, key, value, _), v: MapVector) => + val structVector = v.getDataVector.asInstanceOf[StructVector] + val extractor = if (isSubClass(classOf[scala.collection.Map[_, _]], tag)) { (v: Any) => + v.asInstanceOf[scala.collection.Map[_, _]].iterator + } else if (isSubClass(classOf[JMap[_, _]], tag)) { (v: Any) => + v.asInstanceOf[JMap[Any, Any]].asScala.iterator + } else { + unsupportedCollectionType(tag.runtimeClass) + } + val structSerializer = new StructSerializer( + structVector, + new StructFieldSerializer( + (v: Any) => v.asInstanceOf[(Any, Any)]._1, + serializerFor(key, structVector.getChild(MapVector.KEY_NAME))) :: + new StructFieldSerializer( + (v: Any) => v.asInstanceOf[(Any, Any)]._2, + serializerFor(value, structVector.getChild(MapVector.VALUE_NAME))) :: Nil) + new ArraySerializer(v, extractor, structSerializer) + + case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) => + if (isSubClass(classOf[Product], tag)) { + structSerializerFor(fields, struct, vectors) { (_, i) => p => + p.asInstanceOf[Product].productElement(i) + } + } else if (isSubClass(classOf[DefinedByConstructorParams], tag)) { + structSerializerFor(fields, struct, vectors) { (field, _) => + val getter = methodLookup.findVirtual( + tag.runtimeClass, + field.name, + MethodType.methodType(field.enc.clsTag.runtimeClass)) + o => getter.invoke(o) + } + } else { + unsupportedCollectionType(tag.runtimeClass) + } + + case (RowEncoder(fields), StructVectors(struct, vectors)) => + structSerializerFor(fields, struct, vectors) { (_, i) => r => r.asInstanceOf[Row].get(i) } + + case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) => + structSerializerFor(fields, struct, vectors) { (field, _) => + val getter = methodLookup.findVirtual( + tag.runtimeClass, + field.readMethod.get, + MethodType.methodType(field.enc.clsTag.runtimeClass)) + o => getter.invoke(o) + } + + case (CalendarIntervalEncoder | _: UDTEncoder[_], _) => + throw QueryExecutionErrors.unsupportedDataTypeError(encoder.dataType) + + case _ => + throw new RuntimeException(s"Unsupported Encoder($encoder)/Vector($v) combination.") + } + } + + private val methodLookup = MethodHandles.lookup() + + private def setString(vector: VarCharVector, index: Int, string: String): Unit = { + val bytes = Text.encode(string) + vector.setSafe(index, bytes, 0, bytes.limit()) + } + + private def setDecimal(vector: DecimalVector, index: Int, decimal: JBigDecimal): Unit = { + val scaledDecimal = if (vector.getScale != decimal.scale()) { + decimal.setScale(vector.getScale) + } else { + decimal + } + vector.setSafe(index, scaledDecimal) + } + + private def structSerializerFor( + fields: Seq[EncoderField], + struct: StructVector, + vectors: Seq[FieldVector])( + createGetter: (EncoderField, Int) => Any => Any): StructSerializer = { + require(fields.size == vectors.size) + val serializers = fields.zip(vectors).zipWithIndex.map { case ((field, vector), i) => + val serializer = serializerFor(field.enc, vector) + new StructFieldSerializer(createGetter(field, i), serializer) + } + new StructSerializer(struct, serializers) + } + + abstract class Serializer { + def write(index: Int, value: Any): Unit + } + + private abstract class FieldSerializer[E, V <: FieldVector](val vector: V) extends Serializer { + private[this] val nullable = vector.getField.isNullable + def set(index: Int, value: E): Unit + + override def write(index: Int, raw: Any): Unit = { + val value = raw.asInstanceOf[E] + if (value != null) { + set(index, value) + } else if (nullable) { + vector.setNull(index) + } else { + throw new NullPointerException() + } + } + } + + private class ArraySerializer( + v: ListVector, + toIterator: Any => Iterator[Any], + elementSerializer: Serializer) + extends FieldSerializer[Any, ListVector](v) { + override def set(index: Int, value: Any): Unit = { + val elementStartIndex = vector.startNewValue(index) + var elementIndex = elementStartIndex + val iterator = toIterator(value) + while (iterator.hasNext) { + elementSerializer.write(elementIndex, iterator.next()) + elementIndex += 1 + } + vector.endValue(index, elementIndex - elementStartIndex) + } + } + + private class StructFieldSerializer(val extractor: Any => Any, val serializer: Serializer) { + def write(index: Int, value: Any): Unit = serializer.write(index, extractor(value)) + def writeNull(index: Int): Unit = serializer.write(index, null) + } + + private class StructSerializer( + struct: StructVector, + fieldSerializers: Seq[StructFieldSerializer]) + extends Serializer { + private[this] val nullable = struct != null && struct.getField.isNullable + + override def write(index: Int, value: Any): Unit = { + if (value == null) { + if (!nullable) { + throw new NullPointerException() + } + if (struct != null) { + struct.setNull(index) + } + fieldSerializers.foreach(_.writeNull(index)) + } else { + if (struct != null) { + struct.setIndexDefined(index) + } + fieldSerializers.foreach(_.write(index, value)) + } + } + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala new file mode 100644 index 0000000000000..cf4affa1db49d --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -0,0 +1,837 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client.arrow + +import java.util +import java.util.{Collections, Objects} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.reflect.classTag +import scala.util.control.NonFatal + +import com.google.protobuf.ByteString +import org.apache.arrow.memory.{BufferAllocator, RootAllocator} +import org.apache.arrow.vector.VarBinaryVector +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.connect.proto +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, DummyBean, FooEnum, JavaTypeInference, PrimitiveData, ScalaReflection} +import org.apache.spark.sql.catalyst.FooEnum.FooEnum +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, BoxedData, UDTForCaseClass} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedIntEncoder, CalendarIntervalEncoder, DateEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, RowEncoder, StringEncoder, TimestampEncoder, UDTEncoder} +import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder} +import org.apache.spark.sql.connect.client.SparkResult +import org.apache.spark.sql.connect.client.util.ConnectFunSuite +import org.apache.spark.sql.types.{ArrayType, Decimal, DecimalType, Metadata, StructType} + +/** + * Tests for encoding external data to and from arrow. + */ +class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { + private val allocator = new RootAllocator() + + private def newAllocator(name: String): BufferAllocator = { + allocator.newChildAllocator(name, 0, allocator.getLimit) + } + + protected override def afterAll(): Unit = { + super.afterAll() + allocator.close() + } + + private def withAllocator[T](f: BufferAllocator => T): T = { + val allocator = newAllocator("allocator") + try f(allocator) + finally { + allocator.close() + } + } + + private def roundTrip[T]( + encoder: AgnosticEncoder[T], + iterator: Iterator[T], + maxRecordsPerBatch: Int = 4 * 1024, + maxBatchSize: Long = 16 * 1024, + batchSizeCheckInterval: Int = 128, + inspectBatch: Array[Byte] => Unit = null): CloseableIterator[T] = { + // Use different allocators so we can pinpoint memory leaks better. + val serializerAllocator = newAllocator("serialization") + val deserializerAllocator = newAllocator("deserialization") + + val arrowIterator = ArrowSerializer.serialize( + input = iterator, + enc = encoder, + allocator = serializerAllocator, + maxRecordsPerBatch = maxRecordsPerBatch, + maxBatchSize = maxBatchSize, + batchSizeCheckInterval = batchSizeCheckInterval, + timeZoneId = "UTC") + + val inspectedIterator = if (inspectBatch != null) { + arrowIterator.map { batch => + inspectBatch(batch) + batch + } + } else { + arrowIterator + } + + val resultIterator = + try { + deserializeFromArrow(inspectedIterator, encoder, deserializerAllocator) + } catch { + case NonFatal(e) => + arrowIterator.close() + serializerAllocator.close() + deserializerAllocator.close() + throw e + } + new CloseableIterator[T] { + override def close(): Unit = { + arrowIterator.close() + resultIterator.close() + serializerAllocator.close() + deserializerAllocator.close() + } + override def hasNext: Boolean = resultIterator.hasNext + override def next(): T = resultIterator.next() + } + } + + // Temporary hack until we merge the deserializer. + private def deserializeFromArrow[E]( + batches: Iterator[Array[Byte]], + encoder: AgnosticEncoder[E], + allocator: BufferAllocator): CloseableIterator[E] = { + val responses = batches.map { batch => + val builder = proto.ExecutePlanResponse.newBuilder() + builder.getArrowBatchBuilder.setData(ByteString.copyFrom(batch)) + builder.build() + } + val result = new SparkResult[E](responses.asJava, allocator, encoder) + new CloseableIterator[E] { + private val iterator = result.iterator + override def close(): Unit = iterator.close() + override def hasNext: Boolean = iterator.hasNext + override def next(): E = iterator.next() + } + } + + private def roundTripAndCheck[T]( + encoder: AgnosticEncoder[T], + toInputIterator: () => Iterator[Any], + toOutputIterator: () => Iterator[T], + maxRecordsPerBatch: Int = 4 * 1024, + maxBatchSize: Long = 16 * 1024, + batchSizeCheckInterval: Int = 128, + inspectBatch: Array[Byte] => Unit = null): Unit = { + val iterator = roundTrip( + encoder, + toInputIterator().asInstanceOf[Iterator[T]], // Erasure hack :) + maxRecordsPerBatch, + maxBatchSize, + batchSizeCheckInterval, + inspectBatch) + try { + compareIterators(toOutputIterator(), iterator) + } finally { + iterator.close() + } + } + + private def roundTripAndCheckIdentical[T]( + encoder: AgnosticEncoder[T], + maxRecordsPerBatch: Int = 4 * 1024, + maxBatchSize: Long = 16 * 1024, + batchSizeCheckInterval: Int = 128, + inspectBatch: Array[Byte] => Unit = null)(toIterator: () => Iterator[T]): Unit = { + roundTripAndCheck( + encoder, + toIterator, + toIterator, + maxRecordsPerBatch, + maxBatchSize, + batchSizeCheckInterval, + inspectBatch) + } + + private def serializeToArrow[T]( + input: Iterator[T], + encoder: AgnosticEncoder[T], + allocator: BufferAllocator): CloseableIterator[Array[Byte]] = { + ArrowSerializer.serialize( + input, + encoder, + allocator, + maxRecordsPerBatch = 1024, + maxBatchSize = 8 * 1024, + timeZoneId = "UTC") + } + + private def compareIterators[T](expected: Iterator[T], actual: Iterator[T]): Unit = { + expected.zipAll(actual, null, null).foreach { case (expected, actual) => + assert(expected != null) + assert(actual != null) + assert(actual == expected) + } + } + + private class CountingBatchInspector extends (Array[Byte] => Unit) { + private var _numBatches: Int = 0 + private var _sizeInBytes: Long = 0 + def numBatches: Int = _numBatches + def sizeInBytes: Long = _sizeInBytes + def sizeInBytesPerBatch: Long = sizeInBytes / numBatches + override def apply(batch: Array[Byte]): Unit = { + _numBatches += 1 + _sizeInBytes += batch.length + } + } + + private case class MaybeNull(interval: Int) { + assert(interval > 1) + private var invocations = 0 + def apply[T](value: T): T = { + val result = if (invocations % interval == 0) { + null.asInstanceOf[T] + } else { + value + } + invocations += 1 + result + } + } + + private def javaBigDecimal(i: Int): java.math.BigDecimal = { + javaBigDecimal(i, DecimalType.DEFAULT_SCALE) + } + + private def javaBigDecimal(i: Int, scale: Int): java.math.BigDecimal = { + java.math.BigDecimal.valueOf(i).setScale(scale) + } + + private val singleIntEncoder = RowEncoder( + EncoderField("i", BoxedIntEncoder, nullable = false, Metadata.empty) :: Nil) + + /* ******************************************************************** * + * Iterator behavior tests. + * ******************************************************************** */ + + test("empty") { + val inspector = new CountingBatchInspector + roundTripAndCheckIdentical(singleIntEncoder, inspectBatch = inspector) { () => + Iterator.empty + } + // We always write a batch with a schema. + assert(inspector.numBatches == 1) + assert(inspector.sizeInBytes > 0) + } + + test("single batch") { + val inspector = new CountingBatchInspector + roundTripAndCheckIdentical(singleIntEncoder, inspectBatch = inspector) { () => + Iterator.tabulate(10)(i => Row(i)) + } + assert(inspector.numBatches == 1) + } + + test("multiple batches - split by record count") { + val inspector = new CountingBatchInspector + roundTripAndCheckIdentical( + singleIntEncoder, + inspectBatch = inspector, + maxBatchSize = 32 * 1024) { () => + Iterator.tabulate(1024 * 1024)(i => Row(i)) + } + assert(inspector.numBatches == 256) + } + + test("multiple batches - split by size") { + val dataGen = { () => + Iterator.tabulate(4 * 1024)(i => Row(i)) + } + + // Normal interval + val inspector1 = new CountingBatchInspector + roundTripAndCheckIdentical(singleIntEncoder, maxBatchSize = 1024, inspectBatch = inspector1)( + dataGen) + assert(inspector1.numBatches == 16) + assert(inspector1.sizeInBytesPerBatch >= 1024) + assert(inspector1.sizeInBytesPerBatch <= 1024 + 128 * 5) + + // Lowest possible interval + val inspector2 = new CountingBatchInspector + roundTripAndCheckIdentical( + singleIntEncoder, + maxBatchSize = 1024, + batchSizeCheckInterval = 1, + inspectBatch = inspector2)(dataGen) + assert(inspector2.numBatches == 20) + assert(inspector2.sizeInBytesPerBatch >= 1024) + assert(inspector2.sizeInBytesPerBatch <= 1024 + 128 * 2) + assert(inspector2.sizeInBytesPerBatch < inspector1.sizeInBytesPerBatch) + } + + /* ******************************************************************** * + * Encoder specification tests + * ******************************************************************** */ + // Lenient mode + // Errors + + test("primitive fields") { + val encoder = ScalaReflection.encoderFor[PrimitiveData] + roundTripAndCheckIdentical(encoder) { () => + Iterator.tabulate(10) { i => + PrimitiveData(i, i, i.toDouble, i.toFloat, i.toShort, i.toByte, i < 4) + } + } + } + + test("boxed primitive fields") { + val encoder = ScalaReflection.encoderFor[BoxedData] + roundTripAndCheckIdentical(encoder) { () => + val maybeNull = MaybeNull(3) + Iterator.tabulate(100) { i => + BoxedData( + intField = maybeNull(i), + longField = maybeNull(i), + doubleField = maybeNull(i.toDouble), + floatField = maybeNull(i.toFloat), + shortField = maybeNull(i.toShort), + byteField = maybeNull(i.toByte), + booleanField = maybeNull(i > 4)) + } + } + } + + test("special floating point numbers") { + val floatIterator = roundTrip( + PrimitiveFloatEncoder, + Iterator[Float](Float.NaN, Float.NegativeInfinity, Float.PositiveInfinity)) + assert(java.lang.Float.isNaN(floatIterator.next())) + assert(floatIterator.next() == Float.NegativeInfinity) + assert(floatIterator.next() == Float.PositiveInfinity) + assert(!floatIterator.hasNext) + floatIterator.close() + + val doubleIterator = roundTrip( + PrimitiveDoubleEncoder, + Iterator[Double](Double.NaN, Double.NegativeInfinity, Double.PositiveInfinity)) + assert(java.lang.Double.isNaN(doubleIterator.next())) + assert(doubleIterator.next() == Double.NegativeInfinity) + assert(doubleIterator.next() == Double.PositiveInfinity) + assert(!doubleIterator.hasNext) + doubleIterator.close() + } + + test("nullable fields") { + val encoder = ScalaReflection.encoderFor[NullableData] + val instant = java.time.Instant.now() + val now = java.time.LocalDateTime.now() + val today = java.time.LocalDate.now() + roundTripAndCheckIdentical(encoder) { () => + val maybeNull = MaybeNull(3) + Iterator.tabulate(100) { i => + NullableData( + string = maybeNull(if (i % 7 == 0) "" else "s" + i), + month = maybeNull(java.time.Month.of(1 + (i % 12))), + foo = maybeNull(FooEnum(i % FooEnum.maxId)), + decimal = maybeNull(Decimal(i)), + scalaBigDecimal = maybeNull(BigDecimal(javaBigDecimal(i + 1))), + javaBigDecimal = maybeNull(javaBigDecimal(i + 2)), + scalaBigInt = maybeNull(BigInt(i + 3)), + javaBigInteger = maybeNull(java.math.BigInteger.valueOf(i + 4)), + duration = maybeNull(java.time.Duration.ofDays(i)), + period = maybeNull(java.time.Period.ofMonths(i)), + date = maybeNull(java.sql.Date.valueOf(today.plusDays(i))), + localDate = maybeNull(today.minusDays(i)), + timestamp = maybeNull(java.sql.Timestamp.valueOf(now.plusSeconds(i))), + instant = maybeNull(instant.plusSeconds(i * 100)), + localDateTime = maybeNull(now.minusHours(i))) + } + } + } + + test("binary field") { + val encoder = ScalaReflection.encoderFor[BinaryData] + roundTripAndCheckIdentical(encoder) { () => + val maybeNull = MaybeNull(3) + Iterator.tabulate(100) { i => + BinaryData(maybeNull(Array.tabulate(i % 100)(_.toByte))) + } + } + } + + // Row and Scala class are already covered in other tests + test("javabean") { + val encoder = JavaTypeInference.encoderFor[DummyBean](classOf[DummyBean]) + roundTripAndCheckIdentical(encoder) { () => + val maybeNull = MaybeNull(6) + Iterator.tabulate(100) { i => + val bean = new DummyBean() + bean.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i))) + bean + } + } + } + + test("defined by constructor parameters") { + val encoder = ScalaReflection.encoderFor[NonProduct] + roundTripAndCheckIdentical(encoder) { () => + Iterator.tabulate(100) { i => + new NonProduct("k" + i, i.toDouble) + } + } + } + + test("option") { + val encoder = ScalaReflection.encoderFor[Option[String]] + roundTripAndCheckIdentical(encoder) { () => + val maybeNull = MaybeNull(6) + Iterator.tabulate(100) { i => + Option(maybeNull("v" + i)) + } + } + } + + test("arrays") { + val encoder = ScalaReflection.encoderFor[ArrayData] + roundTripAndCheckIdentical(encoder) { () => + val maybeNull = MaybeNull(5) + Iterator.tabulate(100) { i => + ArrayData( + maybeNull(Array.tabulate[Double](i % 9)(_.toDouble)), + maybeNull(Array.tabulate[String](i % 21)(i => maybeNull("s" + i))), + maybeNull(Array.tabulate[Array[Int]](i % 13) { i => + maybeNull { + Array.fill(i % 29)(i) + } + })) + } + } + } + + test("scala iterables") { + val encoder = ScalaReflection.encoderFor[ListData] + roundTripAndCheckIdentical(encoder) { () => + val maybeNull = MaybeNull(5) + Iterator.tabulate(100) { i => + ListData( + maybeNull(Seq.tabulate[String](i % 9)(i => maybeNull("s" + i))), + maybeNull(Seq.tabulate[Int](i % 10)(identity)), + maybeNull(Set(i.toLong, i.toLong - 1, i.toLong - 33)), + maybeNull(mutable.Queue.tabulate(5 + i % 6) { i => + Option(maybeNull(BigInt(i))) + })) + } + } + } + + test("java lists") { + def genJavaData[E](n: Int, collection: util.Collection[E])(f: Int => E): Unit = { + Iterator.tabulate(n)(f).foreach(collection.add) + } + val encoder = JavaTypeInference.encoderFor(classOf[JavaListData]) + roundTripAndCheckIdentical(encoder) { () => + val maybeNull = MaybeNull(7) + Iterator.tabulate(1) { i => + val bean = new JavaListData + bean.setListOfDecimal(maybeNull { + val list = new util.ArrayList[java.math.BigDecimal] + genJavaData(i % 7, list) { i => maybeNull(java.math.BigDecimal.valueOf(i * 33)) } + list + }) + bean.setListOfBigInt(maybeNull { + val list = new util.LinkedList[java.math.BigInteger] + genJavaData(10, list) { i => maybeNull(java.math.BigInteger.valueOf(i * 50)) } + list + }) + bean.setListOfStrings(maybeNull { + val list = new util.ArrayList[String] + genJavaData((i + 5) % 50, list) { i => maybeNull("v" + (i * 2)) } + list + }) + bean.setListOfBytes(maybeNull(Collections.singletonList(i.toByte))) + bean + } + } + } + + test("wrapped array") { + val encoder = ScalaReflection.encoderFor[mutable.WrappedArray[Int]] + val input = mutable.WrappedArray.make[Int](Array(1, 98, 7, 6)) + val iterator = roundTrip(encoder, Iterator.single(input)) + val Seq(result) = iterator.toSeq + assert(result == input) + assert(result.array.getClass == classOf[Array[Int]]) + iterator.close() + } + + test("wrapped array - empty") { + val schema = new StructType().add("names", "array") + val encoder = toRowEncoder(schema) + val iterator = roundTrip(encoder, Iterator.single(Row(Seq()))) + val Seq(Row(raw)) = iterator.toSeq + val seq = raw.asInstanceOf[mutable.WrappedArray[String]] + assert(seq.isEmpty) + assert(seq.array.getClass == classOf[Array[String]]) + iterator.close() + } + + test("maps") { + val encoder = ScalaReflection.encoderFor[MapData] + roundTripAndCheckIdentical(encoder) { () => + val maybeNull = MaybeNull(5) + Iterator.tabulate(100) { i => + MapData( + maybeNull( + Iterator + .tabulate(i % 9) { i => + i -> maybeNull("s" + i) + } + .toMap), + maybeNull( + Iterator + .tabulate(i % 10) { i => + ("s" + 1) -> maybeNull(Array.tabulate[Long]((i + 5) % 20)(_.toLong)) + } + .toMap)) + } + } + } + + test("java maps") { + val encoder = JavaTypeInference.encoderFor(classOf[JavaMapData]) + roundTripAndCheckIdentical(encoder) { () => + val maybeNull = MaybeNull(11) + Iterator.tabulate(100) { i => + val bean = new JavaMapData + bean.setDummyToDoubleListMap(maybeNull { + val map = new util.HashMap[DummyBean, java.util.List[java.lang.Double]] + (0 until (i % 5)).foreach { j => + val dummy = new DummyBean + dummy.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i * j))) + val values = Array.tabulate(i % 40) { j => + Double.box(j.toDouble) + } + map.put(dummy, maybeNull(util.Arrays.asList(values: _*))) + } + map + }) + bean + } + } + } + + test("map with null key") { + val encoder = ScalaReflection.encoderFor[Map[String, String]] + withAllocator { allocator => + val iterator = ArrowSerializer.serialize( + Iterator(Map((null.asInstanceOf[String], "kaboom?"))), + encoder, + allocator, + maxRecordsPerBatch = 128, + maxBatchSize = 1024, + timeZoneId = "UTC") + intercept[NullPointerException] { + iterator.next() + } + iterator.close() + } + } + + // TODO follow-up with more null tests here: + // - Null primitive + // - Non-nullable map value + // - Non-nullable structfield + // - Non-nullable array element. + + test("lenient field serialization - date/localdate") { + val base = java.time.LocalDate.now() + val localDates = () => Iterator.tabulate(10)(i => base.plusDays(i * i * 60)) + val dates = () => localDates().map(java.sql.Date.valueOf) + val combo = () => localDates() ++ dates() + roundTripAndCheck(DateEncoder(true), dates, dates) + roundTripAndCheck(DateEncoder(true), localDates, dates) + roundTripAndCheck(DateEncoder(true), combo, () => dates() ++ dates()) + roundTripAndCheck(LocalDateEncoder(true), dates, localDates) + roundTripAndCheck(LocalDateEncoder(true), localDates, localDates) + roundTripAndCheck(LocalDateEncoder(true), combo, () => localDates() ++ localDates()) + } + + test("lenient field serialization - timestamp/instant") { + val base = java.time.Instant.now() + val instants = () => Iterator.tabulate(10)(i => base.plusSeconds(i * i * 60)) + val timestamps = () => instants().map(java.sql.Timestamp.from) + val combo = () => instants() ++ timestamps() + roundTripAndCheck(InstantEncoder(true), instants, instants) + roundTripAndCheck(InstantEncoder(true), timestamps, instants) + roundTripAndCheck(InstantEncoder(true), combo, () => instants() ++ instants()) + roundTripAndCheck(TimestampEncoder(true), instants, timestamps) + roundTripAndCheck(TimestampEncoder(true), timestamps, timestamps) + roundTripAndCheck(TimestampEncoder(true), combo, () => timestamps() ++ timestamps()) + } + + test("lenient field serialization - decimal") { + val base = javaBigDecimal(137, DecimalType.DEFAULT_SCALE) + val bigDecimals = () => + Iterator.tabulate(100) { i => + base.multiply(javaBigDecimal(i)).setScale(DecimalType.DEFAULT_SCALE) + } + val bigInts = () => bigDecimals().map(_.toBigInteger) + val scalaBigDecimals = () => bigDecimals().map(BigDecimal.apply) + val scalaBigInts = () => bigDecimals().map(v => BigInt(v.toBigInteger)) + val sparkDecimals = () => bigDecimals().map(Decimal.apply) + val encoder = JavaDecimalEncoder(DecimalType.SYSTEM_DEFAULT, lenientSerialization = true) + roundTripAndCheck(encoder, bigDecimals, bigDecimals) + roundTripAndCheck(encoder, bigInts, bigDecimals) + roundTripAndCheck(encoder, scalaBigDecimals, bigDecimals) + roundTripAndCheck(encoder, scalaBigInts, bigDecimals) + roundTripAndCheck(encoder, sparkDecimals, bigDecimals) + roundTripAndCheck( + encoder, + () => bigDecimals() ++ bigInts() ++ scalaBigDecimals() ++ scalaBigInts() ++ sparkDecimals(), + () => Iterator.fill(5)(bigDecimals()).flatten) + } + + test("lenient field serialization - iterables") { + val encoder = IterableEncoder( + classTag[Seq[Int]], + BoxedIntEncoder, + containsNull = true, + lenientSerialization = true) + val elements = Seq(Array(1, 7, 8), Array.emptyIntArray, Array(88)) + val primitiveArrays = () => elements.iterator + val genericArrays = () => elements.iterator.map(v => v.map(Int.box)) + val lists = () => elements.iterator.map(v => java.util.Arrays.asList(v.map(Int.box): _*)) + val seqs = () => elements.iterator.map(_.toSeq) + roundTripAndCheck(encoder, seqs, seqs) + roundTripAndCheck(encoder, primitiveArrays, seqs) + roundTripAndCheck(encoder, genericArrays, seqs) + roundTripAndCheck(encoder, lists, seqs) + roundTripAndCheck( + encoder, + () => lists() ++ seqs() ++ genericArrays() ++ primitiveArrays(), + () => Iterator.fill(4)(seqs()).flatten) + } + + private val wideSchemaEncoder = toRowEncoder( + new StructType() + .add("a", "int") + .add("b", "string") + .add( + "c", + new StructType() + .add("ca", "array") + .add("cb", "binary") + .add("cc", "float")) + .add( + "d", + ArrayType( + new StructType() + .add("da", "decimal(20, 10)") + .add("db", "string") + .add("dc", "boolean")))) + + private val narrowSchemaEncoder = toRowEncoder( + new StructType() + .add("b", "string") + .add( + "d", + ArrayType( + new StructType() + .add("da", "decimal(20, 10)") + .add("dc", "boolean"))) + .add( + "C", + new StructType() + .add("Ca", "array") + .add("Cb", "binary"))) + + /* ******************************************************************** * + * Arrow serialization/deserialization specific errors + * ******************************************************************** */ + test("unsupported encoders") { + // CalendarIntervalEncoder + val data = null.asInstanceOf[AnyRef] + intercept[SparkUnsupportedOperationException]( + ArrowSerializer.serializerFor(CalendarIntervalEncoder, data)) + + // UDT + val udtEncoder = UDTEncoder(new UDTForCaseClass, classOf[UDTForCaseClass]) + intercept[SparkUnsupportedOperationException](ArrowSerializer.serializerFor(udtEncoder, data)) + } + + test("unsupported encoder/vector combinations") { + // Also add a test for the serializer... + withAllocator { allocator => + intercept[RuntimeException] { + ArrowSerializer.serializerFor(StringEncoder, new VarBinaryVector("bytes", allocator)) + } + } + } +} + +// TODO fix actual Null fields, e.g.: nullable: Null +case class NullableData( + string: String, + month: java.time.Month, + foo: FooEnum, + decimal: Decimal, + scalaBigDecimal: BigDecimal, + javaBigDecimal: java.math.BigDecimal, + scalaBigInt: BigInt, + javaBigInteger: java.math.BigInteger, + duration: java.time.Duration, + period: java.time.Period, + date: java.sql.Date, + localDate: java.time.LocalDate, + timestamp: java.sql.Timestamp, + instant: java.time.Instant, + localDateTime: java.time.LocalDateTime) + +case class BinaryData(binary: Array[Byte]) { + def canEqual(other: Any): Boolean = other.isInstanceOf[BinaryData] + + override def equals(other: Any): Boolean = other match { + case that: BinaryData if that.canEqual(this) => + java.util.Arrays.equals(binary, that.binary) + case _ => false + } + + override def hashCode(): Int = java.util.Arrays.hashCode(binary) +} + +class NonProduct(val name: String, val value: Double) extends DefinedByConstructorParams { + + def canEqual(other: Any): Boolean = other.isInstanceOf[NonProduct] + + override def equals(other: Any): Boolean = other match { + case that: NonProduct => + (that canEqual this) && + name == that.name && + value == that.value + case _ => false + } + + override def hashCode(): Int = { + val state = Seq(name, value) + state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) + } +} + +case class ArrayData(doubles: Array[Double], strings: Array[String], nested: Array[Array[Int]]) { + def canEqual(other: Any): Boolean = other.isInstanceOf[ArrayData] + + override def equals(other: Any): Boolean = other match { + case that: ArrayData if that.canEqual(this) => + Objects.deepEquals(that.doubles, doubles) && + Objects.deepEquals(that.strings, strings) && + Objects.deepEquals(that.nested, nested) + case _ => false + } + + override def hashCode(): Int = { + val state = Seq(doubles, strings, nested) + state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) + } +} + +case class ListData( + seqOfStrings: Seq[String], + seqOfInts: Seq[Int], + setOfLongs: Set[Long], + queueOfBigIntOptions: mutable.Queue[Option[BigInt]]) + +class JavaListData { + @scala.beans.BeanProperty + var listOfDecimal: java.util.ArrayList[java.math.BigDecimal] = _ + @scala.beans.BeanProperty + var listOfBigInt: java.util.LinkedList[java.math.BigInteger] = _ + @scala.beans.BeanProperty + var listOfStrings: java.util.AbstractList[String] = _ + @scala.beans.BeanProperty + var listOfBytes: java.util.List[java.lang.Byte] = _ + + def canEqual(other: Any): Boolean = other.isInstanceOf[JavaListData] + + override def equals(other: Any): Boolean = other match { + case that: JavaListData if that canEqual this => + Objects.equals(listOfDecimal, that.listOfDecimal) && + Objects.equals(listOfBigInt, that.listOfBigInt) && + Objects.equals(listOfStrings, that.listOfStrings) && + Objects.equals(listOfBytes, that.listOfBytes) + case _ => false + } + + override def hashCode(): Int = { + val state = Seq(listOfDecimal, listOfBigInt, listOfStrings, listOfBytes) + state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + + override def toString: String = { + s"JavaListData(listOfDecimal=$listOfDecimal, " + + s"listOfBigInt=$listOfBigInt, " + + s"listOfStrings=$listOfStrings, " + + s"listOfBytes=$listOfBytes)" + } +} + +case class MapData(intStringMap: Map[Int, String], metricMap: Map[String, Array[Long]]) { + def canEqual(other: Any): Boolean = other.isInstanceOf[MapData] + + private def sameMetricMap(other: Map[String, Array[Long]]): Boolean = { + if (metricMap == null && other == null) { + true + } else if (metricMap == null || other == null || metricMap.keySet != other.keySet) { + false + } else { + metricMap.forall { case (key, values) => + java.util.Arrays.equals(values, other(key)) + } + } + } + + override def equals(other: Any): Boolean = other match { + case that: MapData if that canEqual this => + Objects.deepEquals(intStringMap, that.intStringMap) && + sameMetricMap(that.metricMap) + case _ => false + } + + override def hashCode(): Int = { + java.util.Arrays.deepHashCode(Array(intStringMap, metricMap)) + } +} + +class JavaMapData { + @scala.beans.BeanProperty + var dummyToDoubleListMap: java.util.Map[DummyBean, java.util.List[java.lang.Double]] = _ + + def canEqual(other: Any): Boolean = other.isInstanceOf[JavaMapData] + + override def equals(other: Any): Boolean = other match { + case that: JavaMapData if that canEqual this => + dummyToDoubleListMap == that.dummyToDoubleListMap + case _ => false + } + + override def hashCode(): Int = Objects.hashCode(dummyToDoubleListMap) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala index 35f5bf739bfce..90c61c402306e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import java.math.BigInteger -import java.util.{LinkedList, List => JList, Map => JMap} +import java.util.{LinkedList, List => JList, Map => JMap, Objects} import scala.beans.{BeanProperty, BooleanBeanProperty} import scala.reflect.{classTag, ClassTag} @@ -30,6 +30,13 @@ import org.apache.spark.sql.types.{DecimalType, MapType, Metadata, StringType, S class DummyBean { @BeanProperty var bigInteger: BigInteger = _ + + override def hashCode(): Int = Objects.hashCode(bigInteger) + + override def equals(obj: Any): Boolean = obj match { + case bean: DummyBean => Objects.equals(bigInteger, bean.bigInteger) + case _ => false + } } class GenericCollectionBean { From 30cc7b3236a5caf45b323c67f1fecf3f631a7d1f Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 11 Jul 2023 18:23:50 -0400 Subject: [PATCH 02/16] Undo Scala version change --- assembly/pom.xml | 4 ++-- common/kvstore/pom.xml | 4 ++-- common/network-common/pom.xml | 4 ++-- common/network-shuffle/pom.xml | 4 ++-- common/network-yarn/pom.xml | 4 ++-- common/sketch/pom.xml | 4 ++-- common/tags/pom.xml | 4 ++-- common/unsafe/pom.xml | 4 ++-- common/utils/pom.xml | 4 ++-- connector/avro/pom.xml | 8 ++++---- connector/connect/client/jvm/pom.xml | 4 ++-- connector/connect/common/pom.xml | 4 ++-- connector/connect/server/pom.xml | 8 ++++---- connector/docker-integration-tests/pom.xml | 4 ++-- connector/kafka-0-10-assembly/pom.xml | 4 ++-- connector/kafka-0-10-sql/pom.xml | 8 ++++---- connector/kafka-0-10-token-provider/pom.xml | 4 ++-- connector/kafka-0-10/pom.xml | 8 ++++---- connector/kinesis-asl-assembly/pom.xml | 4 ++-- connector/kinesis-asl/pom.xml | 4 ++-- connector/protobuf/pom.xml | 8 ++++---- connector/spark-ganglia-lgpl/pom.xml | 4 ++-- core/pom.xml | 8 ++++---- dev/mima | 8 ++++---- docs/_plugins/copy_api_dirs.rb | 14 +++++++------- examples/pom.xml | 4 ++-- graphx/pom.xml | 4 ++-- hadoop-cloud/pom.xml | 4 ++-- launcher/pom.xml | 4 ++-- mllib-local/pom.xml | 4 ++-- mllib/pom.xml | 8 ++++---- pom.xml | 12 ++++++------ repl/pom.xml | 4 ++-- resource-managers/kubernetes/core/pom.xml | 4 ++-- .../kubernetes/integration-tests/pom.xml | 4 ++-- resource-managers/mesos/pom.xml | 4 ++-- resource-managers/yarn/pom.xml | 4 ++-- sql/api/pom.xml | 4 ++-- sql/catalyst/pom.xml | 8 ++++---- sql/core/pom.xml | 8 ++++---- sql/hive-thriftserver/pom.xml | 8 ++++---- sql/hive/pom.xml | 8 ++++---- streaming/pom.xml | 8 ++++---- tools/pom.xml | 4 ++-- 44 files changed, 123 insertions(+), 123 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index d4d7a1db4a29e..09d6bd8a33f79 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../pom.xml - spark-assembly_2.13 + spark-assembly_2.12 Spark Project Assembly https://spark.apache.org/ pom diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index 69f9f186e0889..bef8303874b20 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-kvstore_2.13 + spark-kvstore_2.12 jar Spark Project Local DB https://spark.apache.org/ diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 9f90d12216e69..8a63e999c53cd 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-network-common_2.13 + spark-network-common_2.12 jar Spark Project Networking https://spark.apache.org/ diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 864f1cc2d3715..a8bde14a259f0 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-network-shuffle_2.13 + spark-network-shuffle_2.12 jar Spark Project Shuffle Streaming Service https://spark.apache.org/ diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index c19ac33afa5cd..671d5cb7e0178 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-network-yarn_2.13 + spark-network-yarn_2.12 jar Spark Project YARN Shuffle Service https://spark.apache.org/ diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 6cf1a4fb83e56..4cc597519c3dd 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-sketch_2.13 + spark-sketch_2.12 jar Spark Project Sketch https://spark.apache.org/ diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 1eb8352e32df3..9a44c847d8a03 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-tags_2.13 + spark-tags_2.12 jar Spark Project Tags https://spark.apache.org/ diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 84e7b61553483..bdf82d9285e06 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-unsafe_2.13 + spark-unsafe_2.12 jar Spark Project Unsafe https://spark.apache.org/ diff --git a/common/utils/pom.xml b/common/utils/pom.xml index ee10a60618297..36cfceed931e0 100644 --- a/common/utils/pom.xml +++ b/common/utils/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-common-utils_2.13 + spark-common-utils_2.12 jar Spark Project Common Utils https://spark.apache.org/ diff --git a/connector/avro/pom.xml b/connector/avro/pom.xml index 7087fdbccd04d..597e3c2235f7a 100644 --- a/connector/avro/pom.xml +++ b/connector/avro/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-avro_2.13 + spark-avro_2.12 avro @@ -70,12 +70,12 @@ org.apache.spark spark-tags_${scala.binary.version} - + + --> org.tukaani xz diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index cef6f6d214e74..60e4ae78147ee 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../../../pom.xml - spark-connect-client-jvm_2.13 + spark-connect-client-jvm_2.12 jar Spark Project Connect Client https://spark.apache.org/ diff --git a/connector/connect/common/pom.xml b/connector/connect/common/pom.xml index 9b28aca5d0726..1890384b51db5 100644 --- a/connector/connect/common/pom.xml +++ b/connector/connect/common/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../../pom.xml - spark-connect-common_2.13 + spark-connect-common_2.12 jar Spark Project Connect Common https://spark.apache.org/ diff --git a/connector/connect/server/pom.xml b/connector/connect/server/pom.xml index 01d6d86c54292..95b70c6b0f41d 100644 --- a/connector/connect/server/pom.xml +++ b/connector/connect/server/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../../pom.xml - spark-connect_2.13 + spark-connect_2.12 jar Spark Project Connect Server https://spark.apache.org/ @@ -152,12 +152,12 @@ - + + --> com.google.guava guava diff --git a/connector/docker-integration-tests/pom.xml b/connector/docker-integration-tests/pom.xml index e40269838522d..cc549487a8b57 100644 --- a/connector/docker-integration-tests/pom.xml +++ b/connector/docker-integration-tests/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-docker-integration-tests_2.13 + spark-docker-integration-tests_2.12 jar Spark Project Docker Integration Tests https://spark.apache.org/ diff --git a/connector/kafka-0-10-assembly/pom.xml b/connector/kafka-0-10-assembly/pom.xml index f339e8c2e4fd2..340974cc789bd 100644 --- a/connector/kafka-0-10-assembly/pom.xml +++ b/connector/kafka-0-10-assembly/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-streaming-kafka-0-10-assembly_2.13 + spark-streaming-kafka-0-10-assembly_2.12 jar Spark Integration for Kafka 0.10 Assembly https://spark.apache.org/ diff --git a/connector/kafka-0-10-sql/pom.xml b/connector/kafka-0-10-sql/pom.xml index f5a12b61c2bea..fdd1196cd446a 100644 --- a/connector/kafka-0-10-sql/pom.xml +++ b/connector/kafka-0-10-sql/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-sql-kafka-0-10_2.13 + spark-sql-kafka-0-10_2.12 sql-kafka-0-10 @@ -74,12 +74,12 @@ test-jar test - + + --> org.apache.kafka kafka-clients diff --git a/connector/kafka-0-10-token-provider/pom.xml b/connector/kafka-0-10-token-provider/pom.xml index b3c0889d9475a..3256130c50f3b 100644 --- a/connector/kafka-0-10-token-provider/pom.xml +++ b/connector/kafka-0-10-token-provider/pom.xml @@ -20,13 +20,13 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml org.apache.spark - spark-token-provider-kafka-0-10_2.13 + spark-token-provider-kafka-0-10_2.12 token-provider-kafka-0-10 diff --git a/connector/kafka-0-10/pom.xml b/connector/kafka-0-10/pom.xml index f1820bb595a2d..706eb2dd2c399 100644 --- a/connector/kafka-0-10/pom.xml +++ b/connector/kafka-0-10/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-streaming-kafka-0-10_2.13 + spark-streaming-kafka-0-10_2.12 streaming-kafka-0-10 @@ -59,12 +59,12 @@ test-jar test - + + --> org.apache.kafka kafka-clients diff --git a/connector/kinesis-asl-assembly/pom.xml b/connector/kinesis-asl-assembly/pom.xml index 2cba2668f049a..cd5c0393f6f84 100644 --- a/connector/kinesis-asl-assembly/pom.xml +++ b/connector/kinesis-asl-assembly/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-streaming-kinesis-asl-assembly_2.13 + spark-streaming-kinesis-asl-assembly_2.12 jar Spark Project Kinesis Assembly https://spark.apache.org/ diff --git a/connector/kinesis-asl/pom.xml b/connector/kinesis-asl/pom.xml index af9cd4b7ec96e..c70a073e73407 100644 --- a/connector/kinesis-asl/pom.xml +++ b/connector/kinesis-asl/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-streaming-kinesis-asl_2.13 + spark-streaming-kinesis-asl_2.12 jar Spark Kinesis Integration diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml index db92af75a5728..3d6bbea7d41c5 100644 --- a/connector/protobuf/pom.xml +++ b/connector/protobuf/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-protobuf_2.13 + spark-protobuf_2.12 protobuf @@ -70,12 +70,12 @@ org.apache.spark spark-tags_${scala.binary.version} - + + --> com.google.protobuf protobuf-java diff --git a/connector/spark-ganglia-lgpl/pom.xml b/connector/spark-ganglia-lgpl/pom.xml index 00f4769fd60ad..c0dcde1355849 100644 --- a/connector/spark-ganglia-lgpl/pom.xml +++ b/connector/spark-ganglia-lgpl/pom.xml @@ -19,13 +19,13 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-ganglia-lgpl_2.13 + spark-ganglia-lgpl_2.12 jar Spark Ganglia Integration diff --git a/core/pom.xml b/core/pom.xml index 79bf8a2163554..6519b46d96e31 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../pom.xml - spark-core_2.13 + spark-core_2.12 jar Spark Project Core https://spark.apache.org/ @@ -35,12 +35,12 @@ - + + --> org.apache.avro avro diff --git a/dev/mima b/dev/mima index 32c3718e4ccca..4a9e343b0a78f 100755 --- a/dev/mima +++ b/dev/mima @@ -24,9 +24,9 @@ set -e FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" -SPARK_PROFILES=${1:-"-Pscala-2.13 -Pmesos -Pkubernetes -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive"} -TOOLS_CLASSPATH="$(build/sbt -Pscala-2.13 -DcopyDependencies=false "export tools/fullClasspath" | grep jar | tail -n1)" -OLD_DEPS_CLASSPATH="$(build/sbt -Pscala-2.13 -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | grep jar | tail -n1)" +SPARK_PROFILES=${1:-"-Pmesos -Pkubernetes -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive"} +TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | grep jar | tail -n1)" +OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | grep jar | tail -n1)" rm -f .generated-mima* @@ -42,7 +42,7 @@ $JAVA_CMD \ -cp "$TOOLS_CLASSPATH:$OLD_DEPS_CLASSPATH" \ org.apache.spark.tools.GenerateMIMAIgnore -echo -e "q\n" | build/sbt -Pscala-2.13 -mem 5120 -DcopyDependencies=false "$@" mimaReportBinaryIssues | grep -v -e "info.*Resolving" +echo -e "q\n" | build/sbt -mem 5120 -DcopyDependencies=false "$@" mimaReportBinaryIssues | grep -v -e "info.*Resolving" ret_val=$? if [ $ret_val != 0 ]; then diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 9cb073ef1e00c..28d5e0d82c93a 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -26,8 +26,8 @@ curr_dir = pwd cd("..") - puts "Running 'build/sbt -Pscala-2.13 -Pkinesis-asl clean compile unidoc' from " + pwd + "; this may take a few minutes..." - system("build/sbt -Pscala-2.13 -Pkinesis-asl clean compile unidoc") || raise("Unidoc generation failed") + puts "Running 'build/sbt -Pkinesis-asl clean compile unidoc' from " + pwd + "; this may take a few minutes..." + system("build/sbt -Pkinesis-asl clean compile unidoc") || raise("Unidoc generation failed") puts "Moving back into docs dir." cd("docs") @@ -37,7 +37,7 @@ # Copy over the unified ScalaDoc for all projects to api/scala. # This directory will be copied over to _site when `jekyll` command is run. - source = "../target/scala-2.13/unidoc" + source = "../target/scala-2.12/unidoc" dest = "api/scala" puts "Making directory " + dest @@ -119,8 +119,8 @@ puts "Moving to project root and building API docs." cd("..") - puts "Running 'build/sbt -Pscala-2.13 clean package -Phive' from " + pwd + "; this may take a few minutes..." - system("build/sbt -Pscala-2.13 clean package -Phive") || raise("PySpark doc generation failed") + puts "Running 'build/sbt clean package -Phive' from " + pwd + "; this may take a few minutes..." + system("build/sbt clean package -Phive") || raise("PySpark doc generation failed") puts "Moving back into docs dir." cd("docs") @@ -165,8 +165,8 @@ puts "Moving to project root and building API docs." cd("..") - puts "Running 'build/sbt -Pscala-2.13 clean package -Phive' from " + pwd + "; this may take a few minutes..." - system("build/sbt -Pscala-2.13 clean package -Phive") || raise("SQL doc generation failed") + puts "Running 'build/sbt clean package -Phive' from " + pwd + "; this may take a few minutes..." + system("build/sbt clean package -Phive") || raise("SQL doc generation failed") puts "Moving back into docs dir." cd("docs") diff --git a/examples/pom.xml b/examples/pom.xml index 57e41724bdca4..e8f22b995fded 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../pom.xml - spark-examples_2.13 + spark-examples_2.12 jar Spark Project Examples https://spark.apache.org/ diff --git a/graphx/pom.xml b/graphx/pom.xml index 5d01dd06c0ecb..48baeb9a87560 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../pom.xml - spark-graphx_2.13 + spark-graphx_2.12 graphx diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 21c1e0fee1ddf..02e7675df286c 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../pom.xml - spark-hadoop-cloud_2.13 + spark-hadoop-cloud_2.12 jar Spark Project Hadoop Cloud Integration diff --git a/launcher/pom.xml b/launcher/pom.xml index 0bc3ae20ee183..aba7ee82d53cf 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../pom.xml - spark-launcher_2.13 + spark-launcher_2.12 jar Spark Project Launcher https://spark.apache.org/ diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 83ca643f43bce..00c16a8b6a544 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../pom.xml - spark-mllib-local_2.13 + spark-mllib-local_2.12 mllib-local diff --git a/mllib/pom.xml b/mllib/pom.xml index 07290124273f2..73af83c758688 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../pom.xml - spark-mllib_2.13 + spark-mllib_2.12 mllib @@ -91,12 +91,12 @@ test-jar test - + + --> org.scalanlp breeze_${scala.binary.version} diff --git a/pom.xml b/pom.xml index 2a917b46d8520..96375ea904dd8 100644 --- a/pom.xml +++ b/pom.xml @@ -25,7 +25,7 @@ 18 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT pom Spark Project Parent POM @@ -168,8 +168,8 @@ 3.2.2 4.4 - 2.13.11 - 2.13 + 2.12.18 + 2.12 2.2.0 4.8.0 @@ -438,13 +438,13 @@ ${project.version} test-jar - + + --> com.twitter chill_${scala.binary.version} @@ -1089,7 +1089,7 @@ org.scala-lang.modules - scala-xml_2.13 + scala-xml_2.12 diff --git a/repl/pom.xml b/repl/pom.xml index 74ac775100cb8..8c0f9f989c170 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../pom.xml - spark-repl_2.13 + spark-repl_2.12 jar Spark Project REPL https://spark.apache.org/ diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 72c7f1f12f42d..9dab5496184e2 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../../pom.xml - spark-kubernetes_2.13 + spark-kubernetes_2.12 jar Spark Project Kubernetes diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 3e25e7053707a..02894f82eec9d 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../../pom.xml - spark-kubernetes-integration-tests_2.13 + spark-kubernetes-integration-tests_2.12 kubernetes-integration-tests diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 267d6c7d84f21..7510ecac3e7fc 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-mesos_2.13 + spark-mesos_2.12 jar Spark Project Mesos diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 2cda552a9c47c..dcc7bcdd1af38 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -19,12 +19,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-yarn_2.13 + spark-yarn_2.12 jar Spark Project YARN diff --git a/sql/api/pom.xml b/sql/api/pom.xml index db3bfaeeca0a9..41a5b85d4c670 100644 --- a/sql/api/pom.xml +++ b/sql/api/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-sql-api_2.13 + spark-sql-api_2.12 jar Spark Project SQL API https://spark.apache.org/ diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 80b7c99ddc139..9dbc8d625d079 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-catalyst_2.13 + spark-catalyst_2.12 jar Spark Project Catalyst https://spark.apache.org/ @@ -92,12 +92,12 @@ spark-sketch_${scala.binary.version} ${project.version} - + + --> org.scalacheck scalacheck_${scala.binary.version} diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 5d4f7572d0022..7f4c2a4cfa54d 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-sql_2.13 + spark-sql_2.12 jar Spark Project SQL https://spark.apache.org/ @@ -89,12 +89,12 @@ test - + + --> org.apache.orc orc-core diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 4bbb92d1376a0..ad7fc0d2ac4bd 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-hive-thriftserver_2.13 + spark-hive-thriftserver_2.12 jar Spark Project Hive Thrift Server https://spark.apache.org/ @@ -61,12 +61,12 @@ test-jar test - + + --> com.google.guava guava diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index b267830a3ad5f..16d915c233ee4 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -21,12 +21,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../../pom.xml - spark-hive_2.13 + spark-hive_2.12 jar Spark Project Hive https://spark.apache.org/ @@ -79,12 +79,12 @@ test-jar test - + + --> ${hive.group} hive-common diff --git a/streaming/pom.xml b/streaming/pom.xml index a36370a1e8b61..bebfd3abcce39 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -20,12 +20,12 @@ 4.0.0 org.apache.spark - spark-parent_2.13 + spark-parent_2.12 3.5.0-SNAPSHOT ../pom.xml - spark-streaming_2.13 + spark-streaming_2.12 streaming @@ -50,12 +50,12 @@ org.apache.spark spark-tags_${scala.binary.version} - + + --> @@ -224,6 +225,24 @@ + + org.codehaus.mojo + build-helper-maven-plugin + + + add-sources + generate-sources + + add-source + + + + src/main/scala-${scala.binary.version} + + + + + \ No newline at end of file diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index c7e9d22682f86..1cdc2035de60b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -92,7 +92,13 @@ private[sql] class SparkResult[T]( arrowSchema = reader.schema stop |= stopOnArrowSchema } else if (arrowSchema != reader.schema) { - // Uh oh... + throw new IllegalStateException( + s"""Schema Mismatch between expected and received schema: + |=== Expected Schema === + |$arrowSchema + |=== Received Schema === + |${reader.schema} + |""".stripMargin) } if (structType == null) { // If the schema is not available yet, fallback to the arrow schema. diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 91589e0945aed..154866d699a34 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.types.{Decimal, StructType} +import org.apache.spark.sql.types.Decimal /** * Helper class for converting arrow batches into user objects. @@ -54,7 +54,7 @@ object ArrowDeserializers { def deserializeFromArrow[T]( input: Iterator[Array[Byte]], encoder: AgnosticEncoder[T], - allocator: BufferAllocator): TypedDeserializingIterator[T] = { + allocator: BufferAllocator): CloseableIterator[T] = { try { val reader = new ConcatenatingArrowStreamReader( allocator, @@ -496,13 +496,8 @@ object ArrowDeserializers { } } -trait TypedDeserializingIterator[E] extends CloseableIterator[E] { - def encoder: AgnosticEncoder[E] - def schema: StructType = encoder.schema -} - -class EmptyDeserializingIterator[E](override val encoder: AgnosticEncoder[E]) - extends TypedDeserializingIterator[E] { +class EmptyDeserializingIterator[E](val encoder: AgnosticEncoder[E]) + extends CloseableIterator[E] { override def close(): Unit = () override def hasNext: Boolean = false override def next(): E = throw new NoSuchElementException() @@ -511,7 +506,7 @@ class EmptyDeserializingIterator[E](override val encoder: AgnosticEncoder[E]) class ArrowDeserializingIterator[E]( val encoder: AgnosticEncoder[E], private[this] val reader: ArrowReader) - extends TypedDeserializingIterator[E] { + extends CloseableIterator[E] { private[this] var index = 0 private[this] val root = reader.getVectorSchemaRoot private[this] val deserializer = ArrowDeserializers.deserializerFor(encoder, root) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 73c401c26cd5b..7e7eab036aed3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -29,7 +29,7 @@ import org.apache.arrow.vector.VarBinaryVector import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedIntEncoder, CalendarIntervalEncoder, DateEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, RowEncoder, StringEncoder, TimestampEncoder, UDTEncoder} @@ -214,6 +214,17 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { assert(inspector.sizeInBytes > 0) } + test("deserializing empty iterator") { + withAllocator { allocator => + val iterator = ArrowDeserializers.deserializeFromArrow( + Iterator.empty, + singleIntEncoder, + allocator) + assert(iterator.isEmpty) + assert(allocator.getAllocatedMemory == 0) + } + } + test("single batch") { val inspector = new CountingBatchInspector roundTripAndCheckIdentical(singleIntEncoder, inspectBatch = inspector) { () => @@ -501,15 +512,22 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { val maybeNull = MaybeNull(11) Iterator.tabulate(100) { i => val bean = new JavaMapData - bean.setDummyToDoubleListMap(maybeNull { - val map = new util.HashMap[DummyBean, java.util.List[java.lang.Double]] - (0 until (i % 5)).foreach { j => - val dummy = new DummyBean - dummy.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i * j))) + bean.setMetricMap(maybeNull { + val map = new util.HashMap[String, util.List[java.lang.Double]] + (0 until (i % 20)).foreach { i => val values = Array.tabulate(i % 40) { j => Double.box(j.toDouble) } - map.put(dummy, maybeNull(util.Arrays.asList(values: _*))) + map.put("k" + i, maybeNull(util.Arrays.asList(values: _*))) + } + map + }) + bean.setDummyToStringMap(maybeNull { + val map = new util.HashMap[DummyBean, String] + (0 until (i % 5)).foreach { j => + val dummy = new DummyBean + dummy.setBigInteger(maybeNull(java.math.BigInteger.valueOf(i * j))) + map.put(dummy, maybeNull("s" + i + "v" + j)) } map }) @@ -643,6 +661,63 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { .add("Ca", "array") .add("Cb", "binary"))) + test("bind to schema") { + // Binds to a wider schema. The narrow schema has fewer (nested) fields, has a slightly + // different field order, and uses different cased names in a couple of places. + withAllocator { allocator => + val input = Row( + 887, + "foo", + Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte), 5f), + Seq(Row(null, "a", false), Row(javaBigDecimal(57853, 10), "b", false))) + val expected = Row( + "foo", + Seq(Row(null, false), Row(javaBigDecimal(57853, 10), false)), + Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte))) + val arrowBatches = serializeToArrow(Iterator.single(input), wideSchemaEncoder, allocator) + val result = ArrowDeserializers.deserializeFromArrow( + arrowBatches, + narrowSchemaEncoder, + allocator) + val actual = result.next() + assert(result.isEmpty) + assert(expected === actual) + result.close() + arrowBatches.close() + } + } + + test("unknown field") { + withAllocator { allocator => + val arrowBatches = serializeToArrow(Iterator.empty, narrowSchemaEncoder, allocator) + intercept[AnalysisException] { + ArrowDeserializers.deserializeFromArrow( + arrowBatches, + wideSchemaEncoder, + allocator) + } + arrowBatches.close() + } + } + + test("duplicate fields") { + val duplicateSchemaEncoder = toRowEncoder(new StructType() + .add("foO", "string") + .add("Foo", "string")) + val fooSchemaEncoder = toRowEncoder(new StructType() + .add("foo", "string")) + withAllocator { allocator => + val arrowBatches = serializeToArrow(Iterator.empty, duplicateSchemaEncoder, allocator) + intercept[AnalysisException] { + ArrowDeserializers.deserializeFromArrow( + arrowBatches, + fooSchemaEncoder, + allocator) + } + arrowBatches.close() + } + } + /* ******************************************************************** * * Arrow serialization/deserialization specific errors * ******************************************************************** */ @@ -801,17 +876,23 @@ case class MapData(intStringMap: Map[Int, String], metricMap: Map[String, Array[ class JavaMapData { @scala.beans.BeanProperty - var dummyToDoubleListMap: java.util.Map[DummyBean, java.util.List[java.lang.Double]] = _ + var dummyToStringMap: java.util.Map[DummyBean, String] = _ + + @scala.beans.BeanProperty + var metricMap: java.util.HashMap[String, java.util.List[java.lang.Double]] = _ def canEqual(other: Any): Boolean = other.isInstanceOf[JavaMapData] override def equals(other: Any): Boolean = other match { case that: JavaMapData if that canEqual this => - dummyToDoubleListMap == that.dummyToDoubleListMap + dummyToStringMap == that.dummyToStringMap && + metricMap == that.metricMap case _ => false } - override def hashCode(): Int = Objects.hashCode(dummyToDoubleListMap) + override def hashCode(): Int = { + java.util.Arrays.deepHashCode(Array(dummyToStringMap, metricMap)) + } } class DummyBean { From 6d140a4f6c8d91678d382e726938a439069463a2 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 19 Jul 2023 03:45:37 -0400 Subject: [PATCH 16/16] Style --- .../client/arrow/ArrowEncoderSuite.scala | 36 ++++++++----------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 7e7eab036aed3..16eec3eee3110 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -216,10 +216,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { test("deserializing empty iterator") { withAllocator { allocator => - val iterator = ArrowDeserializers.deserializeFromArrow( - Iterator.empty, - singleIntEncoder, - allocator) + val iterator = + ArrowDeserializers.deserializeFromArrow(Iterator.empty, singleIntEncoder, allocator) assert(iterator.isEmpty) assert(allocator.getAllocatedMemory == 0) } @@ -675,10 +673,8 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { Seq(Row(null, false), Row(javaBigDecimal(57853, 10), false)), Row(Seq(1, 7, 5), Array[Byte](8.toByte, 756.toByte))) val arrowBatches = serializeToArrow(Iterator.single(input), wideSchemaEncoder, allocator) - val result = ArrowDeserializers.deserializeFromArrow( - arrowBatches, - narrowSchemaEncoder, - allocator) + val result = + ArrowDeserializers.deserializeFromArrow(arrowBatches, narrowSchemaEncoder, allocator) val actual = result.next() assert(result.isEmpty) assert(expected === actual) @@ -691,28 +687,24 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { withAllocator { allocator => val arrowBatches = serializeToArrow(Iterator.empty, narrowSchemaEncoder, allocator) intercept[AnalysisException] { - ArrowDeserializers.deserializeFromArrow( - arrowBatches, - wideSchemaEncoder, - allocator) + ArrowDeserializers.deserializeFromArrow(arrowBatches, wideSchemaEncoder, allocator) } arrowBatches.close() } } test("duplicate fields") { - val duplicateSchemaEncoder = toRowEncoder(new StructType() - .add("foO", "string") - .add("Foo", "string")) - val fooSchemaEncoder = toRowEncoder(new StructType() - .add("foo", "string")) + val duplicateSchemaEncoder = toRowEncoder( + new StructType() + .add("foO", "string") + .add("Foo", "string")) + val fooSchemaEncoder = toRowEncoder( + new StructType() + .add("foo", "string")) withAllocator { allocator => val arrowBatches = serializeToArrow(Iterator.empty, duplicateSchemaEncoder, allocator) intercept[AnalysisException] { - ArrowDeserializers.deserializeFromArrow( - arrowBatches, - fooSchemaEncoder, - allocator) + ArrowDeserializers.deserializeFromArrow(arrowBatches, fooSchemaEncoder, allocator) } arrowBatches.close() } @@ -886,7 +878,7 @@ class JavaMapData { override def equals(other: Any): Boolean = other match { case that: JavaMapData if that canEqual this => dummyToStringMap == that.dummyToStringMap && - metricMap == that.metricMap + metricMap == that.metricMap case _ => false }