Skip to content

Commit

Permalink
[SPARK-31937][SQL] Support processing array and map type using spark …
Browse files Browse the repository at this point in the history
…noserde mode
  • Loading branch information
AngersZhuuuu committed Dec 29, 2020
1 parent b33fa53 commit adc9ded
Show file tree
Hide file tree
Showing 4 changed files with 464 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ object CatalystTypeConverters {
convertedIterable += elementConverter.toCatalyst(item)
}
new GenericArrayData(convertedIterable.toArray)
case g: GenericArrayData => new GenericArrayData(g.array.map(elementConverter.toCatalyst))
case other => throw new IllegalArgumentException(
s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
s"AAAThe value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ s"cannot be converted to an array of ${elementType.catalogString}")
}
}
Expand Down Expand Up @@ -213,6 +214,9 @@ object CatalystTypeConverters {
scalaValue match {
case map: Map[_, _] => ArrayBasedMapData(map, keyFunction, valueFunction)
case javaMap: JavaMap[_, _] => ArrayBasedMapData(javaMap, keyFunction, valueFunction)
case map: ArrayBasedMapData =>
ArrayBasedMapData(map.keyArray.array.zip(map.valueArray.array).toMap,
keyFunction, valueFunction)
case other => throw new IllegalArgumentException(
s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ "cannot be converted to a map type with "
Expand Down Expand Up @@ -263,6 +267,15 @@ object CatalystTypeConverters {
idx += 1
}
new GenericInternalRow(ar)
case g: GenericInternalRow =>
val ar = new Array[Any](structType.size)
val values = g.values
var idx = 0
while (idx < structType.size) {
ar(idx) = converters(idx).toCatalyst(values(idx))
idx += 1
}
new GenericInternalRow(ar)
case other => throw new IllegalArgumentException(
s"The value (${other.toString}) of the type (${other.getClass.getCanonicalName}) "
+ s"cannot be converted to ${structType.catalogString}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution

import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream}
import java.nio.charset.StandardCharsets
import java.util.Map.Entry
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
Expand All @@ -33,10 +34,8 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils}

trait BaseScriptTransformationExec extends UnaryExecNode {
Expand All @@ -47,7 +46,12 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
def ioschema: ScriptTransformationIOSchema

protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = {
input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone))
input.map { in: Expression =>
in.dataType match {
case ArrayType(_, _) | MapType(_, _, _) | StructType(_) => in
case _ => Cast(in, StringType).withTimeZone(conf.sessionLocalTimeZone)
}
}
}

override def producedAttributes: AttributeSet = outputSet -- inputSet
Expand Down Expand Up @@ -182,58 +186,8 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
}

private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr =>
val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType)
attr.dataType match {
case StringType => wrapperConvertException(data => data, converter)
case BooleanType => wrapperConvertException(data => data.toBoolean, converter)
case ByteType => wrapperConvertException(data => data.toByte, converter)
case BinaryType =>
wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter)
case IntegerType => wrapperConvertException(data => data.toInt, converter)
case ShortType => wrapperConvertException(data => data.toShort, converter)
case LongType => wrapperConvertException(data => data.toLong, converter)
case FloatType => wrapperConvertException(data => data.toFloat, converter)
case DoubleType => wrapperConvertException(data => data.toDouble, converter)
case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter)
case DateType if conf.datetimeJava8ApiEnabled =>
wrapperConvertException(data => DateTimeUtils.stringToDate(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.daysToLocalDate).orNull, converter)
case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.toJavaDate).orNull, converter)
case TimestampType if conf.datetimeJava8ApiEnabled =>
wrapperConvertException(data => DateTimeUtils.stringToTimestamp(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.microsToInstant).orNull, converter)
case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.toJavaTimestamp).orNull, converter)
case CalendarIntervalType => wrapperConvertException(
data => IntervalUtils.stringToInterval(UTF8String.fromString(data)),
converter)
case udt: UserDefinedType[_] =>
wrapperConvertException(data => udt.deserialize(data), converter)
case dt =>
throw new SparkException(s"${nodeName} without serde does not support " +
s"${dt.getClass.getSimpleName} as output data type")
}
SparkInspectors.unwrapper(attr.dataType, conf, ioschema, 1)
}

// Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null
private val wrapperConvertException: (String => Any, Any => Any) => String => Any =
(f: String => Any, converter: Any => Any) =>
(data: String) => converter {
try {
f(data)
} catch {
case NonFatal(_) => null
}
}
}

