diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 22bf6df58b040..9fb12c614e78c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution -import java.io.OutputStream +import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream} import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -28,14 +29,26 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeSet, UnsafeProjection} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{CircularBuffer, SerializableConfiguration, Utils} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} trait BaseScriptTransformationExec extends UnaryExecNode { + def input: Seq[Expression] + def script: String + def output: Seq[Attribute] + def child: SparkPlan + def ioschema: ScriptTransformationIOSchema + + protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = { + input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) + } override def producedAttributes: AttributeSet = outputSet -- inputSet @@ -56,10 +69,91 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } } - def processIterator( + protected def initProc: (OutputStream, Process, InputStream, CircularBuffer) = { + val cmd = List("/bin/bash", "-c", script) + val builder = new ProcessBuilder(cmd.asJava) + + val proc = builder.start() + val inputStream = proc.getInputStream + val outputStream = proc.getOutputStream + val errorStream = proc.getErrorStream + + // In order to avoid deadlocks, we need to consume the error output of the child process. + // To avoid issues caused by large error output, we use a circular buffer to limit the amount + // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang + // that motivates this. + val stderrBuffer = new CircularBuffer(2048) + new RedirectThread( + errorStream, + stderrBuffer, + s"Thread-${this.getClass.getSimpleName}-STDERR-Consumer").start() + (outputStream, proc, inputStream, stderrBuffer) + } + + protected def processIterator( inputIterator: Iterator[InternalRow], hadoopConf: Configuration): Iterator[InternalRow] + protected def createOutputIteratorWithoutSerde( + writerThread: BaseScriptTransformationWriterThread, + inputStream: InputStream, + proc: Process, + stderrBuffer: CircularBuffer): Iterator[InternalRow] = { + new Iterator[InternalRow] { + var curLine: String = null + val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + + val outputRowFormat = ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD") + val processRowWithoutSerde = if (!ioschema.schemaLess) { + prevLine: String => + new GenericInternalRow( + prevLine.split(outputRowFormat) + .zip(outputFieldWriters) + .map { case (data, writer) => writer(data) }) + } else { + // In schema less mode, hive default serde will choose first two output column as output + // if output column size less then 2, it will throw ArrayIndexOutOfBoundsException. + // Here we change spark's behavior same as hive's default serde. + // But in hive, TRANSFORM with schema less behavior like origin spark, we will fix this + // to keep spark and hive behavior same in SPARK-32388 + val kvWriter = CatalystTypeConverters.createToCatalystConverter(StringType) + prevLine: String => + new GenericInternalRow( + prevLine.split(outputRowFormat).slice(0, 2) + .map(kvWriter)) + } + + override def hasNext: Boolean = { + try { + if (curLine == null) { + curLine = reader.readLine() + if (curLine == null) { + checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) + return false + } + } + true + } catch { + case NonFatal(e) => + // If this exception is due to abrupt / unclean termination of `proc`, + // then detect it and propagate a better exception message for end users + checkFailureAndPropagate(writerThread, e, proc, stderrBuffer) + + throw e + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException + } + val prevLine = curLine + curLine = reader.readLine() + processRowWithoutSerde(prevLine) + } + } + } + protected def checkFailureAndPropagate( writerThread: BaseScriptTransformationWriterThread, cause: Throwable = null, @@ -87,17 +181,72 @@ trait BaseScriptTransformationExec extends UnaryExecNode { } } } + + private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr => + val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType) + attr.dataType match { + case StringType => wrapperConvertException(data => data, converter) + case BooleanType => wrapperConvertException(data => data.toBoolean, converter) + case ByteType => wrapperConvertException(data => data.toByte, converter) + case BinaryType => + wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter) + case IntegerType => wrapperConvertException(data => data.toInt, converter) + case ShortType => wrapperConvertException(data => data.toShort, converter) + case LongType => wrapperConvertException(data => data.toLong, converter) + case FloatType => wrapperConvertException(data => data.toFloat, converter) + case DoubleType => wrapperConvertException(data => data.toDouble, converter) + case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter) + case DateType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.daysToLocalDate).orNull, converter) + case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaDate).orNull, converter) + case TimestampType if conf.datetimeJava8ApiEnabled => + wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.microsToInstant).orNull, converter) + case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp( + UTF8String.fromString(data), + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) + .map(DateTimeUtils.toJavaTimestamp).orNull, converter) + case CalendarIntervalType => wrapperConvertException( + data => IntervalUtils.stringToInterval(UTF8String.fromString(data)), + converter) + case udt: UserDefinedType[_] => + wrapperConvertException(data => udt.deserialize(data), converter) + case dt => + throw new SparkException(s"${nodeName} without serde does not support " + + s"${dt.getClass.getSimpleName} as output data type") + } + } + + // Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null + private val wrapperConvertException: (String => Any, Any => Any) => String => Any = + (f: String => Any, converter: Any => Any) => + (data: String) => converter { + try { + f(data) + } catch { + case NonFatal(_) => null + } + } } -abstract class BaseScriptTransformationWriterThread( - iter: Iterator[InternalRow], - inputSchema: Seq[DataType], - ioSchema: BaseScriptTransformIOSchema, - outputStream: OutputStream, - proc: Process, - stderrBuffer: CircularBuffer, - taskContext: TaskContext, - conf: Configuration) extends Thread with Logging { +abstract class BaseScriptTransformationWriterThread extends Thread with Logging { + + def iter: Iterator[InternalRow] + def inputSchema: Seq[DataType] + def ioSchema: ScriptTransformationIOSchema + def outputStream: OutputStream + def proc: Process + def stderrBuffer: CircularBuffer + def taskContext: TaskContext + def conf: Configuration setDaemon(true) @@ -169,34 +318,50 @@ abstract class BaseScriptTransformationWriterThread( /** * The wrapper class of input and output schema properties */ -abstract class BaseScriptTransformIOSchema extends Serializable { - import ScriptIOSchema._ - - def inputRowFormat: Seq[(String, String)] - - def outputRowFormat: Seq[(String, String)] - - def inputSerdeClass: Option[String] - - def outputSerdeClass: Option[String] - - def inputSerdeProps: Seq[(String, String)] - - def outputSerdeProps: Seq[(String, String)] - - def recordReaderClass: Option[String] - - def recordWriterClass: Option[String] - - def schemaLess: Boolean +case class ScriptTransformationIOSchema( + inputRowFormat: Seq[(String, String)], + outputRowFormat: Seq[(String, String)], + inputSerdeClass: Option[String], + outputSerdeClass: Option[String], + inputSerdeProps: Seq[(String, String)], + outputSerdeProps: Seq[(String, String)], + recordReaderClass: Option[String], + recordWriterClass: Option[String], + schemaLess: Boolean) extends Serializable { + import ScriptTransformationIOSchema._ val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k)) val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k)) } -object ScriptIOSchema { +object ScriptTransformationIOSchema { val defaultFormat = Map( ("TOK_TABLEROWFORMATFIELD", "\t"), ("TOK_TABLEROWFORMATLINES", "\n") ) + + val defaultIOSchema = ScriptTransformationIOSchema( + inputRowFormat = Seq.empty, + outputRowFormat = Seq.empty, + inputSerdeClass = None, + outputSerdeClass = None, + inputSerdeProps = Seq.empty, + outputSerdeProps = Seq.empty, + recordReaderClass = None, + recordWriterClass = None, + schemaLess = false + ) + + def apply(input: ScriptInputOutputSchema): ScriptTransformationIOSchema = { + ScriptTransformationIOSchema( + input.inputRowFormat, + input.outputRowFormat, + input.inputSerdeClass, + input.outputSerdeClass, + input.inputSerdeProps, + input.outputSerdeProps, + input.recordReaderClass, + input.recordWriterClass, + input.schemaLess) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 19aa5935a09d7..a6826a6546ae5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -1046,8 +1046,8 @@ private[hive] trait HiveInspectors { getListTypeInfo(elemType.toTypeInfo) case StructType(fields) => getStructTypeInfo( - java.util.Arrays.asList(fields.map(_.name) : _*), - java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*)) + java.util.Arrays.asList(fields.map(_.name): _*), + java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo): _*)) case MapType(keyType, valueType, _) => getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo) case BinaryType => binaryTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index dae68df08f32e..97e1dee5913a4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils} import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.hive.execution.{HiveScriptIOSchema, HiveScriptTransformationExec} +import org.apache.spark.sql.hive.execution.HiveScriptTransformationExec import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -244,7 +244,7 @@ private[hive] trait HiveStrategies { object HiveScripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ScriptTransformation(input, script, output, child, ioschema) => - val hiveIoSchema = HiveScriptIOSchema(ioschema) + val hiveIoSchema = ScriptTransformationIOSchema(ioschema) HiveScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil case _ => Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala similarity index 58% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 96fe646d39fde..4096916a100c3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -18,9 +18,7 @@ package org.apache.spark.sql.hive.execution import java.io._ -import java.nio.charset.StandardCharsets import java.util.Properties -import javax.annotation.Nullable import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -33,14 +31,13 @@ import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.io.Writable import org.apache.spark.TaskContext -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveInspectors import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types.DataType -import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} +import org.apache.spark.util.{CircularBuffer, Utils} /** * Transforms the input by forking and running the specified script. @@ -54,71 +51,27 @@ case class HiveScriptTransformationExec( script: String, output: Seq[Attribute], child: SparkPlan, - ioschema: HiveScriptIOSchema) + ioschema: ScriptTransformationIOSchema) extends BaseScriptTransformationExec { + import HiveScriptIOSchema._ - override def processIterator( - inputIterator: Iterator[InternalRow], + private def createOutputIteratorWithSerde( + writerThread: BaseScriptTransformationWriterThread, + inputStream: InputStream, + proc: Process, + stderrBuffer: CircularBuffer, + outputSerde: AbstractSerDe, + outputSoi: StructObjectInspector, hadoopConf: Configuration): Iterator[InternalRow] = { - val cmd = List("/bin/bash", "-c", script) - val builder = new ProcessBuilder(cmd.asJava) - - val proc = builder.start() - val inputStream = proc.getInputStream - val outputStream = proc.getOutputStream - val errorStream = proc.getErrorStream - - // In order to avoid deadlocks, we need to consume the error output of the child process. - // To avoid issues caused by large error output, we use a circular buffer to limit the amount - // of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang - // that motivates this. - val stderrBuffer = new CircularBuffer(2048) - new RedirectThread( - errorStream, - stderrBuffer, - "Thread-ScriptTransformation-STDERR-Consumer").start() - - val outputProjection = new InterpretedProjection(input, child.output) - - // This nullability is a performance optimization in order to avoid an Option.foreach() call - // inside of a loop - @Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null)) - - // This new thread will consume the ScriptTransformation's input rows and write them to the - // external process. That process's output will be read by this current thread. - val writerThread = new HiveScriptTransformationWriterThread( - inputIterator.map(outputProjection), - input.map(_.dataType), - inputSerde, - inputSoi, - ioschema, - outputStream, - proc, - stderrBuffer, - TaskContext.get(), - hadoopConf - ) - - // This nullability is a performance optimization in order to avoid an Option.foreach() call - // inside of a loop - @Nullable val (outputSerde, outputSoi) = { - ioschema.initOutputSerDe(output).getOrElse((null, null)) - } - - val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) - val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { + new Iterator[InternalRow] with HiveInspectors { var curLine: String = null val scriptOutputStream = new DataInputStream(inputStream) - @Nullable val scriptOutputReader = - ioschema.recordReader(scriptOutputStream, hadoopConf).orNull + val scriptOutputReader = + recordReader(ioschema, scriptOutputStream, hadoopConf).orNull var scriptOutputWritable: Writable = null - val reusedWritableObject: Writable = if (null != outputSerde) { - outputSerde.getSerializedClass().getConstructor().newInstance() - } else { - null - } + val reusedWritableObject = outputSerde.getSerializedClass.getConstructor().newInstance() val mutableRow = new SpecificInternalRow(output.map(_.dataType)) @transient @@ -126,15 +79,7 @@ case class HiveScriptTransformationExec( override def hasNext: Boolean = { try { - if (outputSerde == null) { - if (curLine == null) { - curLine = reader.readLine() - if (curLine == null) { - checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) - return false - } - } - } else if (scriptOutputWritable == null) { + if (scriptOutputWritable == null) { scriptOutputWritable = reusedWritableObject if (scriptOutputReader != null) { @@ -172,35 +117,66 @@ case class HiveScriptTransformationExec( if (!hasNext) { throw new NoSuchElementException } - if (outputSerde == null) { - val prevLine = curLine - curLine = reader.readLine() - if (!ioschema.schemaLess) { - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")) - .map(CatalystTypeConverters.convertToCatalyst)) + val raw = outputSerde.deserialize(scriptOutputWritable) + scriptOutputWritable = null + val dataList = outputSoi.getStructFieldsDataAsList(raw) + var i = 0 + while (i < dataList.size()) { + if (dataList.get(i) == null) { + mutableRow.setNullAt(i) } else { - new GenericInternalRow( - prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2) - .map(CatalystTypeConverters.convertToCatalyst)) - } - } else { - val raw = outputSerde.deserialize(scriptOutputWritable) - scriptOutputWritable = null - val dataList = outputSoi.getStructFieldsDataAsList(raw) - var i = 0 - while (i < dataList.size()) { - if (dataList.get(i) == null) { - mutableRow.setNullAt(i) - } else { - unwrappers(i)(dataList.get(i), mutableRow, i) - } - i += 1 + unwrappers(i)(dataList.get(i), mutableRow, i) } - mutableRow + i += 1 } + mutableRow } } + } + + override def processIterator( + inputIterator: Iterator[InternalRow], + hadoopConf: Configuration): Iterator[InternalRow] = { + + val (outputStream, proc, inputStream, stderrBuffer) = initProc + + val (inputSerde, inputSoi) = initInputSerDe(ioschema, input).getOrElse((null, null)) + + // For HiveScriptTransformationExec, if inputSerde == null, but outputSerde != null + // We will use StringBuffer to pass data, in this case, we should cast data as string too. + val finalInput = if (inputSerde == null) { + inputExpressionsWithoutSerde + } else { + input + } + + val outputProjection = new InterpretedProjection(finalInput, child.output) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = HiveScriptTransformationWriterThread( + inputIterator.map(outputProjection), + finalInput.map(_.dataType), + inputSerde, + inputSoi, + ioschema, + outputStream, + proc, + stderrBuffer, + TaskContext.get(), + hadoopConf + ) + + val (outputSerde, outputSoi) = { + initOutputSerDe(ioschema, output).getOrElse((null, null)) + } + + val outputIterator = if (outputSerde == null) { + createOutputIteratorWithoutSerde(writerThread, inputStream, proc, stderrBuffer) + } else { + createOutputIteratorWithSerde( + writerThread, inputStream, proc, stderrBuffer, outputSerde, outputSoi, hadoopConf) + } writerThread.start() @@ -208,30 +184,23 @@ case class HiveScriptTransformationExec( } } -private class HiveScriptTransformationWriterThread( +case class HiveScriptTransformationWriterThread( iter: Iterator[InternalRow], inputSchema: Seq[DataType], - @Nullable inputSerde: AbstractSerDe, - @Nullable inputSoi: StructObjectInspector, - ioSchema: HiveScriptIOSchema, + inputSerde: AbstractSerDe, + inputSoi: StructObjectInspector, + ioSchema: ScriptTransformationIOSchema, outputStream: OutputStream, proc: Process, stderrBuffer: CircularBuffer, taskContext: TaskContext, conf: Configuration) - extends BaseScriptTransformationWriterThread( - iter, - inputSchema, - ioSchema, - outputStream, - proc, - stderrBuffer, - taskContext, - conf) with HiveInspectors { + extends BaseScriptTransformationWriterThread with HiveInspectors { + import HiveScriptIOSchema._ override def processRows(): Unit = { val dataOutputStream = new DataOutputStream(outputStream) - @Nullable val scriptInputWriter = ioSchema.recordWriter(dataOutputStream, conf).orNull + val scriptInputWriter = recordWriter(ioSchema, dataOutputStream, conf).orNull if (inputSerde == null) { processRowsWithoutSerde() @@ -259,40 +228,14 @@ private class HiveScriptTransformationWriterThread( } } -object HiveScriptIOSchema { - def apply(input: ScriptInputOutputSchema): HiveScriptIOSchema = { - HiveScriptIOSchema( - input.inputRowFormat, - input.outputRowFormat, - input.inputSerdeClass, - input.outputSerdeClass, - input.inputSerdeProps, - input.outputSerdeProps, - input.recordReaderClass, - input.recordWriterClass, - input.schemaLess) - } -} +object HiveScriptIOSchema extends HiveInspectors { -/** - * The wrapper class of Hive input and output schema properties - */ -case class HiveScriptIOSchema ( - inputRowFormat: Seq[(String, String)], - outputRowFormat: Seq[(String, String)], - inputSerdeClass: Option[String], - outputSerdeClass: Option[String], - inputSerdeProps: Seq[(String, String)], - outputSerdeProps: Seq[(String, String)], - recordReaderClass: Option[String], - recordWriterClass: Option[String], - schemaLess: Boolean) - extends BaseScriptTransformIOSchema with HiveInspectors { - - def initInputSerDe(input: Seq[Expression]): Option[(AbstractSerDe, StructObjectInspector)] = { - inputSerdeClass.map { serdeClass => + def initInputSerDe( + ioschema: ScriptTransformationIOSchema, + input: Seq[Expression]): Option[(AbstractSerDe, StructObjectInspector)] = { + ioschema.inputSerdeClass.map { serdeClass => val (columns, columnTypes) = parseAttrs(input) - val serde = initSerDe(serdeClass, columns, columnTypes, inputSerdeProps) + val serde = initSerDe(serdeClass, columns, columnTypes, ioschema.inputSerdeProps) val fieldObjectInspectors = columnTypes.map(toInspector) val objectInspector = ObjectInspectorFactory .getStandardStructObjectInspector(columns.asJava, fieldObjectInspectors.asJava) @@ -300,10 +243,12 @@ case class HiveScriptIOSchema ( } } - def initOutputSerDe(output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { - outputSerdeClass.map { serdeClass => + def initOutputSerDe( + ioschema: ScriptTransformationIOSchema, + output: Seq[Attribute]): Option[(AbstractSerDe, StructObjectInspector)] = { + ioschema.outputSerdeClass.map { serdeClass => val (columns, columnTypes) = parseAttrs(output) - val serde = initSerDe(serdeClass, columns, columnTypes, outputSerdeProps) + val serde = initSerDe(serdeClass, columns, columnTypes, ioschema.outputSerdeProps) val structObjectInspector = serde.getObjectInspector().asInstanceOf[StructObjectInspector] (serde, structObjectInspector) } @@ -315,7 +260,7 @@ case class HiveScriptIOSchema ( (columns, columnTypes) } - private def initSerDe( + def initSerDe( serdeClassName: String, columns: Seq[String], columnTypes: Seq[DataType], @@ -339,22 +284,26 @@ case class HiveScriptIOSchema ( } def recordReader( + ioschema: ScriptTransformationIOSchema, inputStream: InputStream, conf: Configuration): Option[RecordReader] = { - recordReaderClass.map { klass => + ioschema.recordReaderClass.map { klass => val instance = Utils.classForName[RecordReader](klass).getConstructor(). newInstance() val props = new Properties() // Can not use props.putAll(outputSerdeProps.toMap.asJava) in scala-2.12 // See https://github.com/scala/bug/issues/10418 - outputSerdeProps.toMap.foreach { case (k, v) => props.put(k, v) } + ioschema.outputSerdeProps.toMap.foreach { case (k, v) => props.put(k, v) } instance.initialize(inputStream, conf, props) instance } } - def recordWriter(outputStream: OutputStream, conf: Configuration): Option[RecordWriter] = { - recordWriterClass.map { klass => + def recordWriter( + ioschema: ScriptTransformationIOSchema, + outputStream: OutputStream, + conf: Configuration): Option[RecordWriter] = { + ioschema.recordWriterClass.map { klass => val instance = Utils.classForName[RecordWriter](klass).getConstructor(). newInstance() instance.initialize(outputStream, conf) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 35252fc47f49f..38bc8d0429135 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} +import org.apache.spark.sql.execution.{ScriptTransformationIOSchema, SparkPlan, SparkPlanTest, UnaryExecNode} import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -39,20 +39,9 @@ import org.apache.spark.sql.types.StringType class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach { import spark.implicits._ + import ScriptTransformationIOSchema._ - private val noSerdeIOSchema = HiveScriptIOSchema( - inputRowFormat = Seq.empty, - outputRowFormat = Seq.empty, - inputSerdeClass = None, - outputSerdeClass = None, - inputSerdeProps = Seq.empty, - outputSerdeProps = Seq.empty, - recordReaderClass = None, - recordWriterClass = None, - schemaLess = false - ) - - private val serdeIOSchema = noSerdeIOSchema.copy( + private val serdeIOSchema = defaultIOSchema.copy( inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) ) @@ -88,7 +77,7 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with script = "cat", output = Seq(AttributeReference("a", StringType)()), child = child, - ioschema = noSerdeIOSchema + ioschema = defaultIOSchema ), rowsDf.collect()) assert(uncaughtExceptionHandler.exception.isEmpty) @@ -123,7 +112,7 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with script = "cat", output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), - ioschema = noSerdeIOSchema + ioschema = defaultIOSchema ), rowsDf.collect()) } @@ -239,7 +228,7 @@ class HiveScriptTransformationSuite extends SparkPlanTest with SQLTestUtils with script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), child = rowsDf.queryExecution.sparkPlan, - ioschema = noSerdeIOSchema) + ioschema = defaultIOSchema) SparkPlanTest.executePlan(plan, hiveContext) } assert(e.getMessage.contains("Subprocess exited with status"))