diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index 61d0360249781..e243acd7def80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import java.io.{BufferedReader, InputStream, OutputStream} +import java.io.{BufferedReader, InputStream, InputStreamReader, OutputStream} import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -98,6 +98,46 @@ trait BaseScriptTransformationExec extends UnaryExecNode { .map { case (data, writer) => writer(data) }) } + protected def createOutputIteratorWithoutSerde( + writerThread: BaseScriptTransformationWriterThread, + inputStream: InputStream, + proc: Process, + stderrBuffer: CircularBuffer): Iterator[InternalRow] = { + new Iterator[InternalRow] { + var curLine: String = null + val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + + override def hasNext: Boolean = { + try { + if (curLine == null) { + curLine = reader.readLine() + if (curLine == null) { + checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) + return false + } + } + true + } catch { + case NonFatal(e) => + // If this exception is due to abrupt / unclean termination of `proc`, + // then detect it and propagate a better exception message for end users + checkFailureAndPropagate(writerThread, e, proc, stderrBuffer) + + throw e + } + } + + override def next(): InternalRow = { + if (!hasNext) { + throw new NoSuchElementException + } + val prevLine = curLine + curLine = reader.readLine() + processOutputWithoutSerde(prevLine, reader) + } + } + } + protected def checkFailureAndPropagate( writerThread: BaseScriptTransformationWriterThread, cause: Throwable = null, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala index 4909feae20ad5..103eaf869039d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkScriptTransformationExec.scala @@ -18,9 +18,6 @@ package org.apache.spark.sql.execution import java.io._ -import java.nio.charset.StandardCharsets - -import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration @@ -68,39 +65,8 @@ case class SparkScriptTransformationExec( hadoopConf ) - val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) - val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] { - var curLine: String = null - - override def hasNext: Boolean = { - try { - if (curLine == null) { - curLine = reader.readLine() - if (curLine == null) { - checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) - return false - } - } - true - } catch { - case NonFatal(e) => - // If this exception is due to abrupt / unclean termination of `proc`, - // then detect it and propagate a better exception message for end users - checkFailureAndPropagate(writerThread, e, proc, stderrBuffer) - - throw e - } - } - - override def next(): InternalRow = { - if (!hasNext) { - throw new NoSuchElementException - } - val prevLine = curLine - curLine = reader.readLine() - processOutputWithoutSerde(prevLine, reader) - } - } + val outputIterator = createOutputIteratorWithoutSerde( + writerThread, inputStream, proc, stderrBuffer) writerThread.start() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala similarity index 83% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index 8398b82da0c2e..26c08c3f513c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.execution +package org.apache.spark.sql.execution import java.sql.Timestamp -import java.util.Locale import org.scalatest.Assertions._ import org.scalatest.BeforeAndAfterEach @@ -30,26 +29,18 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestUtils - with TestHiveSingleton with BeforeAndAfterEach { + with BeforeAndAfterEach { + import testImplicits._ + import ScriptTransformationIOSchema._ - def scriptType: String - - def isHive23OrSpark: Boolean = true - - import spark.implicits._ - - var noSerdeIOSchema: ScriptTransformationIOSchema = ScriptTransformationIOSchema.defaultIOSchema + protected val uncaughtExceptionHandler = new TestUncaughtExceptionHandler private var defaultUncaughtExceptionHandler: Thread.UncaughtExceptionHandler = _ - protected val uncaughtExceptionHandler = new TestUncaughtExceptionHandler - protected override def beforeAll(): Unit = { super.beforeAll() defaultUncaughtExceptionHandler = Thread.getDefaultUncaughtExceptionHandler @@ -66,32 +57,14 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU uncaughtExceptionHandler.cleanStatus() } + def isHive23OrSpark: Boolean + def createScriptTransformationExec( input: Seq[Expression], script: String, output: Seq[Attribute], child: SparkPlan, - ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec = { - scriptType.toUpperCase(Locale.ROOT) match { - case "SPARK" => new SparkScriptTransformationExec( - input = input, - script = script, - output = output, - child = child, - ioschema = ioschema - ) - case "HIVE" => new HiveScriptTransformationExec( - input = input, - script = script, - output = output, - child = child, - ioschema = ioschema - ) - case _ => throw new TestFailedException( - "Test class implement from BaseScriptTransformationSuite" + - " should override method `scriptType` to Spark or Hive", 0) - } - } + ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec test("cat without SerDe") { assume(TestUtils.testCommandAvailable("/bin/bash")) @@ -104,7 +77,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU script = "cat", output = Seq(AttributeReference("a", StringType)()), child = child, - ioschema = noSerdeIOSchema + ioschema = defaultIOSchema ), rowsDf.collect()) assert(uncaughtExceptionHandler.exception.isEmpty) @@ -122,7 +95,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU script = "cat", output = Seq(AttributeReference("a", StringType)()), child = ExceptionInjectingOperator(child), - ioschema = noSerdeIOSchema + ioschema = defaultIOSchema ), rowsDf.collect()) } @@ -178,8 +151,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU script = "some_non_existent_command", output = Seq(AttributeReference("a", StringType)()), child = rowsDf.queryExecution.sparkPlan, - ioschema = noSerdeIOSchema) - SparkPlanTest.executePlan(plan, hiveContext) + ioschema = defaultIOSchema) + SparkPlanTest.executePlan(plan, spark.sqlContext) } assert(e.getMessage.contains("Subprocess exited with status")) assert(uncaughtExceptionHandler.exception.isEmpty) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala similarity index 84% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala index 381679075c0f1..1abf298a6123e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SparkScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkScriptTransformationSuite.scala @@ -15,22 +15,37 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.execution +package org.apache.spark.sql.execution import java.sql.{Date, Timestamp} import org.apache.spark.TestUtils -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.functions.struct +import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -class SparkScriptTransformationSuite extends BaseScriptTransformationSuite { +class SparkScriptTransformationSuite extends BaseScriptTransformationSuite with SharedSparkSession { + import testImplicits._ + import ScriptTransformationIOSchema._ - import spark.implicits._ + override def isHive23OrSpark: Boolean = true - override def scriptType: String = "SPARK" + override def createScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec = { + SparkScriptTransformationExec( + input = input, + script = script, + output = output, + child = child, + ioschema = ioschema + ) + } test("SPARK-32106: SparkScriptTransformExec should handle different data types correctly") { assume(TestUtils.testCommandAvailable("python")) @@ -97,7 +112,7 @@ class SparkScriptTransformationSuite extends BaseScriptTransformationSuite { AttributeReference("j", StringType)(), AttributeReference("k", StringType)()), child = child, - ioschema = noSerdeIOSchema + ioschema = defaultIOSchema ), df.select( 'a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestUncaughtExceptionHandler.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestUncaughtExceptionHandler.scala similarity index 96% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestUncaughtExceptionHandler.scala rename to sql/core/src/test/scala/org/apache/spark/sql/execution/TestUncaughtExceptionHandler.scala index 681eb4e255dbc..360f4658345e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestUncaughtExceptionHandler.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestUncaughtExceptionHandler.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.execution +package org.apache.spark.sql.execution class TestUncaughtExceptionHandler extends Thread.UncaughtExceptionHandler { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala index 4f17d1f40afdc..37a3789205b5d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationExec.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive.execution import java.io._ -import java.nio.charset.StandardCharsets import java.util.Properties import javax.annotation.Nullable @@ -129,61 +128,22 @@ case class HiveScriptTransformationExec( } } - override def processIterator( - inputIterator: Iterator[InternalRow], + private def createOtputIteratorWithSerde( + writerThread: BaseScriptTransformationWriterThread, + inputStream: InputStream, + proc: Process, + stderrBuffer: CircularBuffer, + outputSerde: AbstractSerDe, + outputSoi: StructObjectInspector, hadoopConf: Configuration): Iterator[InternalRow] = { - - val (outputStream, proc, inputStream, stderrBuffer) = initProc - - // This nullability is a performance optimization in order to avoid an Option.foreach() call - // inside of a loop - @Nullable val (inputSerde, inputSoi) = initInputSerDe(input).getOrElse((null, null)) - - // For HiveScriptTransformationExec, if inputSerde == null, but outputSerde != null - // We will use StringBuffer to pass data, in this case, we should cast data as string too. - val finalInput = if (inputSerde == null) { - input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) - } else { - input - } - - val outputProjection = new InterpretedProjection(finalInput, child.output) - - // This new thread will consume the ScriptTransformation's input rows and write them to the - // external process. That process's output will be read by this current thread. - val writerThread = HiveScriptTransformationWriterThread( - inputIterator.map(outputProjection), - finalInput.map(_.dataType), - inputSerde, - inputSoi, - ioschema, - outputStream, - recordWriter, - proc, - stderrBuffer, - TaskContext.get(), - hadoopConf - ) - - // This nullability is a performance optimization in order to avoid an Option.foreach() call - // inside of a loop - @Nullable val (outputSerde, outputSoi) = { - initOutputSerDe(output).getOrElse((null, null)) - } - - val reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) - val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] with HiveInspectors { + new Iterator[InternalRow] with HiveInspectors { var curLine: String = null val scriptOutputStream = new DataInputStream(inputStream) @Nullable val scriptOutputReader = recordReader(scriptOutputStream, hadoopConf).orNull var scriptOutputWritable: Writable = null - val reusedWritableObject: Writable = if (null != outputSerde) { - outputSerde.getSerializedClass().getConstructor().newInstance() - } else { - null - } + val reusedWritableObject = outputSerde.getSerializedClass.getConstructor().newInstance() val mutableRow = new SpecificInternalRow(output.map(_.dataType)) @transient @@ -191,15 +151,7 @@ case class HiveScriptTransformationExec( override def hasNext: Boolean = { try { - if (outputSerde == null) { - if (curLine == null) { - curLine = reader.readLine() - if (curLine == null) { - checkFailureAndPropagate(writerThread, null, proc, stderrBuffer) - return false - } - } - } else if (scriptOutputWritable == null) { + if (scriptOutputWritable == null) { scriptOutputWritable = reusedWritableObject if (scriptOutputReader != null) { @@ -240,13 +192,7 @@ case class HiveScriptTransformationExec( nextRow() } - val nextRow: () => InternalRow = if (outputSerde == null) { - () => { - val prevLine = curLine - curLine = reader.readLine() - processOutputWithoutSerde(prevLine, reader) - } - } else { + val nextRow: () => InternalRow = { () => { val raw = outputSerde.deserialize(scriptOutputWritable) scriptOutputWritable = null @@ -264,6 +210,56 @@ case class HiveScriptTransformationExec( } } } + } + + override def processIterator( + inputIterator: Iterator[InternalRow], + hadoopConf: Configuration): Iterator[InternalRow] = { + + val (outputStream, proc, inputStream, stderrBuffer) = initProc + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (inputSerde, inputSoi) = initInputSerDe(input).getOrElse((null, null)) + + // For HiveScriptTransformationExec, if inputSerde == null, but outputSerde != null + // We will use StringBuffer to pass data, in this case, we should cast data as string too. + val finalInput = if (inputSerde == null) { + input.map(Cast(_, StringType).withTimeZone(conf.sessionLocalTimeZone)) + } else { + input + } + + val outputProjection = new InterpretedProjection(finalInput, child.output) + + // This new thread will consume the ScriptTransformation's input rows and write them to the + // external process. That process's output will be read by this current thread. + val writerThread = HiveScriptTransformationWriterThread( + inputIterator.map(outputProjection), + finalInput.map(_.dataType), + inputSerde, + inputSoi, + ioschema, + outputStream, + recordWriter, + proc, + stderrBuffer, + TaskContext.get(), + hadoopConf + ) + + // This nullability is a performance optimization in order to avoid an Option.foreach() call + // inside of a loop + @Nullable val (outputSerde, outputSoi) = { + initOutputSerDe(output).getOrElse((null, null)) + } + + val outputIterator = if (outputSerde == null) { + createOutputIteratorWithoutSerde(writerThread, inputStream, proc, stderrBuffer) + } else { + createOtputIteratorWithSerde( + writerThread, inputStream, proc, stderrBuffer, outputSerde, outputSoi, hadoopConf) + } writerThread.start() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 11928fbc2ef0f..7ba1deb101a65 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -21,20 +21,35 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.scalatest.exceptions.TestFailedException import org.apache.spark.{SparkException, TestUtils} -import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.execution.{ScriptTransformationIOSchema, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types.StringType -class HiveScriptTransformationSuite extends BaseScriptTransformationSuite { - override def scriptType: String = "HIVE" +class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with TestHiveSingleton { + import testImplicits._ + import ScriptTransformationIOSchema._ override def isHive23OrSpark: Boolean = HiveUtils.isHive23 - import spark.implicits._ + override def createScriptTransformationExec( + input: Seq[Expression], + script: String, + output: Seq[Attribute], + child: SparkPlan, + ioschema: ScriptTransformationIOSchema): BaseScriptTransformationExec = { + HiveScriptTransformationExec( + input = input, + script = script, + output = output, + child = child, + ioschema = ioschema + ) + } private val serdeIOSchema: ScriptTransformationIOSchema = { - noSerdeIOSchema.copy( + defaultIOSchema.copy( inputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName), outputSerdeClass = Some(classOf[LazySimpleSerDe].getCanonicalName) ) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 920f6385f8e19..24b9d25ed94f2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, Functio import org.apache.spark.sql.catalyst.catalog.{CatalogTableType, CatalogUtils, HiveTableRelation} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.execution.TestUncaughtExceptionHandler import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} import org.apache.spark.sql.execution.command.{FunctionsCommand, LoadDataCommand} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}