diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index d161843dd2230..b267fc1a7f934 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -443,7 +443,7 @@ private[spark] object UIUtils extends Logging { val xml = XML.loadString(s"""$desc""") // Verify that this has only anchors and span (we are wrapping in span) - val allowedNodeLabels = Set("a", "span") + val allowedNodeLabels = Set("a", "span", "br") val illegalNodes = xml \\ "_" filterNot { case node: Node => allowedNodeLabels.contains(node.label) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 7b7fb9c0e5cbe..affc2018c43cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -252,6 +252,8 @@ class StreamExecution( */ private def runBatches(): Unit = { try { + sparkSession.sparkContext.setJobGroup(runId.toString, getBatchDescriptionString, + interruptOnCancel = true) if (sparkSession.sessionState.conf.streamingMetricsEnabled) { sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics) } @@ -284,42 +286,40 @@ class StreamExecution( triggerExecutor.execute(() => { startTrigger() - val continueToRun = - if (isActive) { - reportTimeTaken("triggerExecution") { - if (currentBatchId < 0) { - // We'll do this initialization only once - populateStartOffsets(sparkSessionToRunBatches) - logDebug(s"Stream running from $committedOffsets to $availableOffsets") - } else { - constructNextBatch() - } - if (dataAvailable) { - currentStatus = currentStatus.copy(isDataAvailable = true) - updateStatusMessage("Processing new data") - runBatch(sparkSessionToRunBatches) - } + if (isActive) { + reportTimeTaken("triggerExecution") { + if (currentBatchId < 0) { + // We'll do this initialization only once + populateStartOffsets(sparkSessionToRunBatches) + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) + logDebug(s"Stream running from $committedOffsets to $availableOffsets") + } else { + constructNextBatch() } - // Report trigger as finished and construct progress object. - finishTrigger(dataAvailable) if (dataAvailable) { - // Update committed offsets. - committedOffsets ++= availableOffsets - batchCommitLog.add(currentBatchId) - logDebug(s"batch ${currentBatchId} committed") - // We'll increase currentBatchId after we complete processing current batch's data - currentBatchId += 1 - } else { - currentStatus = currentStatus.copy(isDataAvailable = false) - updateStatusMessage("Waiting for data to arrive") - Thread.sleep(pollingDelayMs) + currentStatus = currentStatus.copy(isDataAvailable = true) + updateStatusMessage("Processing new data") + runBatch(sparkSessionToRunBatches) } - true + } + // Report trigger as finished and construct progress object. + finishTrigger(dataAvailable) + if (dataAvailable) { + // Update committed offsets. + batchCommitLog.add(currentBatchId) + committedOffsets ++= availableOffsets + logDebug(s"batch ${currentBatchId} committed") + // We'll increase currentBatchId after we complete processing current batch's data + currentBatchId += 1 + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) } else { - false + currentStatus = currentStatus.copy(isDataAvailable = false) + updateStatusMessage("Waiting for data to arrive") + Thread.sleep(pollingDelayMs) } + } updateStatusMessage("Waiting for next trigger") - continueToRun + isActive }) updateStatusMessage("Stopped") } else { @@ -633,7 +633,7 @@ class StreamExecution( ct.dataType) case cd: CurrentDate => CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs, - cd.dataType) + cd.dataType, cd.timeZoneId) } reportTimeTaken("queryPlanning") { @@ -688,8 +688,11 @@ class StreamExecution( // intentionally state.set(TERMINATED) if (microBatchThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) microBatchThread.interrupt() microBatchThread.join() + // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak + sparkSession.sparkContext.cancelJobGroup(runId.toString) } logInfo(s"Query $prettyIdString was stopped") } @@ -829,6 +832,11 @@ class StreamExecution( } } + private def getBatchDescriptionString: String = { + val batchDescription = if (currentBatchId < 0) "init" else currentBatchId.toString + Option(name).map(_ + "
").getOrElse("") + + s"id = $id
runId = $runId
batch = $batchDescription" + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 6daee4b5b3a88..6335646d37988 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,17 +17,26 @@ package org.apache.spark.sql.streaming +import java.io.{File, InterruptedIOException, IOException} +import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} + import scala.reflect.ClassTag import scala.util.control.ControlThrowable +import org.apache.commons.io.FileUtils + +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.util.Utils class StreamSuite extends StreamTest { @@ -353,7 +362,7 @@ class StreamSuite extends StreamTest { .writeStream .format("memory") .queryName("testquery") - .outputMode("complete") + .outputMode("append") .start() try { query.processAllAvailable() @@ -365,13 +374,137 @@ class StreamSuite extends StreamTest { } } } -} -/** - * A fake StreamSourceProvider thats creates a fake Source that cannot be reused. - */ -class FakeDefaultSource extends StreamSourceProvider { + test("handle IOException when the streaming thread is interrupted (pre Hadoop 2.8)") { + // This test uses a fake source to throw the same IOException as pre Hadoop 2.8 when the + // streaming thread is interrupted. We should handle it properly by not failing the query. + ThrowingIOExceptionLikeHadoop12074.createSourceLatch = new CountDownLatch(1) + val query = spark + .readStream + .format(classOf[ThrowingIOExceptionLikeHadoop12074].getName) + .load() + .writeStream + .format("console") + .start() + assert(ThrowingIOExceptionLikeHadoop12074.createSourceLatch + .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS), + "ThrowingIOExceptionLikeHadoop12074.createSource wasn't called before timeout") + query.stop() + assert(query.exception.isEmpty) + } + + test("handle InterruptedIOException when the streaming thread is interrupted (Hadoop 2.8+)") { + // This test uses a fake source to throw the same InterruptedIOException as Hadoop 2.8+ when the + // streaming thread is interrupted. We should handle it properly by not failing the query. + ThrowingInterruptedIOException.createSourceLatch = new CountDownLatch(1) + val query = spark + .readStream + .format(classOf[ThrowingInterruptedIOException].getName) + .load() + .writeStream + .format("console") + .start() + assert(ThrowingInterruptedIOException.createSourceLatch + .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS), + "ThrowingInterruptedIOException.createSource wasn't called before timeout") + query.stop() + assert(query.exception.isEmpty) + } + + test("SPARK-19873: streaming aggregation with change in number of partitions") { + val inputData = MemoryStream[(Int, Int)] + val agg = inputData.toDS().groupBy("_1").count() + + testStream(agg, OutputMode.Complete())( + AddData(inputData, (1, 0), (2, 0)), + StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "2")), + CheckAnswer((1, 1), (2, 1)), + StopStream, + AddData(inputData, (3, 0), (2, 0)), + StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "5")), + CheckAnswer((1, 1), (2, 2), (3, 1)), + StopStream, + AddData(inputData, (3, 0), (1, 0)), + StartStream(additionalConfs = Map(SQLConf.SHUFFLE_PARTITIONS.key -> "1")), + CheckAnswer((1, 2), (2, 2), (3, 2))) + } + + testQuietly("recover from a Spark v2.1 checkpoint") { + var inputData: MemoryStream[Int] = null + var query: DataStreamWriter[Row] = null + + def prepareMemoryStream(): Unit = { + inputData = MemoryStream[Int] + inputData.addData(1, 2, 3, 4) + inputData.addData(3, 4, 5, 6) + inputData.addData(5, 6, 7, 8) + query = inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .writeStream + .outputMode("complete") + .format("memory") + } + + // Get an existing checkpoint generated by Spark v2.1. + // v2.1 does not record # shuffle partitions in the offset metadata. + val resourceUri = + this.getClass.getResource("/structured-streaming/checkpoint-version-2.1.0").toURI + val checkpointDir = new File(resourceUri) + + // 1 - Test if recovery from the checkpoint is successful. + prepareMemoryStream() + val dir1 = Utils.createTempDir().getCanonicalFile // not using withTempDir {}, makes test flaky + // Copy the checkpoint to a temp dir to prevent changes to the original. + // Not doing this will lead to the test passing on the first run, but fail subsequent runs. + FileUtils.copyDirectory(checkpointDir, dir1) + // Checkpoint data was generated by a query with 10 shuffle partitions. + // In order to test reading from the checkpoint, the checkpoint must have two or more batches, + // since the last batch may be rerun. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "10") { + var streamingQuery: StreamingQuery = null + try { + streamingQuery = + query.queryName("counts").option("checkpointLocation", dir1.getCanonicalPath).start() + streamingQuery.processAllAvailable() + inputData.addData(9) + streamingQuery.processAllAvailable() + + QueryTest.checkAnswer(spark.table("counts").toDF(), + Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) :: + Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil) + } finally { + if (streamingQuery ne null) { + streamingQuery.stop() + } + } + } + + // 2 - Check recovery with wrong num shuffle partitions + prepareMemoryStream() + val dir2 = Utils.createTempDir().getCanonicalFile + FileUtils.copyDirectory(checkpointDir, dir2) + // Since the number of partitions is greater than 10, should throw exception. + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "15") { + var streamingQuery: StreamingQuery = null + try { + intercept[StreamingQueryException] { + streamingQuery = + query.queryName("badQuery").option("checkpointLocation", dir2.getCanonicalPath).start() + streamingQuery.processAllAvailable() + } + } finally { + if (streamingQuery ne null) { + streamingQuery.stop() + } + } + } + } +} + +abstract class FakeSource extends StreamSourceProvider { private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) override def sourceSchema( @@ -379,6 +512,10 @@ class FakeDefaultSource extends StreamSourceProvider { schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema) +} + +/** A fake StreamSourceProvider that creates a fake Source that cannot be reused. */ +class FakeDefaultSource extends FakeSource { override def createSource( spark: SQLContext, @@ -410,3 +547,63 @@ class FakeDefaultSource extends StreamSourceProvider { } } } + +/** A fake source that throws the same IOException like pre Hadoop 2.8 when it's interrupted. */ +class ThrowingIOExceptionLikeHadoop12074 extends FakeSource { + import ThrowingIOExceptionLikeHadoop12074._ + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + createSourceLatch.countDown() + try { + Thread.sleep(30000) + throw new TimeoutException("sleep was not interrupted in 30 seconds") + } catch { + case ie: InterruptedException => + throw new IOException(ie.toString) + } + } +} + +object ThrowingIOExceptionLikeHadoop12074 { + /** + * A latch to allow the user to wait until `ThrowingIOExceptionLikeHadoop12074.createSource` is + * called. + */ + @volatile var createSourceLatch: CountDownLatch = null +} + +/** A fake source that throws InterruptedIOException like Hadoop 2.8+ when it's interrupted. */ +class ThrowingInterruptedIOException extends FakeSource { + import ThrowingInterruptedIOException._ + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + createSourceLatch.countDown() + try { + Thread.sleep(30000) + throw new TimeoutException("sleep was not interrupted in 30 seconds") + } catch { + case ie: InterruptedException => + val iie = new InterruptedIOException(ie.toString) + iie.initCause(ie) + throw iie + } + } +} + +object ThrowingInterruptedIOException { + /** + * A latch to allow the user to wait until `ThrowingInterruptedIOException.createSource` is + * called. + */ + @volatile var createSourceLatch: CountDownLatch = null +}