Skip to content

Commit

Permalink
[SPARK-32403][SQL] Refactor current ScriptTransformationExec
Browse files Browse the repository at this point in the history
# What changes were proposed in this pull request?

This PR comes from the comment: #29085 (comment)

- Extract common Script IOSchema `ScriptTransformationIOSchema`
- avoid repeated judgement extract process output row method `createOutputIteratorWithoutSerde` && `createOutputIteratorWithSerde`
- add default no serde IO schemas `ScriptTransformationIOSchema.defaultIOSchema`

### Why are the changes needed?
Refactor code

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?
NO

Closes #29199 from AngersZhuuuu/spark-32105-followup.

Authored-by: angerszhu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
AngersZhuuuu authored and cloud-fan committed Aug 10, 2020
1 parent 99f50c6 commit d251443
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 206 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,38 @@

package org.apache.spark.sql.execution

import java.io.OutputStream
import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream}
import java.nio.charset.StandardCharsets
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import org.apache.hadoop.conf.Configuration

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, UnsafeProjection}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Cast, Expression, GenericInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.{CircularBuffer, SerializableConfiguration, Utils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils}

trait BaseScriptTransformationExec extends UnaryExecNode {
def input: Seq[Expression]
def script: String
def output: Seq[Attribute]
def child: SparkPlan
def ioschema: ScriptTransformationIOSchema

protected lazy val inputExpressionsWithoutSerde: Seq[Expression] = {
input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone))
}

override def producedAttributes: AttributeSet = outputSet -- inputSet

Expand All @@ -56,10 +69,91 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
}
}

def processIterator(
protected def initProc: (OutputStream, Process, InputStream, CircularBuffer) = {
val cmd = List("/bin/bash", "-c", script)
val builder = new ProcessBuilder(cmd.asJava)

val proc = builder.start()
val inputStream = proc.getInputStream
val outputStream = proc.getOutputStream
val errorStream = proc.getErrorStream

// In order to avoid deadlocks, we need to consume the error output of the child process.
// To avoid issues caused by large error output, we use a circular buffer to limit the amount
// of error output that we retain. See SPARK-7862 for more discussion of the deadlock / hang
// that motivates this.
val stderrBuffer = new CircularBuffer(2048)
new RedirectThread(
errorStream,
stderrBuffer,
s"Thread-${this.getClass.getSimpleName}-STDERR-Consumer").start()
(outputStream, proc, inputStream, stderrBuffer)
}

protected def processIterator(
inputIterator: Iterator[InternalRow],
hadoopConf: Configuration): Iterator[InternalRow]

protected def createOutputIteratorWithoutSerde(
writerThread: BaseScriptTransformationWriterThread,
inputStream: InputStream,
proc: Process,
stderrBuffer: CircularBuffer): Iterator[InternalRow] = {
new Iterator[InternalRow] {
var curLine: String = null
val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))

val outputRowFormat = ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD")
val processRowWithoutSerde = if (!ioschema.schemaLess) {
prevLine: String =>
new GenericInternalRow(
prevLine.split(outputRowFormat)
.zip(outputFieldWriters)
.map { case (data, writer) => writer(data) })
} else {
// In schema less mode, hive default serde will choose first two output column as output
// if output column size less then 2, it will throw ArrayIndexOutOfBoundsException.
// Here we change spark's behavior same as hive's default serde.
// But in hive, TRANSFORM with schema less behavior like origin spark, we will fix this
// to keep spark and hive behavior same in SPARK-32388
val kvWriter = CatalystTypeConverters.createToCatalystConverter(StringType)
prevLine: String =>
new GenericInternalRow(
prevLine.split(outputRowFormat).slice(0, 2)
.map(kvWriter))
}

override def hasNext: Boolean = {
try {
if (curLine == null) {
curLine = reader.readLine()
if (curLine == null) {
checkFailureAndPropagate(writerThread, null, proc, stderrBuffer)
return false
}
}
true
} catch {
case NonFatal(e) =>
// If this exception is due to abrupt / unclean termination of `proc`,
// then detect it and propagate a better exception message for end users
checkFailureAndPropagate(writerThread, e, proc, stderrBuffer)

throw e
}
}

override def next(): InternalRow = {
if (!hasNext) {
throw new NoSuchElementException
}
val prevLine = curLine
curLine = reader.readLine()
processRowWithoutSerde(prevLine)
}
}
}

protected def checkFailureAndPropagate(
writerThread: BaseScriptTransformationWriterThread,
cause: Throwable = null,
Expand Down Expand Up @@ -87,17 +181,72 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
}
}
}

