From bc7bfbcbfe586e3093977fa60106dbf38741a2cd Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 21 Aug 2024 09:25:19 +0900 Subject: [PATCH] [SPARK-49274][CONNECT] Support java serialization based encoders ### What changes were proposed in this pull request? This PR adds Encoders.javaSerialization to connect. It does this by creating an encoder that does not really encode data, but that transforms into something that can be encoded. This is also useful for situations in which we want to define custom serialization logic. This PR is also a stepping stone for adding Kryo based serialization. ### Why are the changes needed? This change increases parity between the connect and classic scala interfaces. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added tests for both Arrow and Expression encoders. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47806 from hvanhovell/SPARK-49274. Authored-by: Herman van Hovell Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/sql/Encoders.scala | 29 ++++++++++++- .../client/arrow/ArrowEncoderSuite.scala | 33 ++++++++++++--- .../catalyst/encoders/AgnosticEncoder.scala | 17 +++++++- .../spark/sql/catalyst/encoders/codecs.scala | 42 +++++++++++++++++++ .../catalyst/DeserializerBuildHelper.scala | 13 ++++-- .../sql/catalyst/SerializerBuildHelper.scala | 14 +++++-- .../encoders/ExpressionEncoderSuite.scala | 14 +++++++ .../client/arrow/ArrowDeserializer.scala | 7 ++++ .../client/arrow/ArrowSerializer.scala | 10 ++++- 9 files changed, 165 insertions(+), 14 deletions(-) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala index 74f0133803137..ffd9975770066 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -16,10 +16,11 @@ */ package org.apache.spark.sql +import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder => RowEncoderFactory} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, JavaSerializationCodec, RowEncoder => RowEncoderFactory} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.types.StructType @@ -176,6 +177,32 @@ object Encoders { */ def row(schema: StructType): Encoder[Row] = RowEncoderFactory.encoderFor(schema) + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java + * serialization. This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. + * + * @note + * This is extremely inefficient and should only be used as the last resort. + * @since 4.0.0 + */ + def javaSerialization[T: ClassTag]: Encoder[T] = { + TransformingEncoder(implicitly[ClassTag[T]], BinaryEncoder, JavaSerializationCodec) + } + + /** + * Creates an encoder that serializes objects of type T using generic Java serialization. This + * encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. + * + * @note + * This is extremely inefficient and should only be used as the last resort. + * @since 4.0.0 + */ + def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) + private def tupleEncoder[T](encoders: Encoder[_]*): Encoder[T] = { ProductEncoder.tuple(encoders.asInstanceOf[Seq[AgnosticEncoder[_]]]).asInstanceOf[Encoder[T]] } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 709e2cf0e84ea..70b471cf74b33 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -30,11 +30,11 @@ import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector.VarBinaryVector import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.{sql, SparkUnsupportedOperationException} +import org.apache.spark.sql.{AnalysisException, Encoders, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder} import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkStringUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND @@ -44,7 +44,7 @@ import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._ import org.apache.spark.sql.connect.client.CloseableIterator import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum import org.apache.spark.sql.test.ConnectFunSuite -import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType, YearMonthIntervalType} +import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StringType, StructType, UserDefinedType, YearMonthIntervalType} /** * Tests for encoding external data to and from arrow. @@ -769,6 +769,24 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } } + test("java serialization") { + val encoder = sql.encoderFor(Encoders.javaSerialization[(Int, String)]) + roundTripAndCheckIdentical(encoder) { () => + Iterator.tabulate(10)(i => (i, "itr_" + i)) + } + } + + test("transforming encoder") { + val schema = new StructType() + .add("key", IntegerType) + .add("value", StringType) + val encoder = + TransformingEncoder(classTag[(Int, String)], toRowEncoder(schema), () => new TestCodec) + roundTripAndCheckIdentical(encoder) { () => + Iterator.tabulate(10)(i => (i, "v" + i)) + } + } + /* ******************************************************************** * * Arrow deserialization upcasting * ******************************************************************** */ @@ -1136,3 +1154,8 @@ class UDTNotSupported extends UserDefinedType[UDTNotSupportedClass] { case i: Int => UDTNotSupportedClass(i) } } + +class TestCodec extends Codec[(Int, String), Row] { + override def encode(in: (Int, String)): Row = Row(in._1, in._2) + override def decode(out: Row): (Int, String) = (out.getInt(0), out.getString(1)) +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index 9133abce88adc..639b23f714149 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -247,5 +247,20 @@ object AgnosticEncoders { ScalaDecimalEncoder(DecimalType.SYSTEM_DEFAULT) val DEFAULT_JAVA_DECIMAL_ENCODER: JavaDecimalEncoder = JavaDecimalEncoder(DecimalType.SYSTEM_DEFAULT, lenientSerialization = false) -} + /** + * Encoder that transforms external data into a representation that can be further processed by + * another encoder. This is fallback for scenarios where objects can't be represented using + * standard encoders, an example of this is where we use a different (opaque) serialization + * format (i.e. java serialization, kryo serialization, or protobuf). + */ + case class TransformingEncoder[I, O]( + clsTag: ClassTag[I], + transformed: AgnosticEncoder[O], + codecProvider: () => Codec[_ >: I, O]) extends AgnosticEncoder[I] { + override def isPrimitive: Boolean = transformed.isPrimitive + override def dataType: DataType = transformed.dataType + override def schema: StructType = transformed.schema + override def isStruct: Boolean = transformed.isStruct + } +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala new file mode 100644 index 0000000000000..46862ebbccdfd --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.util.SparkSerDeUtils + +/** + * Codec for doing conversions between two representations. + * + * @tparam I input type (typically the external representation of the data. + * @tparam O output type (typically the internal representation of the data. + */ +trait Codec[I, O] { + def encode(in: I): O + def decode(out: O): I +} + +/** + * A codec that uses Java Serialization as its output format. + */ +class JavaSerializationCodec[I] extends Codec[I, Array[Byte]] { + override def encode(in: I): Array[Byte] = SparkSerDeUtils.serialize(in) + override def decode(out: Array[Byte]): I = SparkSerDeUtils.deserialize(out) +} + +object JavaSerializationCodec extends (() => Codec[Any, Array[Byte]]) { + override def apply(): Codec[Any, Array[Byte]] = new JavaSerializationCodec[Any] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index 0b88d5a4130e3..40b49506b58aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder} -import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, MapKeys, MapValues, UpCast} +import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, IntervalUtils} import org.apache.spark.sql.types._ @@ -410,6 +410,13 @@ object DeserializerBuildHelper { val newInstance = NewInstance(cls, Nil, ObjectType(cls), propagateNull = false) val result = InitializeJavaBean(newInstance, setters.toMap) exprs.If(IsNull(path), exprs.Literal.create(null, ObjectType(cls)), result) + + case TransformingEncoder(tag, encoder, provider) => + Invoke( + Literal.create(provider(), ObjectType(classOf[Codec[_, _]])), + "decode", + ObjectType(tag.runtimeClass), + createDeserializer(encoder, path, walkedTypePath) :: Nil) } private def deserializeArray( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index cd087514f4be3..38bf0651d6f1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -21,10 +21,10 @@ import scala.language.existentials import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, Codec} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor} -import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData} import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, GenericArrayData, IntervalUtils} import org.apache.spark.sql.internal.SQLConf @@ -397,6 +397,14 @@ object SerializerBuildHelper { f.name -> createSerializer(f.enc, fieldValue) } createSerializerForObject(input, serializedFields) + + case TransformingEncoder(_, encoder, codecProvider) => + val encoded = Invoke( + Literal(codecProvider(), ObjectType(classOf[Codec[_, _]])), + "encode", + externalDataTypeFor(encoder), + input :: Nil) + createSerializer(encoder, encoded) } private def serializerForArray( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index f46c02326e8b1..7b8d8be6bbeeb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import java.util.Arrays import scala.collection.mutable.ArrayBuffer +import scala.reflect.classTag import scala.reflect.runtime.universe.TypeTag import org.apache.spark.{SPARK_DOC_ROOT, SparkArithmeticException, SparkRuntimeException, SparkUnsupportedOperationException} @@ -29,6 +30,7 @@ import org.apache.spark.sql.{Encoder, Encoders, Row} import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, OptionalData, PrimitiveData, ScroogeLikeExample} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, TransformingEncoder} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NaNvl} import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -550,6 +552,18 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes encodeDecodeTest(FooClassWithEnum(1, FooEnum.E1), "case class with Int and scala Enum") encodeDecodeTest(FooEnum.E1, "scala Enum") + test("transforming encoder") { + val encoder = ExpressionEncoder(TransformingEncoder( + classTag[(Long, Long)], + BinaryEncoder, + JavaSerializationCodec)) + .resolveAndBind() + assert(encoder.schema == new StructType().add("value", BinaryType)) + val toRow = encoder.createSerializer() + val fromRow = encoder.createDeserializer() + assert(fromRow(toRow((11, 14))) == (11, 14)) + } + // Scala / Java big decimals ---------------------------------------------------------- encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18), diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 17d8444574f61..f3abaddb0110b 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -359,6 +359,13 @@ object ArrowDeserializers { } } + case (TransformingEncoder(_, encoder, provider), v) => + new Deserializer[Any] { + private[this] val codec = provider() + private[this] val deserializer = deserializerFor(encoder, v, timeZoneId) + override def get(i: Int): Any = codec.decode(deserializer.get(i)) + } + case (CalendarIntervalEncoder | VariantEncoder | _: UDTEncoder[_], _) => throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index 4b7b394235458..f8a5c63ac3abe 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -35,7 +35,7 @@ import org.apache.arrow.vector.util.Text import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.DefinedByConstructorParams -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils} import org.apache.spark.sql.connect.client.CloseableIterator @@ -442,6 +442,14 @@ object ArrowSerializer { o => getter.invoke(o) } + case (TransformingEncoder(_, encoder, provider), v) => + new Serializer { + private[this] val codec = provider().asInstanceOf[Codec[Any, Any]] + private[this] val delegate: Serializer = serializerFor(encoder, v) + override def write(index: Int, value: Any): Unit = + delegate.write(index, codec.encode(value)) + } + case (CalendarIntervalEncoder | VariantEncoder | _: UDTEncoder[_], _) => throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType)