diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index d161843dd2230..b267fc1a7f934 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -443,7 +443,7 @@ private[spark] object UIUtils extends Logging {
val xml = XML.loadString(s"""$desc""")
// Verify that this has only anchors and span (we are wrapping in span)
- val allowedNodeLabels = Set("a", "span")
+ val allowedNodeLabels = Set("a", "span", "br")
val illegalNodes = xml \\ "_" filterNot { case node: Node =>
allowedNodeLabels.contains(node.label)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 7b7fb9c0e5cbe..affc2018c43cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -252,6 +252,8 @@ class StreamExecution(
*/
private def runBatches(): Unit = {
try {
+ sparkSession.sparkContext.setJobGroup(runId.toString, getBatchDescriptionString,
+ interruptOnCancel = true)
if (sparkSession.sessionState.conf.streamingMetricsEnabled) {
sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics)
}
@@ -284,42 +286,40 @@ class StreamExecution(
triggerExecutor.execute(() => {
startTrigger()
- val continueToRun =
- if (isActive) {
- reportTimeTaken("triggerExecution") {
- if (currentBatchId < 0) {
- // We'll do this initialization only once
- populateStartOffsets(sparkSessionToRunBatches)
- logDebug(s"Stream running from $committedOffsets to $availableOffsets")
- } else {
- constructNextBatch()
- }
- if (dataAvailable) {
- currentStatus = currentStatus.copy(isDataAvailable = true)
- updateStatusMessage("Processing new data")
- runBatch(sparkSessionToRunBatches)
- }
+ if (isActive) {
+ reportTimeTaken("triggerExecution") {
+ if (currentBatchId < 0) {
+ // We'll do this initialization only once
+ populateStartOffsets(sparkSessionToRunBatches)
+ sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)
+ logDebug(s"Stream running from $committedOffsets to $availableOffsets")
+ } else {
+ constructNextBatch()
}
- // Report trigger as finished and construct progress object.
- finishTrigger(dataAvailable)
if (dataAvailable) {
- // Update committed offsets.
- committedOffsets ++= availableOffsets
- batchCommitLog.add(currentBatchId)
- logDebug(s"batch ${currentBatchId} committed")
- // We'll increase currentBatchId after we complete processing current batch's data
- currentBatchId += 1
- } else {
- currentStatus = currentStatus.copy(isDataAvailable = false)
- updateStatusMessage("Waiting for data to arrive")
- Thread.sleep(pollingDelayMs)
+ currentStatus = currentStatus.copy(isDataAvailable = true)
+ updateStatusMessage("Processing new data")
+ runBatch(sparkSessionToRunBatches)
}
- true
+ }
+ // Report trigger as finished and construct progress object.
+ finishTrigger(dataAvailable)
+ if (dataAvailable) {
+ // Update committed offsets.
+ batchCommitLog.add(currentBatchId)
+ committedOffsets ++= availableOffsets
+ logDebug(s"batch ${currentBatchId} committed")
+ // We'll increase currentBatchId after we complete processing current batch's data
+ currentBatchId += 1
+ sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)
} else {
- false
+ currentStatus = currentStatus.copy(isDataAvailable = false)
+ updateStatusMessage("Waiting for data to arrive")
+ Thread.sleep(pollingDelayMs)
}
+ }
updateStatusMessage("Waiting for next trigger")
- continueToRun
+ isActive
})
updateStatusMessage("Stopped")
} else {
@@ -633,7 +633,7 @@ class StreamExecution(
ct.dataType)
case cd: CurrentDate =>
CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
- cd.dataType)
+ cd.dataType, cd.timeZoneId)
}
reportTimeTaken("queryPlanning") {
@@ -688,8 +688,11 @@ class StreamExecution(
// intentionally
state.set(TERMINATED)
if (microBatchThread.isAlive) {
+ sparkSession.sparkContext.cancelJobGroup(runId.toString)
microBatchThread.interrupt()
microBatchThread.join()
+ // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak
+ sparkSession.sparkContext.cancelJobGroup(runId.toString)
}
logInfo(s"Query $prettyIdString was stopped")
}
@@ -829,6 +832,11 @@ class StreamExecution(
}
}
+ private def getBatchDescriptionString: String = {
+ val batchDescription = if (currentBatchId < 0) "init" else currentBatchId.toString
+ Option(name).map(_ + "
").getOrElse("") +
+ s"id = $id
runId = $runId
batch = $batchDescription"
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 6daee4b5b3a88..6335646d37988 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -17,17 +17,26 @@
package org.apache.spark.sql.streaming
+import java.io.{File, InterruptedIOException, IOException}
+import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit}
+
import scala.reflect.ClassTag
import scala.util.control.ControlThrowable
+import org.apache.commons.io.FileUtils
+
+import org.apache.spark.SparkContext
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.util.Utils
class StreamSuite extends StreamTest {
@@ -353,7 +362,7 @@ class StreamSuite extends StreamTest {
.writeStream
.format("memory")
.queryName("testquery")
- .outputMode("complete")
+ .outputMode("append")
.start()
try {
query.processAllAvailable()
@@ -365,13 +374,137 @@ class StreamSuite extends StreamTest {
}
}
}
-}
-/**
- * A fake StreamSourceProvider thats creates a fake Source that cannot be reused.
- */
-class FakeDefaultSource extends StreamSourceProvider {
+ test("handle IOException when the streaming thread is interrupted (pre Hadoop 2.8)") {
+ // This test uses a fake source to throw the same IOException as pre Hadoop 2.8 when the
+ // streaming thread is interrupted. We should handle it properly by not failing the query.
+ ThrowingIOExceptionLikeHadoop12074.createSourceLatch = new CountDownLatch(1)
+ val query = spark
+ .readStream
+ .format(classOf[ThrowingIOExceptionLikeHadoop12074].getName)
+ .load()
+ .writeStream
+ .format("console")
+ .start()
+ assert(ThrowingIOExceptionLikeHadoop12074.createSourceLatch
+ .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS),
+ "ThrowingIOExceptionLikeHadoop12074.createSource wasn't called before timeout")
+ query.stop()
+ assert(query.exception.isEmpty)
+ }
+
+ test("handle InterruptedIOException when the streaming thread is interrupted (Hadoop 2.8+)") {
+ // This test uses a fake source to throw the same InterruptedIOException as Hadoop 2.8+ when the
+ // streaming thread is interrupted. We should handle it properly by not failing the query.
+ ThrowingInterruptedIOException.createSourceLatch = new CountDownLatch(1)
+ val query = spark
+ .readStream
+ .format(classOf[ThrowingInterruptedIOException].getName)
+ .load()
+ .writeStream
+ .format("console")
+ .start()
+ assert(ThrowingInterruptedIOException.createSourceLatch
+ .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS),
+ "ThrowingInterruptedIOException.createSource wasn't called before timeout")
+ query.stop()
+ assert(query.exception.isEmpty)
+ }
+
+ test("SPARK-19873: streaming aggregation with change in number of partitions") {
+ val inputData = MemoryStream[(Int, Int)]
+ val agg = inputData.toDS().groupBy("_1").count()
+
+ testStream(agg, OutputMode.Complete())(
+ AddData(inputData, (1, 0), (2, 0)),
+ StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "2")),
+ CheckAnswer((1, 1), (2, 1)),
+ StopStream,
+ AddData(inputData, (3, 0), (2, 0)),
+ StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "5")),
+ CheckAnswer((1, 1), (2, 2), (3, 1)),
+ StopStream,
+ AddData(inputData, (3, 0), (1, 0)),
+ StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "1")),
+ CheckAnswer((1, 2), (2, 2), (3, 2)))
+ }
+
+ testQuietly("recover from a Spark v2.1 checkpoint") {
+ var inputData: MemoryStream[Int] = null
+ var query: DataStreamWriter[Row] = null
+
+ def prepareMemoryStream(): Unit = {
+ inputData = MemoryStream[Int]
+ inputData.addData(1, 2, 3, 4)
+ inputData.addData(3, 4, 5, 6)
+ inputData.addData(5, 6, 7, 8)
+ query = inputData
+ .toDF()
+ .groupBy($"value")
+ .agg(count("*"))
+ .writeStream
+ .outputMode("complete")
+ .format("memory")
+ }
+
+ // Get an existing checkpoint generated by Spark v2.1.
+ // v2.1 does not record # shuffle partitions in the offset metadata.
+ val resourceUri =
+ this.getClass.getResource("/structured-streaming/checkpoint-version-2.1.0").toURI
+ val checkpointDir = new File(resourceUri)
+
+ // 1 - Test if recovery from the checkpoint is successful.
+ prepareMemoryStream()
+ val dir1 = Utils.createTempDir().getCanonicalFile // not using withTempDir {}, makes test flaky
+ // Copy the checkpoint to a temp dir to prevent changes to the original.
+ // Not doing this will lead to the test passing on the first run, but fail subsequent runs.
+ FileUtils.copyDirectory(checkpointDir, dir1)
+ // Checkpoint data was generated by a query with 10 shuffle partitions.
+ // In order to test reading from the checkpoint, the checkpoint must have two or more batches,
+ // since the last batch may be rerun.
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") {
+ var streamingQuery: StreamingQuery = null
+ try {
+ streamingQuery =
+ query.queryName("counts").option("checkpointLocation", dir1.getCanonicalPath).start()
+ streamingQuery.processAllAvailable()
+ inputData.addData(9)
+ streamingQuery.processAllAvailable()
+
+ QueryTest.checkAnswer(spark.table("counts").toDF(),
+ Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) ::
+ Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil)
+ } finally {
+ if (streamingQuery ne null) {
+ streamingQuery.stop()
+ }
+ }
+ }
+
+ // 2 - Check recovery with wrong num shuffle partitions
+ prepareMemoryStream()
+ val dir2 = Utils.createTempDir().getCanonicalFile
+ FileUtils.copyDirectory(checkpointDir, dir2)
+ // Since the number of partitions is greater than 10, should throw exception.
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "15") {
+ var streamingQuery: StreamingQuery = null
+ try {
+ intercept[StreamingQueryException] {
+ streamingQuery =
+ query.queryName("badQuery").option("checkpointLocation", dir2.getCanonicalPath).start()
+ streamingQuery.processAllAvailable()
+ }
+ } finally {
+ if (streamingQuery ne null) {
+ streamingQuery.stop()
+ }
+ }
+ }
+ }
+}
+
+abstract class FakeSource extends StreamSourceProvider {
private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)
override def sourceSchema(
@@ -379,6 +512,10 @@ class FakeDefaultSource extends StreamSourceProvider {
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema)
+}
+
+/** A fake StreamSourceProvider that creates a fake Source that cannot be reused. */
+class FakeDefaultSource extends FakeSource {
override def createSource(
spark: SQLContext,
@@ -410,3 +547,63 @@ class FakeDefaultSource extends StreamSourceProvider {
}
}
}
+
+/** A fake source that throws the same IOException like pre Hadoop 2.8 when it's interrupted. */
+class ThrowingIOExceptionLikeHadoop12074 extends FakeSource {
+ import ThrowingIOExceptionLikeHadoop12074._
+
+ override def createSource(
+ spark: SQLContext,
+ metadataPath: String,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): Source = {
+ createSourceLatch.countDown()
+ try {
+ Thread.sleep(30000)
+ throw new TimeoutException("sleep was not interrupted in 30 seconds")
+ } catch {
+ case ie: InterruptedException =>
+ throw new IOException(ie.toString)
+ }
+ }
+}
+
+object ThrowingIOExceptionLikeHadoop12074 {
+ /**
+ * A latch to allow the user to wait until `ThrowingIOExceptionLikeHadoop12074.createSource` is
+ * called.
+ */
+ @volatile var createSourceLatch: CountDownLatch = null
+}
+
+/** A fake source that throws InterruptedIOException like Hadoop 2.8+ when it's interrupted. */
+class ThrowingInterruptedIOException extends FakeSource {
+ import ThrowingInterruptedIOException._
+
+ override def createSource(
+ spark: SQLContext,
+ metadataPath: String,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): Source = {
+ createSourceLatch.countDown()
+ try {
+ Thread.sleep(30000)
+ throw new TimeoutException("sleep was not interrupted in 30 seconds")
+ } catch {
+ case ie: InterruptedException =>
+ val iie = new InterruptedIOException(ie.toString)
+ iie.initCause(ie)
+ throw iie
+ }
+ }
+}
+
+object ThrowingInterruptedIOException {
+ /**
+ * A latch to allow the user to wait until `ThrowingInterruptedIOException.createSource` is
+ * called.
+ */
+ @volatile var createSourceLatch: CountDownLatch = null
+}