private lazy val outputFieldWriters: Seq[String => Any] = output.map { attr =>
val converter = CatalystTypeConverters.createToCatalystConverter(attr.dataType)
attr.dataType match {
case StringType => wrapperConvertException(data => data, converter)
case BooleanType => wrapperConvertException(data => data.toBoolean, converter)
case ByteType => wrapperConvertException(data => data.toByte, converter)
case BinaryType =>
wrapperConvertException(data => UTF8String.fromString(data).getBytes, converter)
case IntegerType => wrapperConvertException(data => data.toInt, converter)
case ShortType => wrapperConvertException(data => data.toShort, converter)
case LongType => wrapperConvertException(data => data.toLong, converter)
case FloatType => wrapperConvertException(data => data.toFloat, converter)
case DoubleType => wrapperConvertException(data => data.toDouble, converter)
case _: DecimalType => wrapperConvertException(data => BigDecimal(data), converter)
case DateType if conf.datetimeJava8ApiEnabled =>
wrapperConvertException(data => DateTimeUtils.stringToDate(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.daysToLocalDate).orNull, converter)
case DateType => wrapperConvertException(data => DateTimeUtils.stringToDate(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.toJavaDate).orNull, converter)
case TimestampType if conf.datetimeJava8ApiEnabled =>
wrapperConvertException(data => DateTimeUtils.stringToTimestamp(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.microsToInstant).orNull, converter)
case TimestampType => wrapperConvertException(data => DateTimeUtils.stringToTimestamp(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.toJavaTimestamp).orNull, converter)
case CalendarIntervalType => wrapperConvertException(
data => IntervalUtils.stringToInterval(UTF8String.fromString(data)),
converter)
case udt: UserDefinedType[_] =>
wrapperConvertException(data => udt.deserialize(data), converter)
case dt =>
throw new SparkException(s"${nodeName} without serde does not support " +
s"${dt.getClass.getSimpleName} as output data type")
}
}

// Keep consistent with Hive `LazySimpleSerde`, when there is a type case error, return null
private val wrapperConvertException: (String => Any, Any => Any) => String => Any =
(f: String => Any, converter: Any => Any) =>
(data: String) => converter {
try {
f(data)
} catch {
case NonFatal(_) => null
}
}
}

abstract class BaseScriptTransformationWriterThread(
iter: Iterator[InternalRow],
inputSchema: Seq[DataType],
ioSchema: BaseScriptTransformIOSchema,
outputStream: OutputStream,
proc: Process,
stderrBuffer: CircularBuffer,
taskContext: TaskContext,
conf: Configuration) extends Thread with Logging {
abstract class BaseScriptTransformationWriterThread extends Thread with Logging {

def iter: Iterator[InternalRow]
def inputSchema: Seq[DataType]
def ioSchema: ScriptTransformationIOSchema
def outputStream: OutputStream
def proc: Process
def stderrBuffer: CircularBuffer
def taskContext: TaskContext
def conf: Configuration

setDaemon(true)

Expand Down Expand Up @@ -169,34 +318,50 @@ abstract class BaseScriptTransformationWriterThread(
/**
* The wrapper class of input and output schema properties
*/
abstract class BaseScriptTransformIOSchema extends Serializable {
import ScriptIOSchema._

def inputRowFormat: Seq[(String, String)]

def outputRowFormat: Seq[(String, String)]

def inputSerdeClass: Option[String]

def outputSerdeClass: Option[String]

def inputSerdeProps: Seq[(String, String)]

def outputSerdeProps: Seq[(String, String)]

def recordReaderClass: Option[String]

def recordWriterClass: Option[String]

def schemaLess: Boolean
case class ScriptTransformationIOSchema(
inputRowFormat: Seq[(String, String)],
outputRowFormat: Seq[(String, String)],
inputSerdeClass: Option[String],
outputSerdeClass: Option[String],
inputSerdeProps: Seq[(String, String)],
outputSerdeProps: Seq[(String, String)],
recordReaderClass: Option[String],
recordWriterClass: Option[String],
schemaLess: Boolean) extends Serializable {
import ScriptTransformationIOSchema._

val inputRowFormatMap = inputRowFormat.toMap.withDefault((k) => defaultFormat(k))
val outputRowFormatMap = outputRowFormat.toMap.withDefault((k) => defaultFormat(k))
}

object ScriptIOSchema {
object ScriptTransformationIOSchema {
val defaultFormat = Map(
("TOK_TABLEROWFORMATFIELD", "\t"),
("TOK_TABLEROWFORMATLINES", "\n")
)

val defaultIOSchema = ScriptTransformationIOSchema(
inputRowFormat = Seq.empty,
outputRowFormat = Seq.empty,
inputSerdeClass = None,
outputSerdeClass = None,
inputSerdeProps = Seq.empty,
outputSerdeProps = Seq.empty,
recordReaderClass = None,
recordWriterClass = None,
schemaLess = false
)

def apply(input: ScriptInputOutputSchema): ScriptTransformationIOSchema = {
ScriptTransformationIOSchema(
input.inputRowFormat,
input.outputRowFormat,
input.inputSerdeClass,
input.outputSerdeClass,
input.inputSerdeProps,
input.outputSerdeProps,
input.recordReaderClass,
input.recordWriterClass,
input.schemaLess)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1046,8 +1046,8 @@ private[hive] trait HiveInspectors {
getListTypeInfo(elemType.toTypeInfo)
case StructType(fields) =>
getStructTypeInfo(
java.util.Arrays.asList(fields.map(_.name) : _*),
java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*))
java.util.Arrays.asList(fields.map(_.name): _*),
java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo): _*))
case MapType(keyType, valueType, _) =>
getMapTypeInfo(keyType.toTypeInfo, valueType.toTypeInfo)
case BinaryType => binaryTypeInfo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils}
import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.hive.execution.{HiveScriptIOSchema, HiveScriptTransformationExec}
import org.apache.spark.sql.hive.execution.HiveScriptTransformationExec
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}


Expand Down Expand Up @@ -244,7 +244,7 @@ private[hive] trait HiveStrategies {
object HiveScripts extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ScriptTransformation(input, script, output, child, ioschema) =>
val hiveIoSchema = HiveScriptIOSchema(ioschema)
val hiveIoSchema = ScriptTransformationIOSchema(ioschema)
HiveScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil
case _ => Nil
}
Expand Down
Loading

0 comments on commit d251443

Please sign in to comment.