From adc9ded0d8fe957b203c047e433381645fe944e9 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 29 Dec 2020 13:12:50 +0800 Subject: [PATCH] [SPARK-31937][SQL] Support processing array and map type using spark noserde mode --- .../sql/catalyst/CatalystTypeConverters.scala | 15 +- .../BaseScriptTransformationExec.scala | 152 ++++++---- .../spark/sql/execution/SparkInspectors.scala | 259 ++++++++++++++++++ .../BaseScriptTransformationSuite.scala | 98 +++++++ 4 files changed, 464 insertions(+), 60 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 907b5877b3ac0..2688e2f893ee7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -174,8 +174,9 @@ object CatalystTypeConverters { convertedIterable += elementConverter.toCatalyst(item) } new GenericArrayData(convertedIterable.toArray) + case g: GenericArrayData => new GenericArrayData(g.array.map(elementConverter.toCatalyst)) case other => throw new IllegalArgumentException( - s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"AAAThe value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to an array of ${elementType.catalogString}") } } @@ -213,6 +214,9 @@ object CatalystTypeConverters { scalaValue match { case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction) case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction) + case map: ArrayBasedMapData => + ArrayBasedMapData(map.keyArray.array.zip(map.valueArray.array).toMap, + keyFunction, valueFunction) case other => throw new IllegalArgumentException( s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + "cannot be converted to a map type with " @@ -263,6 +267,15 @@ object CatalystTypeConverters { idx += 1 } new GenericInternalRow(ar) + case g: GenericInternalRow => + val ar = new Array[Any](structType.size) + val values = g.values + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx).toCatalyst(values(idx)) + idx += 1 + } + new GenericInternalRow(ar) case other => throw new IllegalArgumentException( s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to ${structType.catalogString}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 74e5aa716ad67..6ec03d1c6c54c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream} import java.nio.charset.StandardCharsets +import java.util.Map.Entry import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -33,10 +34,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} trait BaseScriptTransformationExec extends UnaryExecNode { @@ -47,7 +46,12 @@ trait BaseScriptTransformationExec extends UnaryExecNode { def ioschema: ScriptTransformationIOSchema protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = { - input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) + input.map { in: Expression => + in.dataType match { + case ArrayType(_, _) | MapType(_, _, _) | StructType(_) => in + case _ => Cast(in, StringType).withTimeZone(conf.sessionLocalTimeZone) + } + } } override def producedAttributes: AttributeSet = outputSet -- inputSet @@ -182,58 +186,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr => - val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType) - attr.dataType match { - case StringType => wrapperConvertException(data => data, converter) - case BooleanType => wrapperConvertException(data => data.toBoolean, converter) - case ByteType => wrapperConvertException(data => data.toByte, converter) - case BinaryType => - wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter) - case IntegerType => wrapperConvertException(data => data.toInt, converter) - case ShortType => wrapperConvertException(data => data.toShort, converter) - case LongType => wrapperConvertException(data => data.toLong, converter) - case FloatType => wrapperConvertException(data => data.toFloat, converter) - case DoubleType => wrapperConvertException(data => data.toDouble, converter) - case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) - case DateType if conf.datetimeJava8ApiEnabled => - wrapperConvertException(data => DateTimeUtils.stringToDate( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.daysToLocalDate).orNull, converter) - case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaDate).orNull, converter) - case TimestampType if conf.datetimeJava8ApiEnabled => - wrapperConvertException(data => DateTimeUtils.stringToTimestamp( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.microsToInstant).orNull, converter) - case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaTimestamp).orNull, converter) - case CalendarIntervalType => wrapperConvertException( - data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), - converter) - case udt: UserDefinedType[_] => - wrapperConvertException(data => udt.deserialize(data), converter) - case dt => - throw new SparkException(s"${nodeName} without serde does not support " + - s"${dt.getClass.getSimpleName} as output data type") - } + SparkInspectors.unwrapper(attr.dataType, conf, ioschema, 1) } - - // Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null - private val wrapperConvertException: (String => Any, Any => Any) => String => Any = - (f: String => Any, converter: Any => Any) => - (data: String) => converter { - try { - f(data) - } catch { - case NonFatal(_) => null - } - } } abstract class BaseScriptTransformationWriterThread extends Thread with Logging { @@ -256,18 +210,23 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging protected def processRows(): Unit + val wrappers = inputSchema.map(dt => SparkInspectors.wrapper(dt)) + protected def processRowsWithoutSerde(): Unit = { val len = inputSchema.length iter.foreach { row => + val values = row.asInstanceOf[GenericInternalRow].values.zip(wrappers).map { + case (value, wrapper) => wrapper(value) + } val data = if (len == 0) { ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") } else { val sb = new StringBuilder - sb.append(row.get(0, inputSchema(0))) + buildString(sb, values(0), inputSchema(0), 1) var i = 1 while (i < len) { sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - sb.append(row.get(i, inputSchema(i))) + buildString(sb, values(i), inputSchema(i), 1) i += 1 } sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) @@ -277,6 +236,50 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging } } + /** + * Convert data to string according to the data type. + * + * @param sb The StringBuilder to store the serialized data. + * @param obj The object for the current field. + * @param dataType The DataType for the current Object. + * @param level The current level of separator. + */ + private def buildString(sb: StringBuilder, obj: Any, dataType: DataType, level: Int): Unit = { + (obj, dataType) match { + case (list: java.util.List[_], ArrayType(typ, _)) => + val separator = ioSchema.getSeparator(level) + (0 until list.size).foreach { i => + if (i > 0) { + sb.append(separator) + } + buildString(sb, list.get(i), typ, level + 1) + } + case (map: java.util.Map[_, _], MapType(keyType, valueType, _)) => + val separator = ioSchema.getSeparator(level) + val keyValueSeparator = ioSchema.getSeparator(level + 1) + val entries = map.entrySet().toArray() + (0 until entries.size).foreach { i => + if (i > 0) { + sb.append(separator) + } + val entry = entries(i).asInstanceOf[Entry[_, _]] + buildString(sb, entry.getKey, keyType, level + 2) + sb.append(keyValueSeparator) + buildString(sb, entry.getValue, valueType, level + 2) + } + case (arrayList: java.util.ArrayList[_], StructType(fields)) => + val separator = ioSchema.getSeparator(level) + (0 until arrayList.size).foreach { i => + if (i > 0) { + sb.append(separator) + } + buildString(sb, arrayList.get(i), fields(i).dataType, level + 1) + } + case (other, _) => + sb.append(other) + } + } + override def run(): Unit = Utils.logUncaughtExceptions { TaskContext.setTaskContext(taskContext) @@ -329,14 +332,45 @@ case class ScriptTransformationIOSchema( schemaLess: Boolean) extends Serializable { import ScriptTransformationIOSchema._ - val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) + val inputRowFormatMap = inputRowFormat.toMap.withDefault(k => defaultFormat(k)) + val outputRowFormatMap = outputRowFormat.toMap.withDefault(k => defaultFormat(k)) + + val separators = (getByte(inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 0.toByte) :: + getByte(inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS"), 1.toByte) :: + getByte(inputRowFormatMap("TOK_TABLEROWFORMATMAPKEYS"), 2.toByte) :: Nil) ++ + (4 to 8).map(_.toByte) + + def getByte(altValue: String, defaultVal: Byte): Byte = { + if (altValue != null && altValue.length > 0) { + try { + java.lang.Byte.parseByte(altValue) + } catch { + case _: NumberFormatException => + altValue.charAt(0).toByte + } + } else { + defaultVal + } + } + + def getSeparator(level: Int): Char = { + try { + separators(level).toChar + } catch { + case _: IndexOutOfBoundsException => + val msg = "Number of levels of nesting supported for Spark SQL script transform" + + " is " + (separators.length - 1) + " Unable to work with level " + level + throw new RuntimeException(msg) + } + } } object ScriptTransformationIOSchema { val defaultFormat = Map( ("TOK_TABLEROWFORMATFIELD", "\t"), - ("TOK_TABLEROWFORMATLINES", "\n") + ("TOK_TABLEROWFORMATLINES", "\n"), + ("TOK_TABLEROWFORMATCOLLITEMS", "\u0002"), + ("TOK_TABLEROWFORMATMAPKEYS", "\u0003") ) val defaultIOSchema = ScriptTransformationIOSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala new file mode 100644 index 0000000000000..597bde956da9b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.control.NonFatal + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, IntervalUtils, MapData} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +object SparkInspectors { + def wrapper(dataType: DataType): Any => Any = dataType match { + case ArrayType(tpe, _) => + val wp = wrapper(tpe) + withNullSafe { o => + val array = o.asInstanceOf[ArrayData] + val values = new java.util.ArrayList[Any](array.numElements()) + array.foreach(tpe, (_, e) => values.add(wp(e))) + values + } + case MapType(keyType, valueType, _) => + val mt = dataType.asInstanceOf[MapType] + val keyWrapper = wrapper(keyType) + val valueWrapper = wrapper(valueType) + withNullSafe { o => + val map = o.asInstanceOf[MapData] + val jmap = new java.util.HashMap[Any, Any](map.numElements()) + map.foreach(mt.keyType, mt.valueType, (k, v) => + jmap.put(keyWrapper(k), valueWrapper(v))) + jmap + } + case StringType => getStringWritable + case IntegerType => getIntWritable + case DoubleType => getDoubleWritable + case BooleanType => getBooleanWritable + case LongType => getLongWritable + case FloatType => getFloatWritable + case ShortType => getShortWritable + case ByteType => getByteWritable + case NullType => (_: Any) => null + case BinaryType => getBinaryWritable + case DateType => getDateWritable + case TimestampType => getTimestampWritable + // TODO decimal precision? + case DecimalType() => getDecimalWritable + case StructType(fields) => + val structType = dataType.asInstanceOf[StructType] + val wrappers = fields.map(f => wrapper(f.dataType)) + withNullSafe { o => + val row = o.asInstanceOf[InternalRow] + val result = new java.util.ArrayList[AnyRef](wrappers.size) + wrappers.zipWithIndex.foreach { + case (wrapper, i) => + val tpe = structType(i).dataType + result.add(wrapper(row.get(i, tpe)).asInstanceOf[AnyRef]) + } + result + } + case _: UserDefinedType[_] => + val sqlType = dataType.asInstanceOf[UserDefinedType[_]].sqlType + wrapper(sqlType) + } + + private def withNullSafe(f: Any => Any): Any => Any = { + input => + if (input == null) { + null + } else { + f(input) + } + } + + private def getStringWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[UTF8String].toString + } + + private def getIntWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Int] + } + + private def getDoubleWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Double] + } + + private def getBooleanWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Boolean] + } + + private def getLongWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Long] + } + + private def getFloatWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Float] + } + + private def getShortWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Short] + } + + private def getByteWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Byte] + } + + private def getBinaryWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Array[Byte]] + } + + private def getDateWritable(value: Any): Any = + if (value == null) { + null + } else { + DateTimeUtils.toJavaDate(value.asInstanceOf[Int]) + } + + private def getTimestampWritable(value: Any): Any = + if (value == null) { + null + } else { + DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long]) + } + + private def getDecimalWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Decimal] + } + + + def unwrapper( + dataType: DataType, + conf: SQLConf, + ioSchema: ScriptTransformationIOSchema, + level: Int): String => Any = { + val converter = CatalystTypeConverters.createToCatalystConverter(dataType) + dataType match { + case StringType => wrapperConvertException(data => data, converter) + case BooleanType => wrapperConvertException(data => data.toBoolean, converter) + case ByteType => wrapperConvertException(data => data.toByte, converter) + case BinaryType => + wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter) + case IntegerType => wrapperConvertException(data => data.toInt, converter) + case ShortType => wrapperConvertException(data => data.toShort, converter) + case LongType => wrapperConvertException(data => data.toLong, converter) + case FloatType => wrapperConvertException(data => data.toFloat, converter) + case DoubleType => wrapperConvertException(data => data.toDouble, converter) + case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) + case DateType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.daysToLocalDate).orNull, converter) + case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaDate).orNull, converter) + case TimestampType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.microsToInstant).orNull, converter) + case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaTimestamp).orNull, converter) + case CalendarIntervalType => wrapperConvertException( + data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), + converter) + case udt: UserDefinedType[_] => + wrapperConvertException(data => udt.deserialize(data), converter) + case ArrayType(tpe, _) => + val separator = ioSchema.getSeparator(level) + val un = unwrapper(tpe, conf, ioSchema, level + 1) + wrapperConvertException(data => { + data.split(separator) + .map(un).toSeq + }, converter) + case MapType(keyType, valueType, _) => + val separator = ioSchema.getSeparator(level) + val keyValueSeparator = ioSchema.getSeparator(level + 1) + val keyUnwrapper = unwrapper(keyType, conf, ioSchema, level + 2) + val valueUnwrapper = unwrapper(valueType, conf, ioSchema, level + 2) + wrapperConvertException(data => { + val list = data.split(separator) + list.map { kv => + val kvList = kv.split(keyValueSeparator) + keyUnwrapper(kvList(0)) -> valueUnwrapper(kvList(1)) + }.toMap + }, converter) + case StructType(fields) => + val separator = ioSchema.getSeparator(level) + val unwrappers = fields.map(f => unwrapper(f.dataType, conf, ioSchema, level + 1)) + wrapperConvertException(data => { + val list = data.split(separator) + Row.fromSeq(list.zip(unwrappers).map { + case (data: String, unwrapper: (String => Any)) => unwrapper(data) + }) + }, converter) + case _ => wrapperConvertException(data => data, converter) + } + } + + // Keep consistent with Hive `LazySimpleSerDe`, when there is a type case error, return null + private val wrapperConvertException: (String => Any, Any => Any) => String => Any = + (f: String => Any, converter: Any => Any) => + (data: String) => converter { + try { + f(data) + } catch { + case NonFatal(_) => null + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 863657a7862a6..8071b0181af93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -440,6 +440,104 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU } } } + + test("SPARK-31936: Script transform support Array/MapType/StructType (no serde)") { + assume(TestUtils.testCommandAvailable("python")) + withTempView("v") { + val df = Seq( + (Array(0, 1, 2), Array(Array(0, 1), Array(2)), + Map("a" -> 1), Map("b" -> Array("a", "b"))), + (Array(3, 4, 5), Array(Array(3, 4), Array(5)), + Map("b" -> 2), Map("c" -> Array("c", "d"))), + (Array(6, 7, 8), Array(Array(6, 7), Array(8)), + Map("c" -> 3), Map("d" -> Array("e", "f"))) + ).toDF("a", "b", "c", "d") + .select('a, 'b, 'c, 'd, + struct('a, 'b).as("e"), + struct('a, 'd).as("f"), + struct(struct('a, 'b), struct('a, 'd)).as("g") + ) + + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr, + df.col("e").expr, + df.col("f").expr, + df.col("g").expr), + script = "cat", + output = Seq( + AttributeReference("a", ArrayType(IntegerType))(), + AttributeReference("b", ArrayType(ArrayType(IntegerType)))(), + AttributeReference("c", MapType(StringType, IntegerType))(), + AttributeReference("d", MapType(StringType, ArrayType(StringType)))(), + AttributeReference("e", StructType( + Array(StructField("col1", ArrayType(IntegerType)), + StructField("col2", ArrayType(ArrayType(IntegerType))))))(), + AttributeReference("f", StructType( + Array(StructField("col1", ArrayType(IntegerType)), + StructField("col2", MapType(StringType, ArrayType(StringType))))))(), + AttributeReference("g", StructType( + Array(StructField("col1", StructType( + Array(StructField("col1", ArrayType(IntegerType)), + StructField("col2", ArrayType(ArrayType(IntegerType)))))), + StructField("col2", StructType( + Array(StructField("col1", ArrayType(IntegerType)), + StructField("col2", MapType(StringType, ArrayType(StringType)))))))))()), + child = child, + ioschema = defaultIOSchema + ), + df.select('a, 'b, 'c, 'd, 'e, 'f, 'g).collect()) + } + } + + test("SPARK-31936: Script transform support 7 level nested complex type (no serde)") { + assume(TestUtils.testCommandAvailable("python")) + withTempView("v") { + val df = Seq( + Array(1) + ).toDF("a").select('a, + array(array(array(array(array(array('a)))))).as("level_7"), + array(array(array(array(array(array(array('a))))))).as("level_8") + ) + + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("level_7").expr), + script = "cat", + output = Seq( + AttributeReference("a", ArrayType(ArrayType(ArrayType(ArrayType(ArrayType( + ArrayType(ArrayType(IntegerType))))))))()), + child = child, + ioschema = defaultIOSchema + ), + df.select('level_7).collect()) + + val e = intercept[RuntimeException] { + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("level_8").expr), + script = "cat", + output = Seq( + AttributeReference("a", ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType( + ArrayType(ArrayType(IntegerType)))))))))()), + child = child, + ioschema = defaultIOSchema + ), + df.select('level_8).collect()) + }.getMessage + assert(e.contains("Number of levels of nesting supported for Spark SQL" + + " script transform is 7 Unable to work with level 8")) + } + } } case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode {