Skip to content

Commit

Permalink
fix input and out put format
Browse files Browse the repository at this point in the history
  • Loading branch information
AngersZhuuuu committed Jul 14, 2020
1 parent 5bfa669 commit ec754e2
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@ import org.apache.hadoop.conf.Configuration

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{CircularBuffer, SerializableConfiguration, Utils}

trait BaseScriptTransformationExec extends UnaryExecNode {
Expand Down Expand Up @@ -87,6 +90,41 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
}
}
}

def wrapper(data: String, dt: DataType): Any = {
dt match {
case StringType => data
case ByteType => JavaUtils.stringToBytes(data)
case IntegerType => data.toInt
case ShortType => data.toShort
case LongType => data.toLong
case FloatType => data.toFloat
case DoubleType => data.toDouble
case dt: DecimalType => BigDecimal(data)
case DateType if conf.datetimeJava8ApiEnabled =>
DateTimeUtils.stringToDate(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.daysToLocalDate).orNull
case DateType =>
DateTimeUtils.stringToDate(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.toJavaDate).orNull
case TimestampType if conf.datetimeJava8ApiEnabled =>
DateTimeUtils.stringToTimestamp(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.microsToInstant).orNull
case TimestampType =>
DateTimeUtils.stringToTimestamp(
UTF8String.fromString(data),
DateTimeUtils.getZoneId(conf.sessionLocalTimeZone))
.map(DateTimeUtils.toJavaTimestamp).orNull
case CalendarIntervalType => IntervalUtils.stringToInterval(UTF8String.fromString(data))
case dataType: DataType => data
}
}
}

abstract class BaseScriptTransformationWriterThread extends Thread with Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types._
import org.apache.spark.util.{CircularBuffer, RedirectThread}

/**
Expand Down Expand Up @@ -67,7 +67,9 @@ case class SparkScriptTransformationExec(
stderrBuffer,
"Thread-ScriptTransformation-STDERR-Consumer").start()

val outputProjection = new InterpretedProjection(input, child.output)
val finalInput = input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone))

val outputProjection = new InterpretedProjection(finalInput, child.output)

// This new thread will consume the ScriptTransformation's input rows and write them to the
// external process. That process's output will be read by this current thread.
Expand Down Expand Up @@ -116,11 +118,17 @@ case class SparkScriptTransformationExec(
if (!ioschema.schemaLess) {
new GenericInternalRow(
prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
.map(CatalystTypeConverters.convertToCatalyst))
.zip(output)
.map { case (data, dataType) =>
CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType))
})
} else {
new GenericInternalRow(
prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)
.map(CatalystTypeConverters.convertToCatalyst))
.zip(output)
.map { case (data, dataType) =>
CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType))
})
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
None
}
(Seq.empty, Option(name), props.toSeq, recordHandler)

// SPARK-32106: When there is no definition about format, we return empty result
// then we finally execute with SparkScriptTransformationExec
case null =>
(Nil, None, Seq.empty, None)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.execution

import java.time.ZoneId

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, AnalysisException, Strategy}
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -40,7 +38,7 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.types.StructType

/**
* Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting
Expand Down Expand Up @@ -539,7 +537,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.ScriptTransformation(input, script, output, child, ioschema)
if ioschema.inputSerdeClass.isEmpty && ioschema.outputSerdeClass.isEmpty =>
SparkScriptTransformationExec(
input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)),
input,
script,
output,
planLater(child),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
import org.apache.spark.sql.execution._
import org.apache.spark.sql.hive.HiveInspectors
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils}

/**
Expand Down Expand Up @@ -78,17 +78,25 @@ case class HiveScriptTransformationExec(
stderrBuffer,
"Thread-ScriptTransformation-STDERR-Consumer").start()

val outputProjection = new InterpretedProjection(input, child.output)

// This nullability is a performance optimization in order to avoid an Option.foreach() call
// inside of a loop
@Nullable val (inputSerde, inputSoi) = ioschema.initInputSerDe(input).getOrElse((null, null))

// For HiveScriptTransformationExec, if inputSerde == null, but outputSerde != null
// We will use StringBuffer to pass data, in this case, we should cast data as string too.
val finalInput = if (inputSerde == null) {
input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone))
} else {
input
}

val outputProjection = new InterpretedProjection(finalInput, child.output)

// This new thread will consume the ScriptTransformation's input rows and write them to the
// external process. That process's output will be read by this current thread.
val writerThread = HiveScriptTransformationWriterThread(
inputIterator.map(outputProjection),
input.map(_.dataType),
finalInput.map(_.dataType),
inputSerde,
inputSoi,
ioschema,
Expand Down Expand Up @@ -178,11 +186,17 @@ case class HiveScriptTransformationExec(
if (!ioschema.schemaLess) {
new GenericInternalRow(
prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))
.map(CatalystTypeConverters.convertToCatalyst))
.zip(output)
.map { case (data, dataType) =>
CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType))
})
} else {
new GenericInternalRow(
prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)
.map(CatalystTypeConverters.convertToCatalyst))
.zip(output)
.map { case (data, dataType) =>
CatalystTypeConverters.convertToCatalyst(wrapper(data, dataType.dataType))
})
}
} else {
val raw = outputSerde.deserialize(scriptOutputWritable)
Expand Down

0 comments on commit ec754e2

Please sign in to comment.