diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 3f03f42270252..562fe35fb43f3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -17,20 +17,17 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.StreamingContext._ - -import org.apache.spark.{Partitioner, HashPartitioner} -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} -import org.apache.hadoop.mapred.OutputFormat import org.apache.hadoop.conf.Configuration -import org.apache.spark.streaming.{Time, Duration} +import org.apache.hadoop.mapred.{JobConf, OutputFormat} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} + +import org.apache.spark.{HashPartitioner, Partitioner, SerializableWritable} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} +import org.apache.spark.streaming.StreamingContext._ /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. @@ -702,11 +699,14 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) keyClass: Class[_], valueClass: Class[_], outputFormatClass: Class[_ <: NewOutputFormat[_, _]], - conf: Configuration = new Configuration + conf: Configuration = ssc.sparkContext.hadoopConfiguration ) { + // Wrap this in SerializableWritable so that ForeachDStream can be serialized for checkpoints + val serializableConf = new SerializableWritable(conf) val saveFunc = (rdd: RDD[(K, V)], time: Time) => { val file = rddToFileName(prefix, suffix, time) - rdd.saveAsNewAPIHadoopFile(file, keyClass, valueClass, outputFormatClass, conf) + rdd.saveAsNewAPIHadoopFile( + file, keyClass, valueClass, outputFormatClass, serializableConf.value) } self.foreachRDD(saveFunc) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 77ff1ca780a58..acd793dd76281 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} import org.apache.spark.streaming.util.ManualClock import org.apache.spark.util.Utils +import org.apache.hadoop.io.{Text, IntWritable} +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat /** * This test suites tests the checkpointing functionality of DStreams - @@ -205,6 +207,30 @@ class CheckpointSuite extends TestSuiteBase { testCheckpointedOperation(input, operation, output, 7) } + test("recovery with saveAsNewAPIHadoopFiles") { + val tempDir = Files.createTempDir() + try { + testCheckpointedOperation( + Seq(Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq()), + (s: DStream[String]) => { + val output = s.map(x => (x, 1)).reduceByKey(_ + _) + output.saveAsNewAPIHadoopFiles( + tempDir.toURI.toString, + "result", + classOf[Text], + classOf[IntWritable], + classOf[TextOutputFormat[Text, IntWritable]]) + (tempDir.toString, "result") + output + }, + Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + 3 + ) + } finally { + Utils.deleteRecursively(tempDir) + } + } + // This tests whether the StateDStream's RDD checkpoints works correctly such // that the system can recover from a master failure. This assumes as reliable, @@ -391,7 +417,9 @@ class CheckpointSuite extends TestSuiteBase { logInfo("Manual clock after advancing = " + clock.time) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + val outputStream = ssc.graph.getOutputStreams.filter { dstream => + dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] + }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] outputStream.output.map(_.flatten) } }