diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 55a8f8a921c73..214f22bc5b603 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1210,7 +1210,8 @@ abstract class RDD[T: ClassTag]( /** * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint * directory set with SparkContext.setCheckpointDir() and all references to its parent - * RDDs will be removed. It is strongly recommended that this RDD is persisted in + * RDDs will be removed. This function must be called before any job has been + * executed on this RDD. It is strongly recommended that this RDD is persisted in * memory, otherwise saving it on a file will require recomputation. */ def checkpoint() { @@ -1278,7 +1279,7 @@ abstract class RDD[T: ClassTag]( } // Avoid handling doCheckpoint multiple times to prevent excessive recursion - @transient private var doCheckpointCalled = 0 + @transient private var doCheckpointCalled = false /** * Performs the checkpointing of this RDD by saving this. It is called after a job using this RDD @@ -1286,12 +1287,13 @@ abstract class RDD[T: ClassTag]( * doCheckpoint() is called recursively on the parent RDDs. */ private[spark] def doCheckpoint() { - if (checkpointData == None && doCheckpointCalled == 0) { - dependencies.foreach(_.rdd.doCheckpoint()) - doCheckpointCalled = 1 - } else if (checkpointData.isDefined && doCheckpointCalled < 2) { - checkpointData.get.doCheckpoint() - doCheckpointCalled = 2 + if (!doCheckpointCalled) { + doCheckpointCalled = true + if (checkpointData.isDefined) { + checkpointData.get.doCheckpoint() + } else { + dependencies.foreach(_.rdd.doCheckpoint()) + } } } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index b62744eb69396..3b10b3a042317 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -55,17 +55,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { assert(flatMappedRDD.collect() === result) } - test("After call count method, checkpoint should also work") { - val parCollection = sc.makeRDD(1 to 4) - val flatMappedRDD = parCollection.flatMap(x => 1 to x) - flatMappedRDD.count - flatMappedRDD.checkpoint() - assert(flatMappedRDD.dependencies.head.rdd == parCollection) - val result = flatMappedRDD.collect() - assert(flatMappedRDD.dependencies.head.rdd != parCollection) - assert(flatMappedRDD.collect() === result) - } - test("RDDs with one-to-one dependencies") { testRDD(_.map(x => x.toString)) testRDD(_.flatMap(x => 1 to x)) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 3991de30b4131..9da0064104fb6 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -376,18 +376,13 @@ class GraphSuite extends FunSuite with LocalSparkContext { val rdd = sc.parallelize(ring) val graph = Graph.fromEdges(rdd, 1.0F) graph.checkpoint() + graph.edges.map(_.attr).count() + graph.vertices.map(_._2).count() + val edgesDependencies = graph.edges.partitionsRDD.dependencies val verticesDependencies = graph.vertices.partitionsRDD.dependencies - val edges = graph.edges.collect().map(_.attr) - val vertices = graph.vertices.collect().map(_._2) - - graph.vertices.count() - graph.edges.count() - - assert(graph.edges.partitionsRDD.dependencies != edgesDependencies) - assert(graph.vertices.partitionsRDD.dependencies != verticesDependencies) - assert(graph.vertices.collect().map(_._2) === vertices) - assert(graph.edges.collect().map(_.attr) === edges) + assert(edgesDependencies.forall(_.rdd.isInstanceOf[CheckpointRDD[_]])) + assert(verticesDependencies.forall(_.rdd.isInstanceOf[CheckpointRDD[_]])) } }