abstract class BaseScriptTransformationWriterThread extends Thread with Logging {
Expand All @@ -256,18 +210,23 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging

protected def processRows(): Unit

val wrappers = inputSchema.map(dt => SparkInspectors.wrapper(dt))

protected def processRowsWithoutSerde(): Unit = {
val len = inputSchema.length
iter.foreach { row =>
val values = row.asInstanceOf[GenericInternalRow].values.zip(wrappers).map {
case (value, wrapper) => wrapper(value)
}
val data = if (len == 0) {
ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")
} else {
val sb = new StringBuilder
sb.append(row.get(0, inputSchema(0)))
buildString(sb, values(0), inputSchema(0), 1)
var i = 1
while (i < len) {
sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
sb.append(row.get(i, inputSchema(i)))
buildString(sb, values(i), inputSchema(i), 1)
i += 1
}
sb.append(ioSchema.inputRowFormatMap("TOK_TABLEROWFORMATLINES"))
Expand All @@ -277,6 +236,50 @@ abstract class BaseScriptTransformationWriterThread extends Thread with Logging
}
}

/**
* Convert data to string according to the data type.
*
* @param sb The StringBuilder to store the serialized data.
* @param obj The object for the current field.
* @param dataType The DataType for the current Object.
* @param level The current level of separator.
*/
private def buildString(sb: StringBuilder, obj: Any, dataType: DataType, level: Int): Unit = {
(obj, dataType) match {
case (list: java.util.List[_], ArrayType(typ, _)) =>
val separator = ioSchema.getSeparator(level)
(0 until list.size).foreach { i =>
if (i > 0) {
sb.append(separator)
}
buildString(sb, list.get(i), typ, level + 1)
}
case (map: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
val separator = ioSchema.getSeparator(level)
val keyValueSeparator = ioSchema.getSeparator(level + 1)
val entries = map.entrySet().toArray()
(0 until entries.size).foreach { i =>
if (i > 0) {
sb.append(separator)
}
val entry = entries(i).asInstanceOf[Entry[_, _]]
buildString(sb, entry.getKey, keyType, level + 2)
sb.append(keyValueSeparator)
buildString(sb, entry.getValue, valueType, level + 2)
}
case (arrayList: java.util.ArrayList[_], StructType(fields)) =>
val separator = ioSchema.getSeparator(level)
(0 until arrayList.size).foreach { i =>
if (i > 0) {
sb.append(separator)
}
buildString(sb, arrayList.get(i), fields(i).dataType, level + 1)
}
case (other, _) =>
sb.append(other)
}
}

override def run(): Unit = Utils.logUncaughtExceptions {
TaskContext.setTaskContext(taskContext)

Expand Down Expand Up @@ -329,14 +332,45 @@ case class ScriptTransformationIOSchema(
schemaLess: Boolean) extends Serializable {
import ScriptTransformationIOSchema._

val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
val inputRowFormatMap = inputRowFormat.toMap.withDefault(k => defaultFormat(k))
val outputRowFormatMap = outputRowFormat.toMap.withDefault(k => defaultFormat(k))

val separators = (getByte(inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 0.toByte) ::
getByte(inputRowFormatMap("TOK_TABLEROWFORMATCOLLITEMS"), 1.toByte) ::
getByte(inputRowFormatMap("TOK_TABLEROWFORMATMAPKEYS"), 2.toByte) :: Nil) ++
(4 to 8).map(_.toByte)

def getByte(altValue: String, defaultVal: Byte): Byte = {
if (altValue != null && altValue.length > 0) {
try {
java.lang.Byte.parseByte(altValue)
} catch {
case _: NumberFormatException =>
altValue.charAt(0).toByte
}
} else {
defaultVal
}
}

def getSeparator(level: Int): Char = {
try {
separators(level).toChar
} catch {
case _: IndexOutOfBoundsException =>
val msg = "Number of levels of nesting supported for Spark SQL script transform" +
" is " + (separators.length - 1) + " Unable to work with level " + level
throw new RuntimeException(msg)
}
}
}

object ScriptTransformationIOSchema {
val defaultFormat = Map(
("TOK_TABLEROWFORMATFIELD", "\t"),
("TOK_TABLEROWFORMATLINES", "\n")
("TOK_TABLEROWFORMATLINES", "\n"),
("TOK_TABLEROWFORMATCOLLITEMS", "\u0002"),
("TOK_TABLEROWFORMATMAPKEYS", "\u0003")
)

val defaultIOSchema = ScriptTransformationIOSchema(
Expand Down
Loading

0 comments on commit adc9ded

Please sign in to comment.