From adc9ded0d8fe957b203c047e433381645fe944e9 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 29 Dec 2020 13:12:50 +0800 Subject: [PATCH 01/22] [SPARK-31937][SQL] Support processing array and map type using spark noserde mode --- .../sql/catalyst/CatalystTypeConverters.scala | 15 +- .../BaseScriptTransformationExec.scala | 152 ++++++---- .../spark/sql/execution/SparkInspectors.scala | 259 ++++++++++++++++++ .../BaseScriptTransformationSuite.scala | 98 +++++++ 4 files changed, 464 insertions(+), 60 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 907b5877b3ac0..2688e2f893ee7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -174,8 +174,9 @@ object CatalystTypeConverters { convertedIterable += elementConverter.toCatalyst(item) } new GenericArrayData(convertedIterable.toArray) + case g: GenericArrayData => new GenericArrayData(g.array.map(elementConverter.toCatalyst)) case other => throw new IllegalArgumentException( - s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"AAAThe value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to an array of ${elementType.catalogString}") } } @@ -213,6 +214,9 @@ object CatalystTypeConverters { scalaValue match { case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction) case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction) + case map: ArrayBasedMapData => + ArrayBasedMapData(map.keyArray.array.zip(map.valueArray.array).toMap, + keyFunction, valueFunction) case other => throw new IllegalArgumentException( s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + "cannot be converted to a map type with " @@ -263,6 +267,15 @@ object CatalystTypeConverters { idx += 1 } new GenericInternalRow(ar) + case g: GenericInternalRow => + val ar = new Array[Any](structType.size) + val values = g.values + var idx = 0 + while (idx < structType.size) { + ar(idx) = converters(idx).toCatalyst(values(idx)) + idx += 1 + } + new GenericInternalRow(ar) case other => throw new IllegalArgumentException( s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to ${structType.catalogString}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 74e5aa716ad67..6ec03d1c6c54c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream} import java.nio.charset.StandardCharsets +import java.util.Map.Entry import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -33,10 +34,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} trait BaseScriptTransformationExec extends UnaryExecNode { @@ -47,7 +46,12 @@ trait BaseScriptTransformationExec extends UnaryExecNode { def ioschema: ScriptTransformationIOSchema protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = { - input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) + input.map { in: Expression => + in.dataType match { + case ArrayType(_, _) | MapType(_, _, _) | StructType(_) => in + case _ => Cast(in, StringType).withTimeZone(conf.sessionLocalTimeZone) + } + } } override def producedAttributes: AttributeSet = outputSet -- inputSet @@ -182,58 +186,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr => - val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType) - attr.dataType match { - case StringType => wrapperConvertException(data => data, converter) - case BooleanType => wrapperConvertException(data => data.toBoolean, converter) - case ByteType => wrapperConvertException(data => data.toByte, converter) - case BinaryType => - wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter) - case IntegerType => wrapperConvertException(data => data.toInt, converter) - case ShortType => wrapperConvertException(data => data.toShort, converter) - case LongType => wrapperConvertException(data => data.toLong, converter) - case FloatType => wrapperConvertException(data => data.toFloat, converter) - case DoubleType => wrapperConvertException(data => data.toDouble, converter) - case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) - case DateType if conf.datetimeJava8ApiEnabled => - wrapperConvertException(data => DateTimeUtils.stringToDate( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.daysToLocalDate).orNull, converter) - case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaDate).orNull, converter) - case TimestampType if conf.datetimeJava8ApiEnabled => - wrapperConvertException(data => DateTimeUtils.stringToTimestamp( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.microsToInstant).orNull, converter) - case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaTimestamp).orNull, converter) - case CalendarIntervalType => wrapperConvertException( - data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), - converter) - case udt: UserDefinedType[_] => - wrapperConvertException(data => udt.deserialize(data), converter) - case dt => - throw new SparkException(s"${nodeName} without serde does not support " + - s"${dt.getClass.getSimpleName} as output data type") - } + SparkInspectors.unwrapper(attr.dataType, conf, ioschema, 1) } - - // Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null - private val wrapperConvertException: (String => Any, Any => Any) => String => Any = - (f: String => Any, converter: Any => Any) => - (data: String) => converter { - try { - f(data) - } catch { - case NonFatal(_) => null - } - } } abstract class BaseScriptTransformationWriterThread extends Thread with Logging { @@ -256,18 +210,23 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging protected def processRows(): Unit + val wrappers = inputSchema.map(dt => SparkInspectors.wrapper(dt)) + protected def processRowsWithoutSerde(): Unit = { val len = inputSchema.length iter.foreach { row => + val values = row.asInstanceOf[GenericInternalRow].values.zip(wrappers).map { + case (value, wrapper) => wrapper(value) + } val data = if (len == 0) { ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") } else { val sb = new StringBuilder - sb.append(row.get(0, inputSchema(0))) + buildString(sb, values(0), inputSchema(0), 1) var i = 1 while (i < len) { sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - sb.append(row.get(i, inputSchema(i))) + buildString(sb, values(i), inputSchema(i), 1) i += 1 } sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) @@ -277,6 +236,50 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging } } + /** + * Convert data to string according to the data type. + * + * @param sb The StringBuilder to store the serialized data. + * @param obj The object for the current field. + * @param dataType The DataType for the current Object. + * @param level The current level of separator. + */ + private def buildString(sb: StringBuilder, obj: Any, dataType: DataType, level: Int): Unit = { + (obj, dataType) match { + case (list: java.util.List[_], ArrayType(typ, _)) => + val separator = ioSchema.getSeparator(level) + (0 until list.size).foreach { i => + if (i > 0) { + sb.append(separator) + } + buildString(sb, list.get(i), typ, level + 1) + } + case (map: java.util.Map[_, _], MapType(keyType, valueType, _)) => + val separator = ioSchema.getSeparator(level) + val keyValueSeparator = ioSchema.getSeparator(level + 1) + val entries = map.entrySet().toArray() + (0 until entries.size).foreach { i => + if (i > 0) { + sb.append(separator) + } + val entry = entries(i).asInstanceOf[Entry[_, _]] + buildString(sb, entry.getKey, keyType, level + 2) + sb.append(keyValueSeparator) + buildString(sb, entry.getValue, valueType, level + 2) + } + case (arrayList: java.util.ArrayList[_], StructType(fields)) => + val separator = ioSchema.getSeparator(level) + (0 until arrayList.size).foreach { i => + if (i > 0) { + sb.append(separator) + } + buildString(sb, arrayList.get(i), fields(i).dataType, level + 1) + } + case (other, _) => + sb.append(other) + } + } + override def run(): Unit = Utils.logUncaughtExceptions { TaskContext.setTaskContext(taskContext) @@ -329,14 +332,45 @@ case class ScriptTransformationIOSchema( schemaLess: Boolean) extends Serializable { import ScriptTransformationIOSchema._ - val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) - val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) + val inputRowFormatMap = inputRowFormat.toMap.withDefault(k => defaultFormat(k)) + val outputRowFormatMap = outputRowFormat.toMap.withDefault(k => defaultFormat(k)) + + val separators = (getByte(inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 0.toByte) :: + getByte(inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS"), 1.toByte) :: + getByte(inputRowFormatMap("TOK_TABLEROWFORMATMAPKEYS"), 2.toByte) :: Nil) ++ + (4 to 8).map(_.toByte) + + def getByte(altValue: String, defaultVal: Byte): Byte = { + if (altValue != null && altValue.length > 0) { + try { + java.lang.Byte.parseByte(altValue) + } catch { + case _: NumberFormatException => + altValue.charAt(0).toByte + } + } else { + defaultVal + } + } + + def getSeparator(level: Int): Char = { + try { + separators(level).toChar + } catch { + case _: IndexOutOfBoundsException => + val msg = "Number of levels of nesting supported for Spark SQL script transform" + + " is " + (separators.length - 1) + " Unable to work with level " + level + throw new RuntimeException(msg) + } + } } object ScriptTransformationIOSchema { val defaultFormat = Map( ("TOK_TABLEROWFORMATFIELD", "\t"), - ("TOK_TABLEROWFORMATLINES", "\n") + ("TOK_TABLEROWFORMATLINES", "\n"), + ("TOK_TABLEROWFORMATCOLLITEMS", "\u0002"), + ("TOK_TABLEROWFORMATMAPKEYS", "\u0003") ) val defaultIOSchema = ScriptTransformationIOSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala new file mode 100644 index 0000000000000..597bde956da9b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.control.NonFatal + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, IntervalUtils, MapData} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +object SparkInspectors { + def wrapper(dataType: DataType): Any => Any = dataType match { + case ArrayType(tpe, _) => + val wp = wrapper(tpe) + withNullSafe { o => + val array = o.asInstanceOf[ArrayData] + val values = new java.util.ArrayList[Any](array.numElements()) + array.foreach(tpe, (_, e) => values.add(wp(e))) + values + } + case MapType(keyType, valueType, _) => + val mt = dataType.asInstanceOf[MapType] + val keyWrapper = wrapper(keyType) + val valueWrapper = wrapper(valueType) + withNullSafe { o => + val map = o.asInstanceOf[MapData] + val jmap = new java.util.HashMap[Any, Any](map.numElements()) + map.foreach(mt.keyType, mt.valueType, (k, v) => + jmap.put(keyWrapper(k), valueWrapper(v))) + jmap + } + case StringType => getStringWritable + case IntegerType => getIntWritable + case DoubleType => getDoubleWritable + case BooleanType => getBooleanWritable + case LongType => getLongWritable + case FloatType => getFloatWritable + case ShortType => getShortWritable + case ByteType => getByteWritable + case NullType => (_: Any) => null + case BinaryType => getBinaryWritable + case DateType => getDateWritable + case TimestampType => getTimestampWritable + // TODO decimal precision? + case DecimalType() => getDecimalWritable + case StructType(fields) => + val structType = dataType.asInstanceOf[StructType] + val wrappers = fields.map(f => wrapper(f.dataType)) + withNullSafe { o => + val row = o.asInstanceOf[InternalRow] + val result = new java.util.ArrayList[AnyRef](wrappers.size) + wrappers.zipWithIndex.foreach { + case (wrapper, i) => + val tpe = structType(i).dataType + result.add(wrapper(row.get(i, tpe)).asInstanceOf[AnyRef]) + } + result + } + case _: UserDefinedType[_] => + val sqlType = dataType.asInstanceOf[UserDefinedType[_]].sqlType + wrapper(sqlType) + } + + private def withNullSafe(f: Any => Any): Any => Any = { + input => + if (input == null) { + null + } else { + f(input) + } + } + + private def getStringWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[UTF8String].toString + } + + private def getIntWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Int] + } + + private def getDoubleWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Double] + } + + private def getBooleanWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Boolean] + } + + private def getLongWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Long] + } + + private def getFloatWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Float] + } + + private def getShortWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Short] + } + + private def getByteWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Byte] + } + + private def getBinaryWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Array[Byte]] + } + + private def getDateWritable(value: Any): Any = + if (value == null) { + null + } else { + DateTimeUtils.toJavaDate(value.asInstanceOf[Int]) + } + + private def getTimestampWritable(value: Any): Any = + if (value == null) { + null + } else { + DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long]) + } + + private def getDecimalWritable(value: Any): Any = + if (value == null) { + null + } else { + value.asInstanceOf[Decimal] + } + + + def unwrapper( + dataType: DataType, + conf: SQLConf, + ioSchema: ScriptTransformationIOSchema, + level: Int): String => Any = { + val converter = CatalystTypeConverters.createToCatalystConverter(dataType) + dataType match { + case StringType => wrapperConvertException(data => data, converter) + case BooleanType => wrapperConvertException(data => data.toBoolean, converter) + case ByteType => wrapperConvertException(data => data.toByte, converter) + case BinaryType => + wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter) + case IntegerType => wrapperConvertException(data => data.toInt, converter) + case ShortType => wrapperConvertException(data => data.toShort, converter) + case LongType => wrapperConvertException(data => data.toLong, converter) + case FloatType => wrapperConvertException(data => data.toFloat, converter) + case DoubleType => wrapperConvertException(data => data.toDouble, converter) + case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) + case DateType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.daysToLocalDate).orNull, converter) + case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaDate).orNull, converter) + case TimestampType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.microsToInstant).orNull, converter) + case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaTimestamp).orNull, converter) + case CalendarIntervalType => wrapperConvertException( + data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), + converter) + case udt: UserDefinedType[_] => + wrapperConvertException(data => udt.deserialize(data), converter) + case ArrayType(tpe, _) => + val separator = ioSchema.getSeparator(level) + val un = unwrapper(tpe, conf, ioSchema, level + 1) + wrapperConvertException(data => { + data.split(separator) + .map(un).toSeq + }, converter) + case MapType(keyType, valueType, _) => + val separator = ioSchema.getSeparator(level) + val keyValueSeparator = ioSchema.getSeparator(level + 1) + val keyUnwrapper = unwrapper(keyType, conf, ioSchema, level + 2) + val valueUnwrapper = unwrapper(valueType, conf, ioSchema, level + 2) + wrapperConvertException(data => { + val list = data.split(separator) + list.map { kv => + val kvList = kv.split(keyValueSeparator) + keyUnwrapper(kvList(0)) -> valueUnwrapper(kvList(1)) + }.toMap + }, converter) + case StructType(fields) => + val separator = ioSchema.getSeparator(level) + val unwrappers = fields.map(f => unwrapper(f.dataType, conf, ioSchema, level + 1)) + wrapperConvertException(data => { + val list = data.split(separator) + Row.fromSeq(list.zip(unwrappers).map { + case (data: String, unwrapper: (String => Any)) => unwrapper(data) + }) + }, converter) + case _ => wrapperConvertException(data => data, converter) + } + } + + // Keep consistent with Hive `LazySimpleSerDe`, when there is a type case error, return null + private val wrapperConvertException: (String => Any, Any => Any) => String => Any = + (f: String => Any, converter: Any => Any) => + (data: String) => converter { + try { + f(data) + } catch { + case NonFatal(_) => null + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 863657a7862a6..8071b0181af93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -440,6 +440,104 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU } } } + + test("SPARK-31936: Script transform support Array/MapType/StructType (no serde)") { + assume(TestUtils.testCommandAvailable("python")) + withTempView("v") { + val df = Seq( + (Array(0, 1, 2), Array(Array(0, 1), Array(2)), + Map("a" -> 1), Map("b" -> Array("a", "b"))), + (Array(3, 4, 5), Array(Array(3, 4), Array(5)), + Map("b" -> 2), Map("c" -> Array("c", "d"))), + (Array(6, 7, 8), Array(Array(6, 7), Array(8)), + Map("c" -> 3), Map("d" -> Array("e", "f"))) + ).toDF("a", "b", "c", "d") + .select('a, 'b, 'c, 'd, + struct('a, 'b).as("e"), + struct('a, 'd).as("f"), + struct(struct('a, 'b), struct('a, 'd)).as("g") + ) + + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr, + df.col("e").expr, + df.col("f").expr, + df.col("g").expr), + script = "cat", + output = Seq( + AttributeReference("a", ArrayType(IntegerType))(), + AttributeReference("b", ArrayType(ArrayType(IntegerType)))(), + AttributeReference("c", MapType(StringType, IntegerType))(), + AttributeReference("d", MapType(StringType, ArrayType(StringType)))(), + AttributeReference("e", StructType( + Array(StructField("col1", ArrayType(IntegerType)), + StructField("col2", ArrayType(ArrayType(IntegerType))))))(), + AttributeReference("f", StructType( + Array(StructField("col1", ArrayType(IntegerType)), + StructField("col2", MapType(StringType, ArrayType(StringType))))))(), + AttributeReference("g", StructType( + Array(StructField("col1", StructType( + Array(StructField("col1", ArrayType(IntegerType)), + StructField("col2", ArrayType(ArrayType(IntegerType)))))), + StructField("col2", StructType( + Array(StructField("col1", ArrayType(IntegerType)), + StructField("col2", MapType(StringType, ArrayType(StringType)))))))))()), + child = child, + ioschema = defaultIOSchema + ), + df.select('a, 'b, 'c, 'd, 'e, 'f, 'g).collect()) + } + } + + test("SPARK-31936: Script transform support 7 level nested complex type (no serde)") { + assume(TestUtils.testCommandAvailable("python")) + withTempView("v") { + val df = Seq( + Array(1) + ).toDF("a").select('a, + array(array(array(array(array(array('a)))))).as("level_7"), + array(array(array(array(array(array(array('a))))))).as("level_8") + ) + + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("level_7").expr), + script = "cat", + output = Seq( + AttributeReference("a", ArrayType(ArrayType(ArrayType(ArrayType(ArrayType( + ArrayType(ArrayType(IntegerType))))))))()), + child = child, + ioschema = defaultIOSchema + ), + df.select('level_7).collect()) + + val e = intercept[RuntimeException] { + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("level_8").expr), + script = "cat", + output = Seq( + AttributeReference("a", ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType( + ArrayType(ArrayType(IntegerType)))))))))()), + child = child, + ioschema = defaultIOSchema + ), + df.select('level_8).collect()) + }.getMessage + assert(e.contains("Number of levels of nesting supported for Spark SQL" + + " script transform is 7 Unable to work with level 8")) + } + } } case class ExceptionInjectingOperator(child: SparkPlan) extends UnaryExecNode { From 6a7438bf6574d35ed841a7301f50003b4fb12341 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 29 Dec 2020 13:38:41 +0800 Subject: [PATCH 02/22] Update CatalystTypeConverters.scala --- .../org/apache/spark/sql/catalyst/CatalystTypeConverters.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 2688e2f893ee7..5e3a7d0aa0b5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -176,7 +176,7 @@ object CatalystTypeConverters { new GenericArrayData(convertedIterable.toArray) case g: GenericArrayData => new GenericArrayData(g.array.map(elementConverter.toCatalyst)) case other => throw new IllegalArgumentException( - s"AAAThe value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to an array of ${elementType.catalogString}") } } From d3b9cec8d2c2b46d12760b58872346dceb389223 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 29 Dec 2020 16:42:14 +0800 Subject: [PATCH 03/22] fix failed UT --- .../BaseScriptTransformationSuite.scala | 2 +- .../SparkScriptTransformationSuite.scala | 40 ------------------- 2 files changed, 1 insertion(+), 41 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 8071b0181af93..7dbccb55dab5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -304,7 +304,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU child = child, ioschema = defaultIOSchema ), - df.select('a, 'b.cast("string"), 'c.cast("string"), 'd.cast("string"), 'e).collect()) + df.select('a, 'b, 'c, 'd, 'e).collect()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala index 6ff7c5d6d2f3a..a9eac506b0da6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -59,44 +59,4 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with assert(e.contains("TRANSFORM with serde is only supported in hive mode")) } } - - test("SPARK-32106: TRANSFORM doesn't support ArrayType/MapType/StructType " + - "as output data type (no serde)") { - assume(TestUtils.testCommandAvailable("/bin/bash")) - // check for ArrayType - val e1 = intercept[SparkException] { - sql( - """ - |SELECT TRANSFORM(a) - |USING 'cat' AS (a array) - |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) - """.stripMargin).collect() - }.getMessage - assert(e1.contains("SparkScriptTransformation without serde does not support" + - " ArrayType as output data type")) - - // check for MapType - val e2 = intercept[SparkException] { - sql( - """ - |SELECT TRANSFORM(b) - |USING 'cat' AS (b map) - |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) - """.stripMargin).collect() - }.getMessage - assert(e2.contains("SparkScriptTransformation without serde does not support" + - " MapType as output data type")) - - // check for StructType - val e3 = intercept[SparkException] { - sql( - """ - |SELECT TRANSFORM(c) - |USING 'cat' AS (c struct) - |FROM VALUES (array(1, 1), map('1', 1), struct(1, 'a')) t(a, b, c) - """.stripMargin).collect() - }.getMessage - assert(e3.contains("SparkScriptTransformation without serde does not support" + - " StructType as output data type")) - } } From fdd52256b82266719a32d6156c054a42366f1abc Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 29 Dec 2020 17:18:09 +0800 Subject: [PATCH 04/22] Update SparkScriptTransformationSuite.scala --- .../spark/sql/execution/SparkScriptTransformationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala index a9eac506b0da6..e5aa3bfacd9ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.{SparkException, TestUtils} +import org.apache.spark.TestUtils import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.test.SharedSparkSession From aa16c8f3808975e0c2bb43be01bc87d0885756f3 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 29 Dec 2020 20:35:45 +0800 Subject: [PATCH 05/22] Update BaseScriptTransformationSuite.scala --- .../sql/execution/BaseScriptTransformationSuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 7dbccb55dab5d..2e45fd7242c32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -297,9 +297,11 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU script = "cat", output = Seq( AttributeReference("a", CalendarIntervalType)(), - AttributeReference("b", StringType)(), - AttributeReference("c", StringType)(), - AttributeReference("d", StringType)(), + AttributeReference("b", ArrayType(IntegerType))(), + AttributeReference("c", MapType(StringType, IntegerType))(), + AttributeReference("d", StructType( + Array(StructField("col1", IntegerType), + StructField("col2", IntegerType))))(), AttributeReference("e", new SimpleTupleUDT)()), child = child, ioschema = defaultIOSchema From 092c927186363e07dd5c3636858249a952b440e2 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 29 Dec 2020 22:28:35 +0800 Subject: [PATCH 06/22] Update BaseScriptTransformationExec.scala --- .../spark/sql/execution/BaseScriptTransformationExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 6ec03d1c6c54c..b924a4ac3b856 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -48,7 +48,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = { input.map { in: Expression => in.dataType match { - case ArrayType(_, _) | MapType(_, _, _) | StructType(_) => in + case _: ArrayType | _: MapType | _: StructType => in case _ => Cast(in, StringType).withTimeZone(conf.sessionLocalTimeZone) } } From 28ad7faccbe10fc75d063a3141d50613d665d94b Mon Sep 17 00:00:00 2001 From: angerszhu Date: Tue, 29 Dec 2020 22:34:19 +0800 Subject: [PATCH 07/22] Update BaseScriptTransformationSuite.scala --- .../BaseScriptTransformationSuite.scala | 145 +++++++++--------- 1 file changed, 74 insertions(+), 71 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 4d6faae983514..e12dce18f01df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -448,58 +448,83 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU } } - test("SPARK-31936: Script transform support Array/MapType/StructType (no serde)") { - assume(TestUtils.testCommandAvailable("python")) - withTempView("v") { - val df = Seq( - (Array(0, 1, 2), Array(Array(0, 1), Array(2)), - Map("a" -> 1), Map("b" -> Array("a", "b"))), - (Array(3, 4, 5), Array(Array(3, 4), Array(5)), - Map("b" -> 2), Map("c" -> Array("c", "d"))), - (Array(6, 7, 8), Array(Array(6, 7), Array(8)), - Map("c" -> 3), Map("d" -> Array("e", "f"))) - ).toDF("a", "b", "c", "d") - .select('a, 'b, 'c, 'd, - struct('a, 'b).as("e"), - struct('a, 'd).as("f"), - struct(struct('a, 'b), struct('a, 'd)).as("g") - ) + test("SPARK-33930: Script Transform default FIELD DELIMIT should be \u0001 (no serde)") { + withTempView("v") { + val df = Seq( + (1, 2, 3), + (2, 3, 4), + (3, 4, 5) + ).toDF("a", "b", "c") + df.createTempView("v") - checkAnswer( - df, - (child: SparkPlan) => createScriptTransformationExec( - input = Seq( - df.col("a").expr, - df.col("b").expr, - df.col("c").expr, - df.col("d").expr, - df.col("e").expr, - df.col("f").expr, - df.col("g").expr), - script = "cat", - output = Seq( - AttributeReference("a", ArrayType(IntegerType))(), - AttributeReference("b", ArrayType(ArrayType(IntegerType)))(), - AttributeReference("c", MapType(StringType, IntegerType))(), - AttributeReference("d", MapType(StringType, ArrayType(StringType)))(), - AttributeReference("e", StructType( - Array(StructField("col1", ArrayType(IntegerType)), - StructField("col2", ArrayType(ArrayType(IntegerType))))))(), - AttributeReference("f", StructType( + checkAnswer( + sql( + s""" + |SELECT TRANSFORM(a, b, c) + | ROW FORMAT DELIMITED + | USING 'cat' AS (a) + | ROW FORMAT DELIMITED + | FIELDS TERMINATED BY '&' + |FROM v + """.stripMargin), identity, + Row("1\u00012\u00013") :: + Row("2\u00013\u00014") :: + Row("3\u00014\u00015") :: Nil) + } + } + + test("SPARK-31936: Script transform support ArrayType/MapType/StructType (no serde)") { + assume(TestUtils.testCommandAvailable("python")) + withTempView("v") { + val df = Seq( + (Array(0, 1, 2), Array(Array(0, 1), Array(2)), + Map("a" -> 1), Map("b" -> Array("a", "b"))), + (Array(3, 4, 5), Array(Array(3, 4), Array(5)), + Map("b" -> 2), Map("c" -> Array("c", "d"))), + (Array(6, 7, 8), Array(Array(6, 7), Array(8)), + Map("c" -> 3), Map("d" -> Array("e", "f"))) + ).toDF("a", "b", "c", "d") + .select('a, 'b, 'c, 'd, + struct('a, 'b).as("e"), + struct('a, 'd).as("f"), + struct(struct('a, 'b), struct('a, 'd)).as("g") + ) + + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr, + df.col("e").expr, + df.col("f").expr, + df.col("g").expr), + script = "cat", + output = Seq( + AttributeReference("a", ArrayType(IntegerType))(), + AttributeReference("b", ArrayType(ArrayType(IntegerType)))(), + AttributeReference("c", MapType(StringType, IntegerType))(), + AttributeReference("d", MapType(StringType, ArrayType(StringType)))(), + AttributeReference("e", StructType( + Array(StructField("col1", ArrayType(IntegerType)), + StructField("col2", ArrayType(ArrayType(IntegerType))))))(), + AttributeReference("f", StructType( + Array(StructField("col1", ArrayType(IntegerType)), + StructField("col2", MapType(StringType, ArrayType(StringType))))))(), + AttributeReference("g", StructType( + Array(StructField("col1", StructType( Array(StructField("col1", ArrayType(IntegerType)), - StructField("col2", MapType(StringType, ArrayType(StringType))))))(), - AttributeReference("g", StructType( - Array(StructField("col1", StructType( + StructField("col2", ArrayType(ArrayType(IntegerType)))))), + StructField("col2", StructType( Array(StructField("col1", ArrayType(IntegerType)), - StructField("col2", ArrayType(ArrayType(IntegerType)))))), - StructField("col2", StructType( - Array(StructField("col1", ArrayType(IntegerType)), - StructField("col2", MapType(StringType, ArrayType(StringType)))))))))()), - child = child, - ioschema = defaultIOSchema - ), - df.select('a, 'b, 'c, 'd, 'e, 'f, 'g).collect()) - } + StructField("col2", MapType(StringType, ArrayType(StringType)))))))))()), + child = child, + ioschema = defaultIOSchema + ), + df.select('a, 'b, 'c, 'd, 'e, 'f, 'g).collect()) + } } test("SPARK-31936: Script transform support 7 level nested complex type (no serde)") { @@ -543,28 +568,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU }.getMessage assert(e.contains("Number of levels of nesting supported for Spark SQL" + " script transform is 7 Unable to work with level 8")) - test("SPARK-33930: Script Transform default FIELD DELIMIT should be \u0001 (no serde)") { - withTempView("v") { - val df = Seq( - (1, 2, 3), - (2, 3, 4), - (3, 4, 5) - ).toDF("a", "b", "c") - df.createTempView("v") - - checkAnswer( - sql( - s""" - |SELECT TRANSFORM(a, b, c) - | ROW FORMAT DELIMITED - | USING 'cat' AS (a) - | ROW FORMAT DELIMITED - | FIELDS TERMINATED BY '&' - |FROM v - """.stripMargin), identity, - Row("1\u00012\u00013") :: - Row("2\u00013\u00014") :: - Row("3\u00014\u00015") :: Nil) } } } From 33d8b5ba444bc5bcc12197aedaee54e37cd62dd4 Mon Sep 17 00:00:00 2001 From: angerszhu Date: Mon, 4 Jan 2021 15:04:04 +0800 Subject: [PATCH 08/22] Update BaseScriptTransformationExec.scala --- .../spark/sql/execution/BaseScriptTransformationExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index dbc43a103f69c..23e4ca02b0e3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -371,8 +371,8 @@ case class ScriptTransformationIOSchema( object ScriptTransformationIOSchema { val defaultFormat = Map( - ("TOK_TABLEROWFORMATFIELD", "\u0001"), ("TOK_TABLEROWFORMATLINES", "\n"), + ("TOK_TABLEROWFORMATFIELD", "\u0001"), ("TOK_TABLEROWFORMATCOLLITEMS", "\u0002"), ("TOK_TABLEROWFORMATMAPKEYS", "\u0003") ) From 63f07ebb984b390f91bda72436bdea7f757c6150 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Thu, 4 Feb 2021 16:49:54 +0800 Subject: [PATCH 09/22] follow comment --- .../BaseScriptTransformationExec.scala | 154 +++++------ .../spark/sql/execution/SparkInspectors.scala | 259 ------------------ .../BaseScriptTransformationSuite.scala | 50 ++-- 3 files changed, 88 insertions(+), 375 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 23e4ca02b0e3a..2249d7509c73a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import java.io._ import java.nio.charset.StandardCharsets -import java.util.Map.Entry import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -31,11 +30,13 @@ import org.apache.spark.{SparkException, SparkFiles, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, JsonToStructs, Literal, StructsToJson, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} trait BaseScriptTransformationExec extends UnaryExecNode { @@ -48,7 +49,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode { protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = { input.map { in: Expression => in.dataType match { - case _: ArrayType | _: MapType | _: StructType => in + case _: ArrayType | _: MapType | _: StructType => + new StructsToJson(in).withTimeZone(conf.sessionLocalTimeZone) case _ => Cast(in, StringType).withTimeZone(conf.sessionLocalTimeZone) } } @@ -190,8 +192,62 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr => - SparkInspectors.unwrapper(attr.dataType, conf, ioschema, 1) + val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType) + attr.dataType match { + case StringType => wrapperConvertException(data => data, converter) + case BooleanType => wrapperConvertException(data => data.toBoolean, converter) + case ByteType => wrapperConvertException(data => data.toByte, converter) + case BinaryType => + wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter) + case IntegerType => wrapperConvertException(data => data.toInt, converter) + case ShortType => wrapperConvertException(data => data.toShort, converter) + case LongType => wrapperConvertException(data => data.toLong, converter) + case FloatType => wrapperConvertException(data => data.toFloat, converter) + case DoubleType => wrapperConvertException(data => data.toDouble, converter) + case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) + case DateType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.daysToLocalDate).orNull, converter) + case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaDate).orNull, converter) + case TimestampType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.microsToInstant).orNull, converter) + case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaTimestamp).orNull, converter) + case CalendarIntervalType => wrapperConvertException( + data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), + converter) + case _: ArrayType | _: MapType | _: StructType => wrapperConvertException(data => { + JsonToStructs(attr.dataType, Map.empty[String, String], + Literal(data), Some(conf.sessionLocalTimeZone)).eval() + }, converter) + case udt: UserDefinedType[_] => + wrapperConvertException(data => udt.deserialize(data), converter) + case dt => + throw new SparkException(s"${nodeName} without serde does not support " + + s"${dt.getClass.getSimpleName} as output data type") + } } + + // Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null + private val wrapperConvertException: (String => Any, Any => Any) => String => Any = + (f: String => Any, converter: Any => Any) => + (data: String) => converter { + try { + f(data) + } catch { + case NonFatal(_) => null + } + } } abstract class BaseScriptTransformationWriterThread extends Thread with Logging { @@ -214,23 +270,18 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging protected def processRows(): Unit - val wrappers = inputSchema.map(dt => SparkInspectors.wrapper(dt)) - protected def processRowsWithoutSerde(): Unit = { val len = inputSchema.length iter.foreach { row => - val values = row.asInstanceOf[GenericInternalRow].values.zip(wrappers).map { - case (value, wrapper) => wrapper(value) - } val data = if (len == 0) { ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES") } else { val sb = new StringBuilder - buildString(sb, values(0), inputSchema(0), 1) + sb.append(row.get(0, inputSchema(0))) var i = 1 while (i < len) { sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - buildString(sb, values(i), inputSchema(i), 1) + sb.append(row.get(i, inputSchema(i))) i += 1 } sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")) @@ -240,50 +291,6 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging } } - /** - * Convert data to string according to the data type. - * - * @param sb The StringBuilder to store the serialized data. - * @param obj The object for the current field. - * @param dataType The DataType for the current Object. - * @param level The current level of separator. - */ - private def buildString(sb: StringBuilder, obj: Any, dataType: DataType, level: Int): Unit = { - (obj, dataType) match { - case (list: java.util.List[_], ArrayType(typ, _)) => - val separator = ioSchema.getSeparator(level) - (0 until list.size).foreach { i => - if (i > 0) { - sb.append(separator) - } - buildString(sb, list.get(i), typ, level + 1) - } - case (map: java.util.Map[_, _], MapType(keyType, valueType, _)) => - val separator = ioSchema.getSeparator(level) - val keyValueSeparator = ioSchema.getSeparator(level + 1) - val entries = map.entrySet().toArray() - (0 until entries.size).foreach { i => - if (i > 0) { - sb.append(separator) - } - val entry = entries(i).asInstanceOf[Entry[_, _]] - buildString(sb, entry.getKey, keyType, level + 2) - sb.append(keyValueSeparator) - buildString(sb, entry.getValue, valueType, level + 2) - } - case (arrayList: java.util.ArrayList[_], StructType(fields)) => - val separator = ioSchema.getSeparator(level) - (0 until arrayList.size).foreach { i => - if (i > 0) { - sb.append(separator) - } - buildString(sb, arrayList.get(i), fields(i).dataType, level + 1) - } - case (other, _) => - sb.append(other) - } - } - override def run(): Unit = Utils.logUncaughtExceptions { TaskContext.setTaskContext(taskContext) @@ -336,45 +343,14 @@ case class ScriptTransformationIOSchema( schemaLess: Boolean) extends Serializable { import ScriptTransformationIOSchema._ - val inputRowFormatMap = inputRowFormat.toMap.withDefault(k => defaultFormat(k)) - val outputRowFormatMap = outputRowFormat.toMap.withDefault(k => defaultFormat(k)) - - val separators = (getByte(inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 0.toByte) :: - getByte(inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS"), 1.toByte) :: - getByte(inputRowFormatMap("TOK_TABLEROWFORMATMAPKEYS"), 2.toByte) :: Nil) ++ - (4 to 8).map(_.toByte) - - def getByte(altValue: String, defaultVal: Byte): Byte = { - if (altValue != null && altValue.length > 0) { - try { - java.lang.Byte.parseByte(altValue) - } catch { - case _: NumberFormatException => - altValue.charAt(0).toByte - } - } else { - defaultVal - } - } - - def getSeparator(level: Int): Char = { - try { - separators(level).toChar - } catch { - case _: IndexOutOfBoundsException => - val msg = "Number of levels of nesting supported for Spark SQL script transform" + - " is " + (separators.length - 1) + " Unable to work with level " + level - throw new RuntimeException(msg) - } - } + val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) + val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) } object ScriptTransformationIOSchema { val defaultFormat = Map( - ("TOK_TABLEROWFORMATLINES", "\n"), ("TOK_TABLEROWFORMATFIELD", "\u0001"), - ("TOK_TABLEROWFORMATCOLLITEMS", "\u0002"), - ("TOK_TABLEROWFORMATMAPKEYS", "\u0003") + ("TOK_TABLEROWFORMATLINES", "\n") ) val defaultIOSchema = ScriptTransformationIOSchema( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala deleted file mode 100644 index 597bde956da9b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkInspectors.scala +++ /dev/null @@ -1,259 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import scala.util.control.NonFatal - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, IntervalUtils, MapData} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -object SparkInspectors { - def wrapper(dataType: DataType): Any => Any = dataType match { - case ArrayType(tpe, _) => - val wp = wrapper(tpe) - withNullSafe { o => - val array = o.asInstanceOf[ArrayData] - val values = new java.util.ArrayList[Any](array.numElements()) - array.foreach(tpe, (_, e) => values.add(wp(e))) - values - } - case MapType(keyType, valueType, _) => - val mt = dataType.asInstanceOf[MapType] - val keyWrapper = wrapper(keyType) - val valueWrapper = wrapper(valueType) - withNullSafe { o => - val map = o.asInstanceOf[MapData] - val jmap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(mt.keyType, mt.valueType, (k, v) => - jmap.put(keyWrapper(k), valueWrapper(v))) - jmap - } - case StringType => getStringWritable - case IntegerType => getIntWritable - case DoubleType => getDoubleWritable - case BooleanType => getBooleanWritable - case LongType => getLongWritable - case FloatType => getFloatWritable - case ShortType => getShortWritable - case ByteType => getByteWritable - case NullType => (_: Any) => null - case BinaryType => getBinaryWritable - case DateType => getDateWritable - case TimestampType => getTimestampWritable - // TODO decimal precision? - case DecimalType() => getDecimalWritable - case StructType(fields) => - val structType = dataType.asInstanceOf[StructType] - val wrappers = fields.map(f => wrapper(f.dataType)) - withNullSafe { o => - val row = o.asInstanceOf[InternalRow] - val result = new java.util.ArrayList[AnyRef](wrappers.size) - wrappers.zipWithIndex.foreach { - case (wrapper, i) => - val tpe = structType(i).dataType - result.add(wrapper(row.get(i, tpe)).asInstanceOf[AnyRef]) - } - result - } - case _: UserDefinedType[_] => - val sqlType = dataType.asInstanceOf[UserDefinedType[_]].sqlType - wrapper(sqlType) - } - - private def withNullSafe(f: Any => Any): Any => Any = { - input => - if (input == null) { - null - } else { - f(input) - } - } - - private def getStringWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[UTF8String].toString - } - - private def getIntWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Int] - } - - private def getDoubleWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Double] - } - - private def getBooleanWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Boolean] - } - - private def getLongWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Long] - } - - private def getFloatWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Float] - } - - private def getShortWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Short] - } - - private def getByteWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Byte] - } - - private def getBinaryWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Array[Byte]] - } - - private def getDateWritable(value: Any): Any = - if (value == null) { - null - } else { - DateTimeUtils.toJavaDate(value.asInstanceOf[Int]) - } - - private def getTimestampWritable(value: Any): Any = - if (value == null) { - null - } else { - DateTimeUtils.toJavaTimestamp(value.asInstanceOf[Long]) - } - - private def getDecimalWritable(value: Any): Any = - if (value == null) { - null - } else { - value.asInstanceOf[Decimal] - } - - - def unwrapper( - dataType: DataType, - conf: SQLConf, - ioSchema: ScriptTransformationIOSchema, - level: Int): String => Any = { - val converter = CatalystTypeConverters.createToCatalystConverter(dataType) - dataType match { - case StringType => wrapperConvertException(data => data, converter) - case BooleanType => wrapperConvertException(data => data.toBoolean, converter) - case ByteType => wrapperConvertException(data => data.toByte, converter) - case BinaryType => - wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter) - case IntegerType => wrapperConvertException(data => data.toInt, converter) - case ShortType => wrapperConvertException(data => data.toShort, converter) - case LongType => wrapperConvertException(data => data.toLong, converter) - case FloatType => wrapperConvertException(data => data.toFloat, converter) - case DoubleType => wrapperConvertException(data => data.toDouble, converter) - case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) - case DateType if conf.datetimeJava8ApiEnabled => - wrapperConvertException(data => DateTimeUtils.stringToDate( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.daysToLocalDate).orNull, converter) - case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaDate).orNull, converter) - case TimestampType if conf.datetimeJava8ApiEnabled => - wrapperConvertException(data => DateTimeUtils.stringToTimestamp( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.microsToInstant).orNull, converter) - case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( - UTF8String.fromString(data), - DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) - .map(DateTimeUtils.toJavaTimestamp).orNull, converter) - case CalendarIntervalType => wrapperConvertException( - data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), - converter) - case udt: UserDefinedType[_] => - wrapperConvertException(data => udt.deserialize(data), converter) - case ArrayType(tpe, _) => - val separator = ioSchema.getSeparator(level) - val un = unwrapper(tpe, conf, ioSchema, level + 1) - wrapperConvertException(data => { - data.split(separator) - .map(un).toSeq - }, converter) - case MapType(keyType, valueType, _) => - val separator = ioSchema.getSeparator(level) - val keyValueSeparator = ioSchema.getSeparator(level + 1) - val keyUnwrapper = unwrapper(keyType, conf, ioSchema, level + 2) - val valueUnwrapper = unwrapper(valueType, conf, ioSchema, level + 2) - wrapperConvertException(data => { - val list = data.split(separator) - list.map { kv => - val kvList = kv.split(keyValueSeparator) - keyUnwrapper(kvList(0)) -> valueUnwrapper(kvList(1)) - }.toMap - }, converter) - case StructType(fields) => - val separator = ioSchema.getSeparator(level) - val unwrappers = fields.map(f => unwrapper(f.dataType, conf, ioSchema, level + 1)) - wrapperConvertException(data => { - val list = data.split(separator) - Row.fromSeq(list.zip(unwrappers).map { - case (data: String, unwrapper: (String => Any)) => unwrapper(data) - }) - }, converter) - case _ => wrapperConvertException(data => data, converter) - } - } - - // Keep consistent with Hive `LazySimpleSerDe`, when there is a type case error, return null - private val wrapperConvertException: (String => Any, Any => Any) => String => Any = - (f: String => Any, converter: Any => Any) => - (data: String) => converter { - try { - f(data) - } catch { - case NonFatal(_) => null - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 3f34ea1a91351..306c381e6e8cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -305,8 +305,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU AttributeReference("b", ArrayType(IntegerType))(), AttributeReference("c", MapType(StringType, IntegerType))(), AttributeReference("d", StructType( - Array(StructField("col1", IntegerType), - StructField("col2", IntegerType))))(), + Array(StructField("_1", IntegerType), + StructField("_2", IntegerType))))(), AttributeReference("e", new SimpleTupleUDT)()), child = child, ioschema = defaultIOSchema @@ -508,18 +508,18 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU AttributeReference("c", MapType(StringType, IntegerType))(), AttributeReference("d", MapType(StringType, ArrayType(StringType)))(), AttributeReference("e", StructType( - Array(StructField("col1", ArrayType(IntegerType)), - StructField("col2", ArrayType(ArrayType(IntegerType))))))(), + Array(StructField("a", ArrayType(IntegerType)), + StructField("b", ArrayType(ArrayType(IntegerType))))))(), AttributeReference("f", StructType( - Array(StructField("col1", ArrayType(IntegerType)), - StructField("col2", MapType(StringType, ArrayType(StringType))))))(), + Array(StructField("a", ArrayType(IntegerType)), + StructField("d", MapType(StringType, ArrayType(StringType))))))(), AttributeReference("g", StructType( Array(StructField("col1", StructType( - Array(StructField("col1", ArrayType(IntegerType)), - StructField("col2", ArrayType(ArrayType(IntegerType)))))), + Array(StructField("a", ArrayType(IntegerType)), + StructField("b", ArrayType(ArrayType(IntegerType)))))), StructField("col2", StructType( - Array(StructField("col1", ArrayType(IntegerType)), - StructField("col2", MapType(StringType, ArrayType(StringType)))))))))()), + Array(StructField("a", ArrayType(IntegerType)), + StructField("d", MapType(StringType, ArrayType(StringType)))))))))()), child = child, ioschema = defaultIOSchema ), @@ -551,23 +551,19 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU ), df.select('level_7).collect()) - val e = intercept[RuntimeException] { - checkAnswer( - df, - (child: SparkPlan) => createScriptTransformationExec( - input = Seq( - df.col("level_8").expr), - script = "cat", - output = Seq( - AttributeReference("a", ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType( - ArrayType(ArrayType(IntegerType)))))))))()), - child = child, - ioschema = defaultIOSchema - ), - df.select('level_8).collect()) - }.getMessage - assert(e.contains("Number of levels of nesting supported for Spark SQL" + - " script transform is 7 Unable to work with level 8")) + checkAnswer( + df, + (child: SparkPlan) => createScriptTransformationExec( + input = Seq( + df.col("level_8").expr), + script = "cat", + output = Seq( + AttributeReference("a", ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType( + ArrayType(ArrayType(IntegerType)))))))))()), + child = child, + ioschema = defaultIOSchema + ), + df.select('level_8).collect()) } } From b631b70dc2e6ac0b60d10a04aa2cfb967c8e37bb Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Thu, 4 Feb 2021 16:50:53 +0800 Subject: [PATCH 10/22] Update BaseScriptTransformationExec.scala --- .../spark/sql/execution/BaseScriptTransformationExec.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 2249d7509c73a..d25450e760d0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -226,10 +226,9 @@ trait BaseScriptTransformationExec extends UnaryExecNode { case CalendarIntervalType => wrapperConvertException( data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), converter) - case _: ArrayType | _: MapType | _: StructType => wrapperConvertException(data => { + case _: ArrayType | _: MapType | _: StructType => wrapperConvertException(data => JsonToStructs(attr.dataType, Map.empty[String, String], - Literal(data), Some(conf.sessionLocalTimeZone)).eval() - }, converter) + Literal(data), Some(conf.sessionLocalTimeZone)).eval(), converter) case udt: UserDefinedType[_] => wrapperConvertException(data => udt.deserialize(data), converter) case dt => From b7e7f9200a3e59257f305d929e852ebddeb96506 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Fri, 5 Feb 2021 13:07:26 +0800 Subject: [PATCH 11/22] follow comment --- .../spark/sql/catalyst/CatalystTypeConverters.scala | 13 ------------- .../execution/BaseScriptTransformationExec.scala | 2 +- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 5e3a7d0aa0b5d..907b5877b3ac0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -174,7 +174,6 @@ object CatalystTypeConverters { convertedIterable += elementConverter.toCatalyst(item) } new GenericArrayData(convertedIterable.toArray) - case g: GenericArrayData => new GenericArrayData(g.array.map(elementConverter.toCatalyst)) case other => throw new IllegalArgumentException( s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to an array of ${elementType.catalogString}") @@ -214,9 +213,6 @@ object CatalystTypeConverters { scalaValue match { case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction) case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction) - case map: ArrayBasedMapData => - ArrayBasedMapData(map.keyArray.array.zip(map.valueArray.array).toMap, - keyFunction, valueFunction) case other => throw new IllegalArgumentException( s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + "cannot be converted to a map type with " @@ -267,15 +263,6 @@ object CatalystTypeConverters { idx += 1 } new GenericInternalRow(ar) - case g: GenericInternalRow => - val ar = new Array[Any](structType.size) - val values = g.values - var idx = 0 - while (idx < structType.size) { - ar(idx) = converters(idx).toCatalyst(values(idx)) - idx += 1 - } - new GenericInternalRow(ar) case other => throw new IllegalArgumentException( s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) " + s"cannot be converted to ${structType.catalogString}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index d25450e760d0b..8121efa9c52a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -228,7 +228,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { converter) case _: ArrayType | _: MapType | _: StructType => wrapperConvertException(data => JsonToStructs(attr.dataType, Map.empty[String, String], - Literal(data), Some(conf.sessionLocalTimeZone)).eval(), converter) + Literal(data), Some(conf.sessionLocalTimeZone)).eval(), any => any) case udt: UserDefinedType[_] => wrapperConvertException(data => udt.deserialize(data), converter) case dt => From 8dec5a170c57d276ade4e359e3d899a03a9c0214 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Fri, 5 Feb 2021 14:15:04 +0800 Subject: [PATCH 12/22] follow comment --- .../BaseScriptTransformationExec.scala | 6 +- .../BaseScriptTransformationSuite.scala | 72 +++++++++++++------ 2 files changed, 52 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 8121efa9c52a3..02d682a019aa5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -47,7 +47,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { def ioschema: ScriptTransformationIOSchema protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = { - input.map { in: Expression => + input.map { in => in.dataType match { case _: ArrayType | _: MapType | _: StructType => new StructsToJson(in).withTimeZone(conf.sessionLocalTimeZone) @@ -226,8 +226,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode { case CalendarIntervalType => wrapperConvertException( data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), converter) - case _: ArrayType | _: MapType | _: StructType => wrapperConvertException(data => - JsonToStructs(attr.dataType, Map.empty[String, String], + case _: ArrayType | _: MapType | _: StructType => + wrapperConvertException(data => JsonToStructs(attr.dataType, Map.empty[String, String], Literal(data), Some(conf.sessionLocalTimeZone)).eval(), any => any) case udt: UserDefinedType[_] => wrapperConvertException(data => udt.deserialize(data), converter) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 306c381e6e8cd..332de97b7913a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -527,43 +527,69 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU } } - test("SPARK-31936: Script transform support 7 level nested complex type (no serde)") { + test("SPARK-31936: Script transform support nested complex type (no serde)") { assume(TestUtils.testCommandAvailable("python")) withTempView("v") { val df = Seq( - Array(1) - ).toDF("a").select('a, - array(array(array(array(array(array('a)))))).as("level_7"), - array(array(array(array(array(array(array('a))))))).as("level_8") + (Array(Array(Array(Array(Array(Array(1, 2, 3)))))), + Array(Array(Array(Array(Array(Array(1, 2, 3))), Array(Array(Array(1, 2, 3)))))), + Map("a" -> Map("c" -> Map("d" -> Array(1, 2, 3))), + "b" -> Map("c" -> Map("d" -> Array(1, 2, 3)))) + ) + ).toDF("a", "b", "c").select('a, 'b, 'c, + struct('a, 'b, 'c).as("d") + ).select('a, 'b, 'c, 'd, + struct('c, 'd).as("e") ) checkAnswer( df, (child: SparkPlan) => createScriptTransformationExec( input = Seq( - df.col("level_7").expr), - script = "cat", - output = Seq( - AttributeReference("a", ArrayType(ArrayType(ArrayType(ArrayType(ArrayType( - ArrayType(ArrayType(IntegerType))))))))()), - child = child, - ioschema = defaultIOSchema - ), - df.select('level_7).collect()) - - checkAnswer( - df, - (child: SparkPlan) => createScriptTransformationExec( - input = Seq( - df.col("level_8").expr), + df.col("a").expr, + df.col("b").expr, + df.col("c").expr, + df.col("d").expr, + df.col("e").expr), script = "cat", output = Seq( - AttributeReference("a", ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType( - ArrayType(ArrayType(IntegerType)))))))))()), + AttributeReference("a", + ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(IntegerType)))))))(), + AttributeReference("b", + ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(IntegerType)))))))(), + AttributeReference("c", + MapType(StringType, MapType(StringType, + MapType(StringType, ArrayType(IntegerType)))))(), + AttributeReference("d", + StructType(Array( + StructField("a", + ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(IntegerType))))))), + StructField("b", + ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(IntegerType))))))), + StructField("c", + MapType(StringType, MapType(StringType, + MapType(StringType, ArrayType(IntegerType))))))))(), + AttributeReference("e", + StructType(Array( + StructField("c", + MapType(StringType, MapType(StringType, + MapType(StringType, ArrayType(IntegerType))))), + StructField("d", + StructType(Array( + StructField("a", + ArrayType(ArrayType(ArrayType(ArrayType(ArrayType( + ArrayType(IntegerType))))))), + StructField("b", + ArrayType(ArrayType(ArrayType(ArrayType(ArrayType( + ArrayType(IntegerType))))))), + StructField("c", + MapType(StringType, MapType(StringType, + MapType(StringType, ArrayType(IntegerType)))))))))))() + ), child = child, ioschema = defaultIOSchema ), - df.select('level_8).collect()) + df.select('a, 'b, 'c, 'd, 'e).collect()) } } From 529d54d242b696b650cd36476d19cdc7f0b2bc94 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Sat, 6 Feb 2021 22:41:02 +0800 Subject: [PATCH 13/22] Update BaseScriptTransformationExec.scala --- .../spark/sql/execution/BaseScriptTransformationExec.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 02d682a019aa5..8820d2df1c2c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -50,7 +50,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode { input.map { in => in.dataType match { case _: ArrayType | _: MapType | _: StructType => - new StructsToJson(in).withTimeZone(conf.sessionLocalTimeZone) + new StructsToJson(ioschema.inputSerdeProps.toMap, in) + .withTimeZone(conf.sessionLocalTimeZone) case _ => Cast(in, StringType).withTimeZone(conf.sessionLocalTimeZone) } } @@ -227,7 +228,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode { data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), converter) case _: ArrayType | _: MapType | _: StructType => - wrapperConvertException(data => JsonToStructs(attr.dataType, Map.empty[String, String], + wrapperConvertException(data => JsonToStructs(attr.dataType, + ioschema.outputSerdeProps.toMap, Literal(data), Some(conf.sessionLocalTimeZone)).eval(), any => any) case udt: UserDefinedType[_] => wrapperConvertException(data => udt.deserialize(data), converter) From 4f0e78f258bc327d25e0b6a1404974d7c883b229 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Sat, 6 Feb 2021 23:24:05 +0800 Subject: [PATCH 14/22] Avoid construct JsonToStructs repeated --- .../spark/sql/execution/BaseScriptTransformationExec.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 8820d2df1c2c6..63b3c8e635a69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -228,9 +228,10 @@ trait BaseScriptTransformationExec extends UnaryExecNode { data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), converter) case _: ArrayType | _: MapType | _: StructType => - wrapperConvertException(data => JsonToStructs(attr.dataType, - ioschema.outputSerdeProps.toMap, - Literal(data), Some(conf.sessionLocalTimeZone)).eval(), any => any) + val complexTypeFactory = JsonToStructs(attr.dataType, + ioschema.outputSerdeProps.toMap, Literal(null), Some(conf.sessionLocalTimeZone)) + wrapperConvertException(data => + complexTypeFactory.nullSafeEval(UTF8String.fromString(data)), any => any) case udt: UserDefinedType[_] => wrapperConvertException(data => udt.deserialize(data), converter) case dt => From ed8c54ca96154e41db419591dc860d7c714d0aa6 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Sat, 6 Feb 2021 23:24:15 +0800 Subject: [PATCH 15/22] remove unused UT --- .../BaseScriptTransformationSuite.scala | 66 ------------------- 1 file changed, 66 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 332de97b7913a..d81b672d64829 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -527,72 +527,6 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU } } - test("SPARK-31936: Script transform support nested complex type (no serde)") { - assume(TestUtils.testCommandAvailable("python")) - withTempView("v") { - val df = Seq( - (Array(Array(Array(Array(Array(Array(1, 2, 3)))))), - Array(Array(Array(Array(Array(Array(1, 2, 3))), Array(Array(Array(1, 2, 3)))))), - Map("a" -> Map("c" -> Map("d" -> Array(1, 2, 3))), - "b" -> Map("c" -> Map("d" -> Array(1, 2, 3)))) - ) - ).toDF("a", "b", "c").select('a, 'b, 'c, - struct('a, 'b, 'c).as("d") - ).select('a, 'b, 'c, 'd, - struct('c, 'd).as("e") - ) - - checkAnswer( - df, - (child: SparkPlan) => createScriptTransformationExec( - input = Seq( - df.col("a").expr, - df.col("b").expr, - df.col("c").expr, - df.col("d").expr, - df.col("e").expr), - script = "cat", - output = Seq( - AttributeReference("a", - ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(IntegerType)))))))(), - AttributeReference("b", - ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(IntegerType)))))))(), - AttributeReference("c", - MapType(StringType, MapType(StringType, - MapType(StringType, ArrayType(IntegerType)))))(), - AttributeReference("d", - StructType(Array( - StructField("a", - ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(IntegerType))))))), - StructField("b", - ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(ArrayType(IntegerType))))))), - StructField("c", - MapType(StringType, MapType(StringType, - MapType(StringType, ArrayType(IntegerType))))))))(), - AttributeReference("e", - StructType(Array( - StructField("c", - MapType(StringType, MapType(StringType, - MapType(StringType, ArrayType(IntegerType))))), - StructField("d", - StructType(Array( - StructField("a", - ArrayType(ArrayType(ArrayType(ArrayType(ArrayType( - ArrayType(IntegerType))))))), - StructField("b", - ArrayType(ArrayType(ArrayType(ArrayType(ArrayType( - ArrayType(IntegerType))))))), - StructField("c", - MapType(StringType, MapType(StringType, - MapType(StringType, ArrayType(IntegerType)))))))))))() - ), - child = child, - ioschema = defaultIOSchema - ), - df.select('a, 'b, 'c, 'd, 'e).collect()) - } - } - test("SPARK-33934: Add SparkFile's root dir to env property PATH") { assume(TestUtils.testCommandAvailable("python")) val scriptFilePath = copyAndGetResourceFile("test_script.py", ".py").getAbsoluteFile From 520f4b847218fd7c0d4f22f455e5569fda6fa9f0 Mon Sep 17 00:00:00 2001 From: AngersZhuuuu Date: Fri, 16 Apr 2021 10:25:47 +0800 Subject: [PATCH 16/22] Update sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala Co-authored-by: Hyukjin Kwon --- .../spark/sql/execution/BaseScriptTransformationExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 63b3c8e635a69..0c7ff7c292dca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -229,7 +229,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { converter) case _: ArrayType | _: MapType | _: StructType => val complexTypeFactory = JsonToStructs(attr.dataType, - ioschema.outputSerdeProps.toMap, Literal(null), Some(conf.sessionLocalTimeZone)) + ioschema.outputSerdeProps.toMap, Literal(null), Some(conf.sessionLocalTimeZone)) wrapperConvertException(data => complexTypeFactory.nullSafeEval(UTF8String.fromString(data)), any => any) case udt: UserDefinedType[_] => From b5a42684eae2af1d271456a97a434258b4882962 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Sun, 18 Apr 2021 09:47:23 +0800 Subject: [PATCH 17/22] [SPARK-35097][SQL] Add column name to SparkUpgradeException about ancient datetime --- .../spark/sql/avro/AvroDeserializer.scala | 11 ++--- .../sql/errors/QueryExecutionErrors.scala | 9 ++++- .../orc/OrcColumnarBatchReader.java | 4 +- .../parquet/VectorizedColumnReader.java | 40 ++++++++++++------- .../parquet/VectorizedPlainValuesReader.java | 4 +- .../parquet/VectorizedRleValuesReader.java | 5 ++- .../vectorized/OffHeapColumnVector.java | 10 ++--- .../vectorized/OnHeapColumnVector.java | 10 ++--- .../vectorized/WritableColumnVector.java | 27 ++++++++----- .../datasources/DataSourceUtils.scala | 15 ++++--- .../parquet/ParquetRowConverter.scala | 14 +++---- .../connector/JavaColumnarDataSourceV2.java | 4 +- .../sql/SparkSessionExtensionSuite.scala | 6 +-- .../sql/connector/DataSourceV2Suite.scala | 4 +- .../compression/BooleanBitSetSuite.scala | 4 +- .../compression/DictionaryEncodingSuite.scala | 2 +- .../compression/IntegralDeltaSuite.scala | 2 +- .../PassThroughEncodingSuite.scala | 2 +- .../compression/RunLengthEncodingSuite.scala | 2 +- .../vectorized/ColumnVectorSuite.scala | 8 ++-- .../vectorized/ColumnarBatchBenchmark.scala | 16 ++++---- .../vectorized/ColumnarBatchSuite.scala | 4 +- 22 files changed, 115 insertions(+), 88 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index a19a7b0d0edd1..6597c56efbede 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -62,10 +62,10 @@ private[sql] class AvroDeserializer( private lazy val decimalConversions = new DecimalConversion() private val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( - datetimeRebaseMode, "Avro") + datetimeRebaseMode, "Avro")(_, _) private val timestampRebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( - datetimeRebaseMode, "Avro") + datetimeRebaseMode, "Avro")(_, _) private val converter: Any => Option[Any] = try { rootCatalystType match { @@ -126,7 +126,8 @@ private[sql] class AvroDeserializer( updater.setInt(ordinal, value.asInstanceOf[Int]) case (INT, DateType) => (updater, ordinal, value) => - updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int])) + updater.setInt(ordinal, + dateRebaseFunc(avroType.getName, catalystType)(value.asInstanceOf[Int])) case (LONG, LongType) => (updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long]) @@ -137,10 +138,10 @@ private[sql] class AvroDeserializer( case null | _: TimestampMillis => (updater, ordinal, value) => val millis = value.asInstanceOf[Long] val micros = DateTimeUtils.millisToMicros(millis) - updater.setLong(ordinal, timestampRebaseFunc(micros)) + updater.setLong(ordinal, timestampRebaseFunc(avroType.getName, catalystType)(micros)) case _: TimestampMicros => (updater, ordinal, value) => val micros = value.asInstanceOf[Long] - updater.setLong(ordinal, timestampRebaseFunc(micros)) + updater.setLong(ordinal, timestampRebaseFunc(avroType.getName, catalystType)(micros)) case other => throw new IncompatibleSchemaException(errorPrefix + s"Avro logical type $other cannot be converted to SQL type ${TimestampType.sql}.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index c5a608e38da56..8b2827ae14c5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -394,11 +394,16 @@ object QueryExecutionErrors { } def sparkUpgradeInReadingDatesError( - format: String, config: String, option: String): SparkUpgradeException = { + colName: String, + dataType: DataType, + format: String, + config: String, + option: String): SparkUpgradeException = { new SparkUpgradeException("3.0", s""" |reading dates before 1582-10-15 or timestamps before 1900-01-01T00:00:00Z from $format - |files can be ambiguous, as the files may be written by Spark 2.x or legacy versions of + |files can be ambiguous when read column `${colName}` of datatype `${dataType}`, + |as the files may be written by Spark 2.x or legacy versions of |Hive, which uses a legacy hybrid calendar that is different from Spark 3.0+'s Proleptic |Gregorian calendar. See more details in SPARK-31404. You can set the SQL config |'$config' or the datasource option '$option' to 'LEGACY' to rebase the datetime values diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 40ed0b2454c12..e316ec0c830f4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -167,7 +167,7 @@ public void initBatch( for (int i = 0; i < requiredFields.length; i++) { DataType dt = requiredFields[i].dataType(); if (requestedPartitionColIds[i] != -1) { - OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, requiredFields[i].name(),dt); ColumnVectorUtils.populate(partitionCol, partitionValues, requestedPartitionColIds[i]); partitionCol.setIsConstant(); orcVectorWrappers[i] = partitionCol; @@ -175,7 +175,7 @@ public void initBatch( int colId = requestedDataColIds[i]; // Initialize the missing columns once. if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, requiredFields[i].name(), dt); missingCol.putNulls(0, capacity); missingCol.setIsConstant(); orcVectorWrappers[i] = missingCol; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 52620b0740851..36fe7d2252a65 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -190,10 +190,13 @@ private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName return isSupported; } - static int rebaseDays(int julianDays, final boolean failIfRebase) { + static int rebaseDays( + int julianDays, + final boolean failIfRebase, + WritableColumnVector c) { if (failIfRebase) { if (julianDays < RebaseDateTime.lastSwitchJulianDay()) { - throw DataSourceUtils.newRebaseExceptionInRead("Parquet"); + throw DataSourceUtils.newRebaseExceptionInRead(c.colName, c.dataType(), "Parquet"); } else { return julianDays; } @@ -205,10 +208,11 @@ static int rebaseDays(int julianDays, final boolean failIfRebase) { private static long rebaseTimestamp( long julianMicros, final boolean failIfRebase, + WritableColumnVector c, final String format) { if (failIfRebase) { if (julianMicros < RebaseDateTime.lastSwitchJulianTs()) { - throw DataSourceUtils.newRebaseExceptionInRead(format); + throw DataSourceUtils.newRebaseExceptionInRead(c.colName, c.dataType(), format); } else { return julianMicros; } @@ -217,12 +221,18 @@ private static long rebaseTimestamp( } } - static long rebaseMicros(long julianMicros, final boolean failIfRebase) { - return rebaseTimestamp(julianMicros, failIfRebase, "Parquet"); + static long rebaseMicros( + long julianMicros, + final boolean failIfRebase, + WritableColumnVector c) { + return rebaseTimestamp(julianMicros, failIfRebase, c, "Parquet"); } - static long rebaseInt96(long julianMicros, final boolean failIfRebase) { - return rebaseTimestamp(julianMicros, failIfRebase, "Parquet INT96"); + static long rebaseInt96( + long julianMicros, + final boolean failIfRebase, + WritableColumnVector c) { + return rebaseTimestamp(julianMicros, failIfRebase, c, "Parquet INT96"); } /** @@ -387,7 +397,7 @@ private void decodeDictionaryIds( for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { int julianDays = dictionary.decodeToInt(dictionaryIds.getDictId(i)); - column.putInt(i, rebaseDays(julianDays, failIfRebase)); + column.putInt(i, rebaseDays(julianDays, failIfRebase, column)); } } } else { @@ -432,7 +442,7 @@ private void decodeDictionaryIds( if (!column.isNullAt(i)) { long julianMillis = dictionary.decodeToLong(dictionaryIds.getDictId(i)); long julianMicros = DateTimeUtils.millisToMicros(julianMillis); - column.putLong(i, rebaseMicros(julianMicros, failIfRebase)); + column.putLong(i, rebaseMicros(julianMicros, failIfRebase, column)); } } } @@ -441,7 +451,7 @@ private void decodeDictionaryIds( for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { long julianMicros = dictionary.decodeToLong(dictionaryIds.getDictId(i)); - column.putLong(i, rebaseMicros(julianMicros, failIfRebase)); + column.putLong(i, rebaseMicros(julianMicros, failIfRebase, column)); } } } else { @@ -480,7 +490,7 @@ private void decodeDictionaryIds( if (!column.isNullAt(i)) { Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i)); long julianMicros = ParquetRowConverter.binaryToSQLTimestamp(v); - long gregorianMicros = rebaseInt96(julianMicros, failIfRebase); + long gregorianMicros = rebaseInt96(julianMicros, failIfRebase, column); column.putLong(i, gregorianMicros); } } @@ -500,7 +510,7 @@ private void decodeDictionaryIds( if (!column.isNullAt(i)) { Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i)); long julianMicros = ParquetRowConverter.binaryToSQLTimestamp(v); - long gregorianMicros = rebaseInt96(julianMicros, failIfRebase); + long gregorianMicros = rebaseInt96(julianMicros, failIfRebase, column); long adjTime = DateTimeUtils.convertTz(gregorianMicros, convertTz, UTC); column.putLong(i, adjTime); } @@ -640,7 +650,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) thro for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { long julianMicros = DateTimeUtils.millisToMicros(dataColumn.readLong()); - column.putLong(rowId + i, rebaseMicros(julianMicros, failIfRebase)); + column.putLong(rowId + i, rebaseMicros(julianMicros, failIfRebase, column)); } else { column.putNull(rowId + i); } @@ -698,7 +708,7 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) th if (defColumn.readInteger() == maxDefLevel) { // Read 12 bytes for INT96 long julianMicros = ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12)); - long gregorianMicros = rebaseInt96(julianMicros, failIfRebase); + long gregorianMicros = rebaseInt96(julianMicros, failIfRebase, column); column.putLong(rowId + i, gregorianMicros); } else { column.putNull(rowId + i); @@ -722,7 +732,7 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) th if (defColumn.readInteger() == maxDefLevel) { // Read 12 bytes for INT96 long julianMicros = ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12)); - long gregorianMicros = rebaseInt96(julianMicros, failIfRebase); + long gregorianMicros = rebaseInt96(julianMicros, failIfRebase, column); long adjTime = DateTimeUtils.convertTz(gregorianMicros, convertTz, UTC); column.putLong(rowId + i, adjTime); } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 6a0038dbdc44c..95a4dafb7ad8c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -107,7 +107,7 @@ public final void readIntegersWithRebase( } if (rebase) { if (failIfRebase) { - throw DataSourceUtils.newRebaseExceptionInRead("Parquet"); + throw DataSourceUtils.newRebaseExceptionInRead(c.colName, c.dataType(), "Parquet"); } else { for (int i = 0; i < total; i += 1) { c.putInt(rowId + i, RebaseDateTime.rebaseJulianToGregorianDays(buffer.getInt())); @@ -164,7 +164,7 @@ public final void readLongsWithRebase( } if (rebase) { if (failIfRebase) { - throw DataSourceUtils.newRebaseExceptionInRead("Parquet"); + throw DataSourceUtils.newRebaseExceptionInRead(c.colName, c.dataType(), "Parquet"); } else { for (int i = 0; i < total; i += 1) { c.putLong(rowId + i, RebaseDateTime.rebaseJulianToGregorianMicros(buffer.getLong())); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 125506d4d5013..7297861e5dfce 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -264,7 +264,8 @@ public void readIntegersWithRebase( for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { int julianDays = data.readInteger(); - c.putInt(rowId + i, VectorizedColumnReader.rebaseDays(julianDays, failIfRebase)); + c.putInt(rowId + i, + VectorizedColumnReader.rebaseDays(julianDays, failIfRebase, c)); } else { c.putNull(rowId + i); } @@ -492,7 +493,7 @@ public void readLongsWithRebase( for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { long julianMicros = data.readLong(); - c.putLong(rowId + i, VectorizedColumnReader.rebaseMicros(julianMicros, failIfRebase)); + c.putLong(rowId + i, VectorizedColumnReader.rebaseMicros(julianMicros, failIfRebase, c)); } else { c.putNull(rowId + i); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 7da5a287710eb..80bd2ee590c70 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -50,7 +50,7 @@ public static OffHeapColumnVector[] allocateColumns(int capacity, StructType sch public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[] fields) { OffHeapColumnVector[] vectors = new OffHeapColumnVector[fields.length]; for (int i = 0; i < fields.length; i++) { - vectors[i] = new OffHeapColumnVector(capacity, fields[i].dataType()); + vectors[i] = new OffHeapColumnVector(capacity, fields[i].name(), fields[i].dataType()); } return vectors; } @@ -64,8 +64,8 @@ public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[] private long lengthData; private long offsetData; - public OffHeapColumnVector(int capacity, DataType type) { - super(capacity, type); + public OffHeapColumnVector(int capacity, String colName, DataType type) { + super(capacity, colName, type); nulls = 0; data = 0; @@ -566,7 +566,7 @@ protected void reserveInternal(int newCapacity) { } @Override - protected OffHeapColumnVector reserveNewColumn(int capacity, DataType type) { - return new OffHeapColumnVector(capacity, type); + protected OffHeapColumnVector reserveNewColumn(int capacity, String colName, DataType type) { + return new OffHeapColumnVector(capacity, colName, type); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 5a7d6cc20971b..f7f28f095f11b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -50,7 +50,7 @@ public static OnHeapColumnVector[] allocateColumns(int capacity, StructType sche public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] fields) { OnHeapColumnVector[] vectors = new OnHeapColumnVector[fields.length]; for (int i = 0; i < fields.length; i++) { - vectors[i] = new OnHeapColumnVector(capacity, fields[i].dataType()); + vectors[i] = new OnHeapColumnVector(capacity, fields[i].name(), fields[i].dataType()); } return vectors; } @@ -73,8 +73,8 @@ public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] f private int[] arrayLengths; private int[] arrayOffsets; - public OnHeapColumnVector(int capacity, DataType type) { - super(capacity, type); + public OnHeapColumnVector(int capacity, String colName, DataType type) { + super(capacity, colName, type); reserveInternal(capacity); reset(); @@ -580,7 +580,7 @@ protected void reserveInternal(int newCapacity) { } @Override - protected OnHeapColumnVector reserveNewColumn(int capacity, DataType type) { - return new OnHeapColumnVector(capacity, type); + protected OnHeapColumnVector reserveNewColumn(int capacity, String colName, DataType type) { + return new OnHeapColumnVector(capacity, colName, type); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 8c0f1e1257503..64ee74dba0d44 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -165,7 +165,7 @@ public void setDictionary(Dictionary dictionary) { */ public WritableColumnVector reserveDictionaryIds(int capacity) { if (dictionaryIds == null) { - dictionaryIds = reserveNewColumn(capacity, DataTypes.IntegerType); + dictionaryIds = reserveNewColumn(capacity, colName, DataTypes.IntegerType); } else { dictionaryIds.reset(); dictionaryIds.reserve(capacity); @@ -677,6 +677,11 @@ public WritableColumnVector arrayData() { */ public final void setIsConstant() { isConstant = true; } + /** + * Column name of this column. + */ + public String colName; + /** * Maximum number of rows that can be stored in this column. */ @@ -717,7 +722,7 @@ public WritableColumnVector arrayData() { /** * Reserve a new column. */ - protected abstract WritableColumnVector reserveNewColumn(int capacity, DataType type); + protected abstract WritableColumnVector reserveNewColumn(int capacity, String colName, DataType type); protected boolean isArray() { return type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType || @@ -728,8 +733,9 @@ protected boolean isArray() { * Sets up the common state and also handles creating the child columns if this is a nested * type. */ - protected WritableColumnVector(int capacity, DataType type) { + protected WritableColumnVector(int capacity, String colName, DataType type) { super(type); + this.colName = colName; this.capacity = capacity; if (isArray()) { @@ -742,24 +748,25 @@ protected WritableColumnVector(int capacity, DataType type) { childCapacity *= DEFAULT_ARRAY_LENGTH; } this.childColumns = new WritableColumnVector[1]; - this.childColumns[0] = reserveNewColumn(childCapacity, childType); + this.childColumns[0] = reserveNewColumn(childCapacity, colName + ".elem", childType); } else if (type instanceof StructType) { StructType st = (StructType)type; this.childColumns = new WritableColumnVector[st.fields().length]; for (int i = 0; i < childColumns.length; ++i) { - this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); + this.childColumns[i] = reserveNewColumn(capacity, colName + "." + st.fields()[i].name(), + st.fields()[i].dataType()); } } else if (type instanceof MapType) { MapType mapType = (MapType) type; this.childColumns = new WritableColumnVector[2]; - this.childColumns[0] = reserveNewColumn(capacity, mapType.keyType()); - this.childColumns[1] = reserveNewColumn(capacity, mapType.valueType()); + this.childColumns[0] = reserveNewColumn(capacity, colName + ".key", mapType.keyType()); + this.childColumns[1] = reserveNewColumn(capacity, colName + ".value", mapType.valueType()); } else if (type instanceof CalendarIntervalType) { // Three columns. Months as int. Days as Int. Microseconds as Long. this.childColumns = new WritableColumnVector[3]; - this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); - this.childColumns[1] = reserveNewColumn(capacity, DataTypes.IntegerType); - this.childColumns[2] = reserveNewColumn(capacity, DataTypes.LongType); + this.childColumns[0] = reserveNewColumn(capacity, colName + ".months", DataTypes.IntegerType); + this.childColumns[1] = reserveNewColumn(capacity, colName + ".days", DataTypes.IntegerType); + this.childColumns[2] = reserveNewColumn(capacity, colName + ".microseconds", DataTypes.LongType); } else { this.childColumns = null; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 2b10e4efd9ab8..2dc680a0e5691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -131,7 +131,10 @@ object DataSourceUtils { }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) } - def newRebaseExceptionInRead(format: String): SparkUpgradeException = { + def newRebaseExceptionInRead( + colName: String, + dataType: DataType, + format: String): SparkUpgradeException = { val (config, option) = format match { case "Parquet INT96" => (SQLConf.PARQUET_INT96_REBASE_MODE_IN_READ.key, ParquetOptions.INT96_REBASE_MODE) @@ -141,7 +144,7 @@ object DataSourceUtils { (SQLConf.AVRO_REBASE_MODE_IN_READ.key, "datetimeRebaseMode") case _ => throw QueryExecutionErrors.unrecognizedFileFormatError(format) } - QueryExecutionErrors.sparkUpgradeInReadingDatesError(format, config, option) + QueryExecutionErrors.sparkUpgradeInReadingDatesError(colName, dataType, format, config, option) } def newRebaseExceptionInWrite(format: String): SparkUpgradeException = { @@ -156,10 +159,10 @@ object DataSourceUtils { def creteDateRebaseFuncInRead( rebaseMode: LegacyBehaviorPolicy.Value, - format: String): Int => Int = rebaseMode match { + format: String)(colName: String, dataType: DataType): Int => Int = rebaseMode match { case LegacyBehaviorPolicy.EXCEPTION => days: Int => if (days < RebaseDateTime.lastSwitchJulianDay) { - throw DataSourceUtils.newRebaseExceptionInRead(format) + throw DataSourceUtils.newRebaseExceptionInRead(colName, dataType, format) } days case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianDays @@ -180,10 +183,10 @@ object DataSourceUtils { def creteTimestampRebaseFuncInRead( rebaseMode: LegacyBehaviorPolicy.Value, - format: String): Long => Long = rebaseMode match { + format: String)(colName: String, dataType: DataType): Long => Long = rebaseMode match { case LegacyBehaviorPolicy.EXCEPTION => micros: Long => if (micros < RebaseDateTime.lastSwitchJulianTs) { - throw DataSourceUtils.newRebaseExceptionInRead(format) + throw DataSourceUtils.newRebaseExceptionInRead(colName, dataType, format) } micros case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianMicros diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 0a1cca7ed0f3f..d02e162c2b612 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -189,13 +189,13 @@ private[parquet] class ParquetRowConverter( def currentRecord: InternalRow = currentRow private val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( - datetimeRebaseMode, "Parquet") + datetimeRebaseMode, "Parquet")(_, _) private val timestampRebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( - datetimeRebaseMode, "Parquet") + datetimeRebaseMode, "Parquet")(_, _) private val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( - int96RebaseMode, "Parquet INT96") + int96RebaseMode, "Parquet INT96")(_, _) // Converters for each field. private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = { @@ -332,7 +332,7 @@ private[parquet] class ParquetRowConverter( case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MICROS => new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { - updater.setLong(timestampRebaseFunc(value)) + updater.setLong(timestampRebaseFunc(parquetType.getName, catalystType)(value)) } } @@ -340,7 +340,7 @@ private[parquet] class ParquetRowConverter( new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { val micros = DateTimeUtils.millisToMicros(value) - updater.setLong(timestampRebaseFunc(micros)) + updater.setLong(timestampRebaseFunc(parquetType.getName, catalystType)(micros)) } } @@ -350,7 +350,7 @@ private[parquet] class ParquetRowConverter( // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { val julianMicros = ParquetRowConverter.binaryToSQLTimestamp(value) - val gregorianMicros = int96RebaseFunc(julianMicros) + val gregorianMicros = int96RebaseFunc(parquetType.getName, catalystType)(julianMicros) val adjTime = convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) .getOrElse(gregorianMicros) updater.setLong(adjTime) @@ -360,7 +360,7 @@ private[parquet] class ParquetRowConverter( case DateType => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { - updater.set(dateRebaseFunc(value)) + updater.set(dateRebaseFunc(parquetType.getName, catalystType)(value)) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java index 2f10c84c999f9..3bdc669dbf311 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java @@ -77,8 +77,8 @@ public PartitionReader createReader(InputPartition partition) { @Override public PartitionReader createColumnarReader(InputPartition partition) { JavaRangeInputPartition p = (JavaRangeInputPartition) partition; - OnHeapColumnVector i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - OnHeapColumnVector j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + OnHeapColumnVector i = new OnHeapColumnVector(BATCH_SIZE, "", DataTypes.IntegerType); + OnHeapColumnVector j = new OnHeapColumnVector(BATCH_SIZE, "", DataTypes.IntegerType); ColumnVector[] vectors = new ColumnVector[2]; vectors[0] = i; vectors[1] = j; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index d4a6d84ce2b30..951ee81ac7de2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -675,7 +675,7 @@ class BrokenColumnarAdd( } else if (lhs.isInstanceOf[ColumnVector] && rhs.isInstanceOf[ColumnVector]) { val l = lhs.asInstanceOf[ColumnVector] val r = rhs.asInstanceOf[ColumnVector] - val result = new OnHeapColumnVector(batch.numRows(), dataType) + val result = new OnHeapColumnVector(batch.numRows(), "", dataType) ret = result for (i <- 0 until batch.numRows()) { @@ -684,7 +684,7 @@ class BrokenColumnarAdd( } else if (rhs.isInstanceOf[ColumnVector]) { val l = lhs.asInstanceOf[Long] val r = rhs.asInstanceOf[ColumnVector] - val result = new OnHeapColumnVector(batch.numRows(), dataType) + val result = new OnHeapColumnVector(batch.numRows(), "", dataType) ret = result for (i <- 0 until batch.numRows()) { @@ -693,7 +693,7 @@ class BrokenColumnarAdd( } else if (lhs.isInstanceOf[ColumnVector]) { val l = lhs.asInstanceOf[ColumnVector] val r = rhs.asInstanceOf[Long] - val result = new OnHeapColumnVector(batch.numRows(), dataType) + val result = new OnHeapColumnVector(batch.numRows(), "", dataType) ret = result for (i <- 0 until batch.numRows()) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 49a1078800552..562596c9fbf8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -682,8 +682,8 @@ object ColumnarReaderFactory extends PartitionReaderFactory { override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { val RangeInputPartition(start, end) = partition new PartitionReader[ColumnarBatch] { - private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, "", IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, "", IntegerType) private lazy val batch = new ColumnarBatch(Array(i, j)) private var current = start diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index 111a620df8c24..790cda9bfa5b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -105,7 +105,7 @@ class BooleanBitSetSuite extends SparkFunSuite { assertResult(BooleanBitSet.typeId, "Wrong compression scheme ID")(buffer.getInt()) val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) - val columnVector = new OnHeapColumnVector(values.length, BooleanType) + val columnVector = new OnHeapColumnVector(values.length, "", BooleanType) decoder.decompress(columnVector, values.length) if (values.nonEmpty) { @@ -175,7 +175,7 @@ class BooleanBitSetSuite extends SparkFunSuite { assertResult(BooleanBitSet.typeId, "Wrong compression scheme ID")(buffer.getInt()) val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) - val columnVector = new OnHeapColumnVector(numRows, BooleanType) + val columnVector = new OnHeapColumnVector(numRows, "", BooleanType) decoder.decompress(columnVector, numRows) (0 until numRows).foreach { rowNum => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index 61e4cc068fa80..3fc556557fa0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -142,7 +142,7 @@ class DictionaryEncodingSuite extends SparkFunSuite { assertResult(DictionaryEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) val decoder = DictionaryEncoding.decoder(buffer, columnType) - val columnVector = new OnHeapColumnVector(inputSeq.length, columnType.dataType) + val columnVector = new OnHeapColumnVector(inputSeq.length, "", columnType.dataType) decoder.decompress(columnVector, inputSeq.length) if (inputSeq.nonEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index b5630488b3667..af1c5a34e35b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -136,7 +136,7 @@ class IntegralDeltaSuite extends SparkFunSuite { assertResult(scheme.typeId, "Wrong compression scheme ID")(buffer.getInt()) val decoder = scheme.decoder(buffer, columnType) - val columnVector = new OnHeapColumnVector(input.length, columnType.dataType) + val columnVector = new OnHeapColumnVector(input.length, "", columnType.dataType) decoder.decompress(columnVector, input.length) if (input.nonEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala index c6fe64d1058ab..081d4dfaae4f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala @@ -117,7 +117,7 @@ class PassThroughSuite extends SparkFunSuite { assertResult(PassThrough.typeId, "Wrong compression scheme ID")(buffer.getInt()) val decoder = PassThrough.decoder(buffer, columnType) - val columnVector = new OnHeapColumnVector(input.length, columnType.dataType) + val columnVector = new OnHeapColumnVector(input.length, "", columnType.dataType) decoder.decompress(columnVector, input.length) if (input.nonEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index 29dbc13b59c6b..110dc6681200b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -126,7 +126,7 @@ class RunLengthEncodingSuite extends SparkFunSuite { assertResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) val decoder = RunLengthEncoding.decoder(buffer, columnType) - val columnVector = new OnHeapColumnVector(inputSeq.length, columnType.dataType) + val columnVector = new OnHeapColumnVector(inputSeq.length, "", columnType.dataType) decoder.decompress(columnVector, inputSeq.length) if (inputSeq.nonEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 247efd5554a8f..97be5e5e1221a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -38,8 +38,8 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { size: Int, dt: DataType)( block: WritableColumnVector => Unit): Unit = { - withVector(new OnHeapColumnVector(size, dt))(block) - withVector(new OffHeapColumnVector(size, dt))(block) + withVector(new OnHeapColumnVector(size, "", dt))(block) + withVector(new OffHeapColumnVector(size, "", dt))(block) } private def testVectors( @@ -259,7 +259,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { - withVector(new OffHeapColumnVector(8, arrayType)) { testVector => + withVector(new OffHeapColumnVector(8, "", arrayType)) { testVector => val data = testVector.arrayData() (0 until 8).foreach(i => data.putInt(i, i)) (0 until 8).foreach(i => testVector.putArray(i, i, 1)) @@ -275,7 +275,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } test("[SPARK-22092] off-heap column vector reallocation corrupts struct nullability") { - withVector(new OffHeapColumnVector(8, structType)) { testVector => + withVector(new OffHeapColumnVector(8, "", structType)) { testVector => (0 until 8).foreach(i => if (i % 2 == 0) testVector.putNull(i) else testVector.putNotNull(i)) testVector.reserve(16) (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index f9ae611691a7f..eb9f70902add0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -147,7 +147,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { // Access through the column API with on heap memory val columnOnHeap = { i: Int => - val col = new OnHeapColumnVector(count, IntegerType) + val col = new OnHeapColumnVector(count, "", IntegerType) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -166,7 +166,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { // Access through the column API with off heap memory def columnOffHeap = { i: Int => { - val col = new OffHeapColumnVector(count, IntegerType) + val col = new OffHeapColumnVector(count, "", IntegerType) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -185,7 +185,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { // Access by directly getting the buffer backing the column. val columnOffheapDirect = { i: Int => - val col = new OffHeapColumnVector(count, IntegerType) + val col = new OffHeapColumnVector(count, "", IntegerType) var sum = 0L for (n <- 0L until iters) { var addr = col.valuesNativeAddress() @@ -251,7 +251,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { // Adding values by appending, instead of putting. val onHeapAppend = { i: Int => - val col = new OnHeapColumnVector(count, IntegerType) + val col = new OnHeapColumnVector(count, "", IntegerType) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -347,9 +347,9 @@ object ColumnarBatchBenchmark extends BenchmarkBase { def column(memoryMode: MemoryMode) = { i: Int => val column = if (memoryMode == MemoryMode.OFF_HEAP) { - new OffHeapColumnVector(count, BinaryType) + new OffHeapColumnVector(count, "", BinaryType) } else { - new OnHeapColumnVector(count, BinaryType) + new OnHeapColumnVector(count, "", BinaryType) } var sum = 0L @@ -378,8 +378,8 @@ object ColumnarBatchBenchmark extends BenchmarkBase { val random = new Random(0) val count = 4 * 1000 - val onHeapVector = new OnHeapColumnVector(count, ArrayType(IntegerType)) - val offHeapVector = new OffHeapColumnVector(count, ArrayType(IntegerType)) + val onHeapVector = new OnHeapColumnVector(count, "", ArrayType(IntegerType)) + val offHeapVector = new OffHeapColumnVector(count, "", ArrayType(IntegerType)) val minSize = 3 val maxSize = 32 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index bd69bab6f5da2..a6b210277b6ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -44,9 +44,9 @@ class ColumnarBatchSuite extends SparkFunSuite { private def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { if (memMode == MemoryMode.OFF_HEAP) { - new OffHeapColumnVector(capacity, dt) + new OffHeapColumnVector(capacity, "", dt) } else { - new OnHeapColumnVector(capacity, dt) + new OnHeapColumnVector(capacity, "", dt) } } From 76a746e69728083bad28e33144e70fdf024c5aa7 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Sun, 18 Apr 2021 12:23:35 +0800 Subject: [PATCH 18/22] Revert "[SPARK-35097][SQL] Add column name to SparkUpgradeException about ancient datetime" This reverts commit b5a42684eae2af1d271456a97a434258b4882962. --- .../spark/sql/avro/AvroDeserializer.scala | 11 +++-- .../sql/errors/QueryExecutionErrors.scala | 9 +---- .../orc/OrcColumnarBatchReader.java | 4 +- .../parquet/VectorizedColumnReader.java | 40 +++++++------------ .../parquet/VectorizedPlainValuesReader.java | 4 +- .../parquet/VectorizedRleValuesReader.java | 5 +-- .../vectorized/OffHeapColumnVector.java | 10 ++--- .../vectorized/OnHeapColumnVector.java | 10 ++--- .../vectorized/WritableColumnVector.java | 27 +++++-------- .../datasources/DataSourceUtils.scala | 15 +++---- .../parquet/ParquetRowConverter.scala | 14 +++---- .../connector/JavaColumnarDataSourceV2.java | 4 +- .../sql/SparkSessionExtensionSuite.scala | 6 +-- .../sql/connector/DataSourceV2Suite.scala | 4 +- .../compression/BooleanBitSetSuite.scala | 4 +- .../compression/DictionaryEncodingSuite.scala | 2 +- .../compression/IntegralDeltaSuite.scala | 2 +- .../PassThroughEncodingSuite.scala | 2 +- .../compression/RunLengthEncodingSuite.scala | 2 +- .../vectorized/ColumnVectorSuite.scala | 8 ++-- .../vectorized/ColumnarBatchBenchmark.scala | 16 ++++---- .../vectorized/ColumnarBatchSuite.scala | 4 +- 22 files changed, 88 insertions(+), 115 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 6597c56efbede..a19a7b0d0edd1 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -62,10 +62,10 @@ private[sql] class AvroDeserializer( private lazy val decimalConversions = new DecimalConversion() private val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( - datetimeRebaseMode, "Avro")(_, _) + datetimeRebaseMode, "Avro") private val timestampRebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( - datetimeRebaseMode, "Avro")(_, _) + datetimeRebaseMode, "Avro") private val converter: Any => Option[Any] = try { rootCatalystType match { @@ -126,8 +126,7 @@ private[sql] class AvroDeserializer( updater.setInt(ordinal, value.asInstanceOf[Int]) case (INT, DateType) => (updater, ordinal, value) => - updater.setInt(ordinal, - dateRebaseFunc(avroType.getName, catalystType)(value.asInstanceOf[Int])) + updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int])) case (LONG, LongType) => (updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long]) @@ -138,10 +137,10 @@ private[sql] class AvroDeserializer( case null | _: TimestampMillis => (updater, ordinal, value) => val millis = value.asInstanceOf[Long] val micros = DateTimeUtils.millisToMicros(millis) - updater.setLong(ordinal, timestampRebaseFunc(avroType.getName, catalystType)(micros)) + updater.setLong(ordinal, timestampRebaseFunc(micros)) case _: TimestampMicros => (updater, ordinal, value) => val micros = value.asInstanceOf[Long] - updater.setLong(ordinal, timestampRebaseFunc(avroType.getName, catalystType)(micros)) + updater.setLong(ordinal, timestampRebaseFunc(micros)) case other => throw new IncompatibleSchemaException(errorPrefix + s"Avro logical type $other cannot be converted to SQL type ${TimestampType.sql}.") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 8b2827ae14c5f..c5a608e38da56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -394,16 +394,11 @@ object QueryExecutionErrors { } def sparkUpgradeInReadingDatesError( - colName: String, - dataType: DataType, - format: String, - config: String, - option: String): SparkUpgradeException = { + format: String, config: String, option: String): SparkUpgradeException = { new SparkUpgradeException("3.0", s""" |reading dates before 1582-10-15 or timestamps before 1900-01-01T00:00:00Z from $format - |files can be ambiguous when read column `${colName}` of datatype `${dataType}`, - |as the files may be written by Spark 2.x or legacy versions of + |files can be ambiguous, as the files may be written by Spark 2.x or legacy versions of |Hive, which uses a legacy hybrid calendar that is different from Spark 3.0+'s Proleptic |Gregorian calendar. See more details in SPARK-31404. You can set the SQL config |'$config' or the datasource option '$option' to 'LEGACY' to rebase the datetime values diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index e316ec0c830f4..40ed0b2454c12 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -167,7 +167,7 @@ public void initBatch( for (int i = 0; i < requiredFields.length; i++) { DataType dt = requiredFields[i].dataType(); if (requestedPartitionColIds[i] != -1) { - OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, requiredFields[i].name(),dt); + OnHeapColumnVector partitionCol = new OnHeapColumnVector(capacity, dt); ColumnVectorUtils.populate(partitionCol, partitionValues, requestedPartitionColIds[i]); partitionCol.setIsConstant(); orcVectorWrappers[i] = partitionCol; @@ -175,7 +175,7 @@ public void initBatch( int colId = requestedDataColIds[i]; // Initialize the missing columns once. if (colId == -1) { - OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, requiredFields[i].name(), dt); + OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); missingCol.putNulls(0, capacity); missingCol.setIsConstant(); orcVectorWrappers[i] = missingCol; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 36fe7d2252a65..52620b0740851 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -190,13 +190,10 @@ private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName return isSupported; } - static int rebaseDays( - int julianDays, - final boolean failIfRebase, - WritableColumnVector c) { + static int rebaseDays(int julianDays, final boolean failIfRebase) { if (failIfRebase) { if (julianDays < RebaseDateTime.lastSwitchJulianDay()) { - throw DataSourceUtils.newRebaseExceptionInRead(c.colName, c.dataType(), "Parquet"); + throw DataSourceUtils.newRebaseExceptionInRead("Parquet"); } else { return julianDays; } @@ -208,11 +205,10 @@ static int rebaseDays( private static long rebaseTimestamp( long julianMicros, final boolean failIfRebase, - WritableColumnVector c, final String format) { if (failIfRebase) { if (julianMicros < RebaseDateTime.lastSwitchJulianTs()) { - throw DataSourceUtils.newRebaseExceptionInRead(c.colName, c.dataType(), format); + throw DataSourceUtils.newRebaseExceptionInRead(format); } else { return julianMicros; } @@ -221,18 +217,12 @@ private static long rebaseTimestamp( } } - static long rebaseMicros( - long julianMicros, - final boolean failIfRebase, - WritableColumnVector c) { - return rebaseTimestamp(julianMicros, failIfRebase, c, "Parquet"); + static long rebaseMicros(long julianMicros, final boolean failIfRebase) { + return rebaseTimestamp(julianMicros, failIfRebase, "Parquet"); } - static long rebaseInt96( - long julianMicros, - final boolean failIfRebase, - WritableColumnVector c) { - return rebaseTimestamp(julianMicros, failIfRebase, c, "Parquet INT96"); + static long rebaseInt96(long julianMicros, final boolean failIfRebase) { + return rebaseTimestamp(julianMicros, failIfRebase, "Parquet INT96"); } /** @@ -397,7 +387,7 @@ private void decodeDictionaryIds( for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { int julianDays = dictionary.decodeToInt(dictionaryIds.getDictId(i)); - column.putInt(i, rebaseDays(julianDays, failIfRebase, column)); + column.putInt(i, rebaseDays(julianDays, failIfRebase)); } } } else { @@ -442,7 +432,7 @@ private void decodeDictionaryIds( if (!column.isNullAt(i)) { long julianMillis = dictionary.decodeToLong(dictionaryIds.getDictId(i)); long julianMicros = DateTimeUtils.millisToMicros(julianMillis); - column.putLong(i, rebaseMicros(julianMicros, failIfRebase, column)); + column.putLong(i, rebaseMicros(julianMicros, failIfRebase)); } } } @@ -451,7 +441,7 @@ private void decodeDictionaryIds( for (int i = rowId; i < rowId + num; ++i) { if (!column.isNullAt(i)) { long julianMicros = dictionary.decodeToLong(dictionaryIds.getDictId(i)); - column.putLong(i, rebaseMicros(julianMicros, failIfRebase, column)); + column.putLong(i, rebaseMicros(julianMicros, failIfRebase)); } } } else { @@ -490,7 +480,7 @@ private void decodeDictionaryIds( if (!column.isNullAt(i)) { Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i)); long julianMicros = ParquetRowConverter.binaryToSQLTimestamp(v); - long gregorianMicros = rebaseInt96(julianMicros, failIfRebase, column); + long gregorianMicros = rebaseInt96(julianMicros, failIfRebase); column.putLong(i, gregorianMicros); } } @@ -510,7 +500,7 @@ private void decodeDictionaryIds( if (!column.isNullAt(i)) { Binary v = dictionary.decodeToBinary(dictionaryIds.getDictId(i)); long julianMicros = ParquetRowConverter.binaryToSQLTimestamp(v); - long gregorianMicros = rebaseInt96(julianMicros, failIfRebase, column); + long gregorianMicros = rebaseInt96(julianMicros, failIfRebase); long adjTime = DateTimeUtils.convertTz(gregorianMicros, convertTz, UTC); column.putLong(i, adjTime); } @@ -650,7 +640,7 @@ private void readLongBatch(int rowId, int num, WritableColumnVector column) thro for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { long julianMicros = DateTimeUtils.millisToMicros(dataColumn.readLong()); - column.putLong(rowId + i, rebaseMicros(julianMicros, failIfRebase, column)); + column.putLong(rowId + i, rebaseMicros(julianMicros, failIfRebase)); } else { column.putNull(rowId + i); } @@ -708,7 +698,7 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) th if (defColumn.readInteger() == maxDefLevel) { // Read 12 bytes for INT96 long julianMicros = ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12)); - long gregorianMicros = rebaseInt96(julianMicros, failIfRebase, column); + long gregorianMicros = rebaseInt96(julianMicros, failIfRebase); column.putLong(rowId + i, gregorianMicros); } else { column.putNull(rowId + i); @@ -732,7 +722,7 @@ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) th if (defColumn.readInteger() == maxDefLevel) { // Read 12 bytes for INT96 long julianMicros = ParquetRowConverter.binaryToSQLTimestamp(data.readBinary(12)); - long gregorianMicros = rebaseInt96(julianMicros, failIfRebase, column); + long gregorianMicros = rebaseInt96(julianMicros, failIfRebase); long adjTime = DateTimeUtils.convertTz(gregorianMicros, convertTz, UTC); column.putLong(rowId + i, adjTime); } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 95a4dafb7ad8c..6a0038dbdc44c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -107,7 +107,7 @@ public final void readIntegersWithRebase( } if (rebase) { if (failIfRebase) { - throw DataSourceUtils.newRebaseExceptionInRead(c.colName, c.dataType(), "Parquet"); + throw DataSourceUtils.newRebaseExceptionInRead("Parquet"); } else { for (int i = 0; i < total; i += 1) { c.putInt(rowId + i, RebaseDateTime.rebaseJulianToGregorianDays(buffer.getInt())); @@ -164,7 +164,7 @@ public final void readLongsWithRebase( } if (rebase) { if (failIfRebase) { - throw DataSourceUtils.newRebaseExceptionInRead(c.colName, c.dataType(), "Parquet"); + throw DataSourceUtils.newRebaseExceptionInRead("Parquet"); } else { for (int i = 0; i < total; i += 1) { c.putLong(rowId + i, RebaseDateTime.rebaseJulianToGregorianMicros(buffer.getLong())); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 7297861e5dfce..125506d4d5013 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -264,8 +264,7 @@ public void readIntegersWithRebase( for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { int julianDays = data.readInteger(); - c.putInt(rowId + i, - VectorizedColumnReader.rebaseDays(julianDays, failIfRebase, c)); + c.putInt(rowId + i, VectorizedColumnReader.rebaseDays(julianDays, failIfRebase)); } else { c.putNull(rowId + i); } @@ -493,7 +492,7 @@ public void readLongsWithRebase( for (int i = 0; i < n; ++i) { if (currentBuffer[currentBufferIdx++] == level) { long julianMicros = data.readLong(); - c.putLong(rowId + i, VectorizedColumnReader.rebaseMicros(julianMicros, failIfRebase, c)); + c.putLong(rowId + i, VectorizedColumnReader.rebaseMicros(julianMicros, failIfRebase)); } else { c.putNull(rowId + i); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 80bd2ee590c70..7da5a287710eb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -50,7 +50,7 @@ public static OffHeapColumnVector[] allocateColumns(int capacity, StructType sch public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[] fields) { OffHeapColumnVector[] vectors = new OffHeapColumnVector[fields.length]; for (int i = 0; i < fields.length; i++) { - vectors[i] = new OffHeapColumnVector(capacity, fields[i].name(), fields[i].dataType()); + vectors[i] = new OffHeapColumnVector(capacity, fields[i].dataType()); } return vectors; } @@ -64,8 +64,8 @@ public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[] private long lengthData; private long offsetData; - public OffHeapColumnVector(int capacity, String colName, DataType type) { - super(capacity, colName, type); + public OffHeapColumnVector(int capacity, DataType type) { + super(capacity, type); nulls = 0; data = 0; @@ -566,7 +566,7 @@ protected void reserveInternal(int newCapacity) { } @Override - protected OffHeapColumnVector reserveNewColumn(int capacity, String colName, DataType type) { - return new OffHeapColumnVector(capacity, colName, type); + protected OffHeapColumnVector reserveNewColumn(int capacity, DataType type) { + return new OffHeapColumnVector(capacity, type); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index f7f28f095f11b..5a7d6cc20971b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -50,7 +50,7 @@ public static OnHeapColumnVector[] allocateColumns(int capacity, StructType sche public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] fields) { OnHeapColumnVector[] vectors = new OnHeapColumnVector[fields.length]; for (int i = 0; i < fields.length; i++) { - vectors[i] = new OnHeapColumnVector(capacity, fields[i].name(), fields[i].dataType()); + vectors[i] = new OnHeapColumnVector(capacity, fields[i].dataType()); } return vectors; } @@ -73,8 +73,8 @@ public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] f private int[] arrayLengths; private int[] arrayOffsets; - public OnHeapColumnVector(int capacity, String colName, DataType type) { - super(capacity, colName, type); + public OnHeapColumnVector(int capacity, DataType type) { + super(capacity, type); reserveInternal(capacity); reset(); @@ -580,7 +580,7 @@ protected void reserveInternal(int newCapacity) { } @Override - protected OnHeapColumnVector reserveNewColumn(int capacity, String colName, DataType type) { - return new OnHeapColumnVector(capacity, colName, type); + protected OnHeapColumnVector reserveNewColumn(int capacity, DataType type) { + return new OnHeapColumnVector(capacity, type); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 64ee74dba0d44..8c0f1e1257503 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -165,7 +165,7 @@ public void setDictionary(Dictionary dictionary) { */ public WritableColumnVector reserveDictionaryIds(int capacity) { if (dictionaryIds == null) { - dictionaryIds = reserveNewColumn(capacity, colName, DataTypes.IntegerType); + dictionaryIds = reserveNewColumn(capacity, DataTypes.IntegerType); } else { dictionaryIds.reset(); dictionaryIds.reserve(capacity); @@ -677,11 +677,6 @@ public WritableColumnVector arrayData() { */ public final void setIsConstant() { isConstant = true; } - /** - * Column name of this column. - */ - public String colName; - /** * Maximum number of rows that can be stored in this column. */ @@ -722,7 +717,7 @@ public WritableColumnVector arrayData() { /** * Reserve a new column. */ - protected abstract WritableColumnVector reserveNewColumn(int capacity, String colName, DataType type); + protected abstract WritableColumnVector reserveNewColumn(int capacity, DataType type); protected boolean isArray() { return type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType || @@ -733,9 +728,8 @@ protected boolean isArray() { * Sets up the common state and also handles creating the child columns if this is a nested * type. */ - protected WritableColumnVector(int capacity, String colName, DataType type) { + protected WritableColumnVector(int capacity, DataType type) { super(type); - this.colName = colName; this.capacity = capacity; if (isArray()) { @@ -748,25 +742,24 @@ protected WritableColumnVector(int capacity, String colName, DataType type) { childCapacity *= DEFAULT_ARRAY_LENGTH; } this.childColumns = new WritableColumnVector[1]; - this.childColumns[0] = reserveNewColumn(childCapacity, colName + ".elem", childType); + this.childColumns[0] = reserveNewColumn(childCapacity, childType); } else if (type instanceof StructType) { StructType st = (StructType)type; this.childColumns = new WritableColumnVector[st.fields().length]; for (int i = 0; i < childColumns.length; ++i) { - this.childColumns[i] = reserveNewColumn(capacity, colName + "." + st.fields()[i].name(), - st.fields()[i].dataType()); + this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); } } else if (type instanceof MapType) { MapType mapType = (MapType) type; this.childColumns = new WritableColumnVector[2]; - this.childColumns[0] = reserveNewColumn(capacity, colName + ".key", mapType.keyType()); - this.childColumns[1] = reserveNewColumn(capacity, colName + ".value", mapType.valueType()); + this.childColumns[0] = reserveNewColumn(capacity, mapType.keyType()); + this.childColumns[1] = reserveNewColumn(capacity, mapType.valueType()); } else if (type instanceof CalendarIntervalType) { // Three columns. Months as int. Days as Int. Microseconds as Long. this.childColumns = new WritableColumnVector[3]; - this.childColumns[0] = reserveNewColumn(capacity, colName + ".months", DataTypes.IntegerType); - this.childColumns[1] = reserveNewColumn(capacity, colName + ".days", DataTypes.IntegerType); - this.childColumns[2] = reserveNewColumn(capacity, colName + ".microseconds", DataTypes.LongType); + this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); + this.childColumns[1] = reserveNewColumn(capacity, DataTypes.IntegerType); + this.childColumns[2] = reserveNewColumn(capacity, DataTypes.LongType); } else { this.childColumns = null; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 2dc680a0e5691..2b10e4efd9ab8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -131,10 +131,7 @@ object DataSourceUtils { }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) } - def newRebaseExceptionInRead( - colName: String, - dataType: DataType, - format: String): SparkUpgradeException = { + def newRebaseExceptionInRead(format: String): SparkUpgradeException = { val (config, option) = format match { case "Parquet INT96" => (SQLConf.PARQUET_INT96_REBASE_MODE_IN_READ.key, ParquetOptions.INT96_REBASE_MODE) @@ -144,7 +141,7 @@ object DataSourceUtils { (SQLConf.AVRO_REBASE_MODE_IN_READ.key, "datetimeRebaseMode") case _ => throw QueryExecutionErrors.unrecognizedFileFormatError(format) } - QueryExecutionErrors.sparkUpgradeInReadingDatesError(colName, dataType, format, config, option) + QueryExecutionErrors.sparkUpgradeInReadingDatesError(format, config, option) } def newRebaseExceptionInWrite(format: String): SparkUpgradeException = { @@ -159,10 +156,10 @@ object DataSourceUtils { def creteDateRebaseFuncInRead( rebaseMode: LegacyBehaviorPolicy.Value, - format: String)(colName: String, dataType: DataType): Int => Int = rebaseMode match { + format: String): Int => Int = rebaseMode match { case LegacyBehaviorPolicy.EXCEPTION => days: Int => if (days < RebaseDateTime.lastSwitchJulianDay) { - throw DataSourceUtils.newRebaseExceptionInRead(colName, dataType, format) + throw DataSourceUtils.newRebaseExceptionInRead(format) } days case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianDays @@ -183,10 +180,10 @@ object DataSourceUtils { def creteTimestampRebaseFuncInRead( rebaseMode: LegacyBehaviorPolicy.Value, - format: String)(colName: String, dataType: DataType): Long => Long = rebaseMode match { + format: String): Long => Long = rebaseMode match { case LegacyBehaviorPolicy.EXCEPTION => micros: Long => if (micros < RebaseDateTime.lastSwitchJulianTs) { - throw DataSourceUtils.newRebaseExceptionInRead(colName, dataType, format) + throw DataSourceUtils.newRebaseExceptionInRead(format) } micros case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianMicros diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index d02e162c2b612..0a1cca7ed0f3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -189,13 +189,13 @@ private[parquet] class ParquetRowConverter( def currentRecord: InternalRow = currentRow private val dateRebaseFunc = DataSourceUtils.creteDateRebaseFuncInRead( - datetimeRebaseMode, "Parquet")(_, _) + datetimeRebaseMode, "Parquet") private val timestampRebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( - datetimeRebaseMode, "Parquet")(_, _) + datetimeRebaseMode, "Parquet") private val int96RebaseFunc = DataSourceUtils.creteTimestampRebaseFuncInRead( - int96RebaseMode, "Parquet INT96")(_, _) + int96RebaseMode, "Parquet INT96") // Converters for each field. private[this] val fieldConverters: Array[Converter with HasParentContainerUpdater] = { @@ -332,7 +332,7 @@ private[parquet] class ParquetRowConverter( case TimestampType if parquetType.getOriginalType == OriginalType.TIMESTAMP_MICROS => new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { - updater.setLong(timestampRebaseFunc(parquetType.getName, catalystType)(value)) + updater.setLong(timestampRebaseFunc(value)) } } @@ -340,7 +340,7 @@ private[parquet] class ParquetRowConverter( new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { val micros = DateTimeUtils.millisToMicros(value) - updater.setLong(timestampRebaseFunc(parquetType.getName, catalystType)(micros)) + updater.setLong(timestampRebaseFunc(micros)) } } @@ -350,7 +350,7 @@ private[parquet] class ParquetRowConverter( // Converts nanosecond timestamps stored as INT96 override def addBinary(value: Binary): Unit = { val julianMicros = ParquetRowConverter.binaryToSQLTimestamp(value) - val gregorianMicros = int96RebaseFunc(parquetType.getName, catalystType)(julianMicros) + val gregorianMicros = int96RebaseFunc(julianMicros) val adjTime = convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) .getOrElse(gregorianMicros) updater.setLong(adjTime) @@ -360,7 +360,7 @@ private[parquet] class ParquetRowConverter( case DateType => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { - updater.set(dateRebaseFunc(parquetType.getName, catalystType)(value)) + updater.set(dateRebaseFunc(value)) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java index 3bdc669dbf311..2f10c84c999f9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaColumnarDataSourceV2.java @@ -77,8 +77,8 @@ public PartitionReader createReader(InputPartition partition) { @Override public PartitionReader createColumnarReader(InputPartition partition) { JavaRangeInputPartition p = (JavaRangeInputPartition) partition; - OnHeapColumnVector i = new OnHeapColumnVector(BATCH_SIZE, "", DataTypes.IntegerType); - OnHeapColumnVector j = new OnHeapColumnVector(BATCH_SIZE, "", DataTypes.IntegerType); + OnHeapColumnVector i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + OnHeapColumnVector j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); ColumnVector[] vectors = new ColumnVector[2]; vectors[0] = i; vectors[1] = j; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 951ee81ac7de2..d4a6d84ce2b30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -675,7 +675,7 @@ class BrokenColumnarAdd( } else if (lhs.isInstanceOf[ColumnVector] && rhs.isInstanceOf[ColumnVector]) { val l = lhs.asInstanceOf[ColumnVector] val r = rhs.asInstanceOf[ColumnVector] - val result = new OnHeapColumnVector(batch.numRows(), "", dataType) + val result = new OnHeapColumnVector(batch.numRows(), dataType) ret = result for (i <- 0 until batch.numRows()) { @@ -684,7 +684,7 @@ class BrokenColumnarAdd( } else if (rhs.isInstanceOf[ColumnVector]) { val l = lhs.asInstanceOf[Long] val r = rhs.asInstanceOf[ColumnVector] - val result = new OnHeapColumnVector(batch.numRows(), "", dataType) + val result = new OnHeapColumnVector(batch.numRows(), dataType) ret = result for (i <- 0 until batch.numRows()) { @@ -693,7 +693,7 @@ class BrokenColumnarAdd( } else if (lhs.isInstanceOf[ColumnVector]) { val l = lhs.asInstanceOf[ColumnVector] val r = rhs.asInstanceOf[Long] - val result = new OnHeapColumnVector(batch.numRows(), "", dataType) + val result = new OnHeapColumnVector(batch.numRows(), dataType) ret = result for (i <- 0 until batch.numRows()) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 562596c9fbf8e..49a1078800552 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -682,8 +682,8 @@ object ColumnarReaderFactory extends PartitionReaderFactory { override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { val RangeInputPartition(start, end) = partition new PartitionReader[ColumnarBatch] { - private lazy val i = new OnHeapColumnVector(BATCH_SIZE, "", IntegerType) - private lazy val j = new OnHeapColumnVector(BATCH_SIZE, "", IntegerType) + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) private lazy val batch = new ColumnarBatch(Array(i, j)) private var current = start diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index 790cda9bfa5b1..111a620df8c24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -105,7 +105,7 @@ class BooleanBitSetSuite extends SparkFunSuite { assertResult(BooleanBitSet.typeId, "Wrong compression scheme ID")(buffer.getInt()) val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) - val columnVector = new OnHeapColumnVector(values.length, "", BooleanType) + val columnVector = new OnHeapColumnVector(values.length, BooleanType) decoder.decompress(columnVector, values.length) if (values.nonEmpty) { @@ -175,7 +175,7 @@ class BooleanBitSetSuite extends SparkFunSuite { assertResult(BooleanBitSet.typeId, "Wrong compression scheme ID")(buffer.getInt()) val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) - val columnVector = new OnHeapColumnVector(numRows, "", BooleanType) + val columnVector = new OnHeapColumnVector(numRows, BooleanType) decoder.decompress(columnVector, numRows) (0 until numRows).foreach { rowNum => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index 3fc556557fa0e..61e4cc068fa80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -142,7 +142,7 @@ class DictionaryEncodingSuite extends SparkFunSuite { assertResult(DictionaryEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) val decoder = DictionaryEncoding.decoder(buffer, columnType) - val columnVector = new OnHeapColumnVector(inputSeq.length, "", columnType.dataType) + val columnVector = new OnHeapColumnVector(inputSeq.length, columnType.dataType) decoder.decompress(columnVector, inputSeq.length) if (inputSeq.nonEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index af1c5a34e35b0..b5630488b3667 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -136,7 +136,7 @@ class IntegralDeltaSuite extends SparkFunSuite { assertResult(scheme.typeId, "Wrong compression scheme ID")(buffer.getInt()) val decoder = scheme.decoder(buffer, columnType) - val columnVector = new OnHeapColumnVector(input.length, "", columnType.dataType) + val columnVector = new OnHeapColumnVector(input.length, columnType.dataType) decoder.decompress(columnVector, input.length) if (input.nonEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala index 081d4dfaae4f3..c6fe64d1058ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala @@ -117,7 +117,7 @@ class PassThroughSuite extends SparkFunSuite { assertResult(PassThrough.typeId, "Wrong compression scheme ID")(buffer.getInt()) val decoder = PassThrough.decoder(buffer, columnType) - val columnVector = new OnHeapColumnVector(input.length, "", columnType.dataType) + val columnVector = new OnHeapColumnVector(input.length, columnType.dataType) decoder.decompress(columnVector, input.length) if (input.nonEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index 110dc6681200b..29dbc13b59c6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -126,7 +126,7 @@ class RunLengthEncodingSuite extends SparkFunSuite { assertResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) val decoder = RunLengthEncoding.decoder(buffer, columnType) - val columnVector = new OnHeapColumnVector(inputSeq.length, "", columnType.dataType) + val columnVector = new OnHeapColumnVector(inputSeq.length, columnType.dataType) decoder.decompress(columnVector, inputSeq.length) if (inputSeq.nonEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 97be5e5e1221a..247efd5554a8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -38,8 +38,8 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { size: Int, dt: DataType)( block: WritableColumnVector => Unit): Unit = { - withVector(new OnHeapColumnVector(size, "", dt))(block) - withVector(new OffHeapColumnVector(size, "", dt))(block) + withVector(new OnHeapColumnVector(size, dt))(block) + withVector(new OffHeapColumnVector(size, dt))(block) } private def testVectors( @@ -259,7 +259,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { - withVector(new OffHeapColumnVector(8, "", arrayType)) { testVector => + withVector(new OffHeapColumnVector(8, arrayType)) { testVector => val data = testVector.arrayData() (0 until 8).foreach(i => data.putInt(i, i)) (0 until 8).foreach(i => testVector.putArray(i, i, 1)) @@ -275,7 +275,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } test("[SPARK-22092] off-heap column vector reallocation corrupts struct nullability") { - withVector(new OffHeapColumnVector(8, "", structType)) { testVector => + withVector(new OffHeapColumnVector(8, structType)) { testVector => (0 until 8).foreach(i => if (i % 2 == 0) testVector.putNull(i) else testVector.putNotNull(i)) testVector.reserve(16) (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index eb9f70902add0..f9ae611691a7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -147,7 +147,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { // Access through the column API with on heap memory val columnOnHeap = { i: Int => - val col = new OnHeapColumnVector(count, "", IntegerType) + val col = new OnHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -166,7 +166,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { // Access through the column API with off heap memory def columnOffHeap = { i: Int => { - val col = new OffHeapColumnVector(count, "", IntegerType) + val col = new OffHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -185,7 +185,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { // Access by directly getting the buffer backing the column. val columnOffheapDirect = { i: Int => - val col = new OffHeapColumnVector(count, "", IntegerType) + val col = new OffHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { var addr = col.valuesNativeAddress() @@ -251,7 +251,7 @@ object ColumnarBatchBenchmark extends BenchmarkBase { // Adding values by appending, instead of putting. val onHeapAppend = { i: Int => - val col = new OnHeapColumnVector(count, "", IntegerType) + val col = new OnHeapColumnVector(count, IntegerType) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -347,9 +347,9 @@ object ColumnarBatchBenchmark extends BenchmarkBase { def column(memoryMode: MemoryMode) = { i: Int => val column = if (memoryMode == MemoryMode.OFF_HEAP) { - new OffHeapColumnVector(count, "", BinaryType) + new OffHeapColumnVector(count, BinaryType) } else { - new OnHeapColumnVector(count, "", BinaryType) + new OnHeapColumnVector(count, BinaryType) } var sum = 0L @@ -378,8 +378,8 @@ object ColumnarBatchBenchmark extends BenchmarkBase { val random = new Random(0) val count = 4 * 1000 - val onHeapVector = new OnHeapColumnVector(count, "", ArrayType(IntegerType)) - val offHeapVector = new OffHeapColumnVector(count, "", ArrayType(IntegerType)) + val onHeapVector = new OnHeapColumnVector(count, ArrayType(IntegerType)) + val offHeapVector = new OffHeapColumnVector(count, ArrayType(IntegerType)) val minSize = 3 val maxSize = 32 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index a6b210277b6ca..bd69bab6f5da2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -44,9 +44,9 @@ class ColumnarBatchSuite extends SparkFunSuite { private def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { if (memMode == MemoryMode.OFF_HEAP) { - new OffHeapColumnVector(capacity, "", dt) + new OffHeapColumnVector(capacity, dt) } else { - new OnHeapColumnVector(capacity, "", dt) + new OnHeapColumnVector(capacity, dt) } } From 6aa05fca305139fd345be7d5b219728f349eddd5 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 19 Apr 2021 10:57:56 +0800 Subject: [PATCH 19/22] fix UT --- .../resources/sql-tests/inputs/transform.sql | 19 ++++++++------ .../sql-tests/results/transform.sql.out | 25 +++++++++++-------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index e3ae34ee89300..a279a0e4dcf09 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -263,14 +263,17 @@ FROM script_trans WHERE a <= 4 WINDOW w AS (PARTITION BY b ORDER BY a); -SELECT TRANSFORM(b, MAX(a), CAST(SUM(c) AS STRING), myCol, myCol2) - USING 'cat' AS (a, b, c, d, e) -FROM script_trans -LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol -LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2 -WHERE a <= 4 -GROUP BY b, myCol, myCol2 -HAVING max(a) > 1; +SELECT a, b, c, CAST(d AS STriNG), e +FROM ( + SELECT TRANSFORM(b, MAX(a), CAST(SUM(c) AS STRING), myCol, myCol2) + USING 'cat' AS (a, b, c, d, e) + FROM script_trans + LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol + LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2 + WHERE a <= 4 + GROUP BY b, myCol, myCol2 + HAVING max(a) > 1 +) tmp; FROM( FROM script_trans diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index c20ec4ca20715..3ada90e7de966 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -494,20 +494,23 @@ struct -- !query -SELECT TRANSFORM(b, MAX(a), CAST(SUM(c) AS STRING), myCol, myCol2) - USING 'cat' AS (a, b, c, d, e) -FROM script_trans -LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol -LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2 -WHERE a <= 4 -GROUP BY b, myCol, myCol2 -HAVING max(a) > 1 +SELECT a, b, c, CAST(d AS STriNG), e +FROM ( + SELECT TRANSFORM(b, MAX(a), CAST(SUM(c) AS STRING), myCol, myCol2) + USING 'cat' AS (a, b, c, d, e) + FROM script_trans + LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol + LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2 + WHERE a <= 4 + GROUP BY b, myCol, myCol2 + HAVING max(a) > 1 +) tmp -- !query schema struct -- !query output -5 4 6 [1, 2, 3] 1 -5 4 6 [1, 2, 3] 2 -5 4 6 [1, 2, 3] 3 +5 4 6 [1,2,3] 1 +5 4 6 [1,2,3] 2 +5 4 6 [1,2,3] 3 -- !query From 9e3f808cd64db622456249bad5bd4fe0a0c566cb Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 19 Apr 2021 11:14:53 +0800 Subject: [PATCH 20/22] Revert "fix UT" This reverts commit 6aa05fca305139fd345be7d5b219728f349eddd5. --- .../resources/sql-tests/inputs/transform.sql | 19 ++++++-------- .../sql-tests/results/transform.sql.out | 25 ++++++++----------- 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index a279a0e4dcf09..e3ae34ee89300 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -263,17 +263,14 @@ FROM script_trans WHERE a <= 4 WINDOW w AS (PARTITION BY b ORDER BY a); -SELECT a, b, c, CAST(d AS STriNG), e -FROM ( - SELECT TRANSFORM(b, MAX(a), CAST(SUM(c) AS STRING), myCol, myCol2) - USING 'cat' AS (a, b, c, d, e) - FROM script_trans - LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol - LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2 - WHERE a <= 4 - GROUP BY b, myCol, myCol2 - HAVING max(a) > 1 -) tmp; +SELECT TRANSFORM(b, MAX(a), CAST(SUM(c) AS STRING), myCol, myCol2) + USING 'cat' AS (a, b, c, d, e) +FROM script_trans +LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol +LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2 +WHERE a <= 4 +GROUP BY b, myCol, myCol2 +HAVING max(a) > 1; FROM( FROM script_trans diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index 3ada90e7de966..c20ec4ca20715 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -494,23 +494,20 @@ struct -- !query -SELECT a, b, c, CAST(d AS STriNG), e -FROM ( - SELECT TRANSFORM(b, MAX(a), CAST(SUM(c) AS STRING), myCol, myCol2) - USING 'cat' AS (a, b, c, d, e) - FROM script_trans - LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol - LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2 - WHERE a <= 4 - GROUP BY b, myCol, myCol2 - HAVING max(a) > 1 -) tmp +SELECT TRANSFORM(b, MAX(a), CAST(SUM(c) AS STRING), myCol, myCol2) + USING 'cat' AS (a, b, c, d, e) +FROM script_trans +LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol +LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2 +WHERE a <= 4 +GROUP BY b, myCol, myCol2 +HAVING max(a) > 1 -- !query schema struct -- !query output -5 4 6 [1,2,3] 1 -5 4 6 [1,2,3] 2 -5 4 6 [1,2,3] 3 +5 4 6 [1, 2, 3] 1 +5 4 6 [1, 2, 3] 2 +5 4 6 [1, 2, 3] 3 -- !query From 3f51d278a29f1a468a02a83325d2c71cc264ac02 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 19 Apr 2021 11:19:35 +0800 Subject: [PATCH 21/22] fix UT --- .../src/test/resources/sql-tests/inputs/transform.sql | 2 +- .../test/resources/sql-tests/results/transform.sql.out | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/transform.sql b/sql/core/src/test/resources/sql-tests/inputs/transform.sql index e3ae34ee89300..7419ca1bd0a80 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/transform.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/transform.sql @@ -264,7 +264,7 @@ WHERE a <= 4 WINDOW w AS (PARTITION BY b ORDER BY a); SELECT TRANSFORM(b, MAX(a), CAST(SUM(c) AS STRING), myCol, myCol2) - USING 'cat' AS (a, b, c, d, e) + USING 'cat' AS (a STRING, b STRING, c STRING, d ARRAY, e STRING) FROM script_trans LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2 diff --git a/sql/core/src/test/resources/sql-tests/results/transform.sql.out b/sql/core/src/test/resources/sql-tests/results/transform.sql.out index c20ec4ca20715..1d7e9cdb430e0 100644 --- a/sql/core/src/test/resources/sql-tests/results/transform.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/transform.sql.out @@ -495,7 +495,7 @@ struct -- !query SELECT TRANSFORM(b, MAX(a), CAST(SUM(c) AS STRING), myCol, myCol2) - USING 'cat' AS (a, b, c, d, e) + USING 'cat' AS (a STRING, b STRING, c STRING, d ARRAY, e STRING) FROM script_trans LATERAL VIEW explode(array(array(1,2,3))) myTable AS myCol LATERAL VIEW explode(myTable.myCol) myTable2 AS myCol2 @@ -503,11 +503,11 @@ WHERE a <= 4 GROUP BY b, myCol, myCol2 HAVING max(a) > 1 -- !query schema -struct +struct,e:string> -- !query output -5 4 6 [1, 2, 3] 1 -5 4 6 [1, 2, 3] 2 -5 4 6 [1, 2, 3] 3 +5 4 6 [1,2,3] 1 +5 4 6 [1,2,3] 2 +5 4 6 [1,2,3] 3 -- !query From adf8a6682d6c9b2ff759c456d26e0d4358c0d965 Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 19 Apr 2021 11:26:16 +0800 Subject: [PATCH 22/22] Update sql-migration-guide.md --- docs/sql-migration-guide.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 4f18ff3268443..e9cf49a7ca4bf 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -79,6 +79,8 @@ license: | - In Spark 3.2, `TRANSFORM` operator can't support alias in inputs. In Spark 3.1 and earlier, we can write script transform like `SELECT TRANSFORM(a AS c1, b AS c2) USING 'cat' FROM TBL`. + - In Spark 3.2, `TRANSFORM` operator can support `ArrayType/MapType/StructType` without Hive SerDe, in this mode, we use `StructsToJosn` to convert `ArrayType/MapType/StructType` column to `STRING` and use `JsonToStructs` to parse `STRING` to `ArrayType/MapType/StructType`. In Spark 3.1, Spark just support case `ArrayType/MapType/StructType` column as `STRING` but can't support parse `STRING` to `ArrayType/MapType/StructType` output columns. + ## Upgrading from Spark SQL 3.0 to 3.1 - In Spark 3.1, statistical aggregation function includes `std`, `stddev`, `stddev_samp`, `variance`, `var_samp`, `skewness`, `kurtosis`, `covar_samp`, `corr` will return `NULL` instead of `Double.NaN` when `DivideByZero` occurs during expression evaluation, for example, when `stddev_samp` applied on a single element set. In Spark version 3.0 and earlier, it will return `Double.NaN` in such case. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.statisticalAggregate` to `true`.