Skip to content

Commit

Permalink
Graph should support the checkpoint operation
Browse files Browse the repository at this point in the history
  • Loading branch information
witgo committed Dec 6, 2014
1 parent 6eb1b6f commit e682724
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 8 deletions.
15 changes: 7 additions & 8 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1279,21 +1279,20 @@ abstract class RDD[T: ClassTag](
}

// Avoid handling doCheckpoint multiple times to prevent excessive recursion
@transient private var doCheckpointCalled = false
@transient private var doCheckpointCalled = 0

/**
* Performs the checkpointing of this RDD by saving this. It is called after a job using this RDD
* has completed (therefore the RDD has been materialized and potentially stored in memory).
* doCheckpoint() is called recursively on the parent RDDs.
*/
private[spark] def doCheckpoint() {
if (!doCheckpointCalled) {
doCheckpointCalled = true
if (checkpointData.isDefined) {
checkpointData.get.doCheckpoint()
} else {
dependencies.foreach(_.rdd.doCheckpoint())
}
if (checkpointData == None && doCheckpointCalled == 0) {
dependencies.foreach(_.rdd.doCheckpoint())
doCheckpointCalled = 1
} else if (checkpointData.isDefined && doCheckpointCalled < 2) {
checkpointData.get.doCheckpoint()
doCheckpointCalled = 2
}
}

Expand Down
11 changes: 11 additions & 0 deletions core/src/test/scala/org/apache/spark/CheckpointSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
assert(flatMappedRDD.collect() === result)
}

test("After call count method, checkpoint should also work") {
val parCollection = sc.makeRDD(1 to 4)
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
flatMappedRDD.count
flatMappedRDD.checkpoint()
assert(flatMappedRDD.dependencies.head.rdd == parCollection)
val result = flatMappedRDD.collect()
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
assert(flatMappedRDD.collect() === result)
}

test("RDDs with one-to-one dependencies") {
testRDD(_.map(x => x.toString))
testRDD(_.flatMap(x => 1 to x))
Expand Down
2 changes: 2 additions & 0 deletions graphx/src/main/scala/org/apache/spark/graphx/Graph.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
*/
def cache(): Graph[VD, ED]

def checkpoint():Unit

/**
* Uncaches only the vertices of this graph, leaving the edges alone. This is useful in iterative
* algorithms that modify the vertex attributes but reuse the edges. This method can be used to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] (
this
}

override def checkpoint(): Unit = {
partitionsRDD.checkpoint()
}

/** Persists the edge partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */
override def cache(): this.type = {
partitionsRDD.persist(targetStorageLevel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
this
}

override def checkpoint(): Unit = {
vertices.checkpoint()
replicatedVertexView.edges.checkpoint()
}

override def unpersistVertices(blocking: Boolean = true): Graph[VD, ED] = {
vertices.unpersist(blocking)
// TODO: unpersist the replicated vertices in `replicatedVertexView` but leave the edges alone
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class VertexRDDImpl[VD] private[graphx] (
this
}

override def checkpoint(): Unit = {
partitionsRDD.checkpoint()
}

/** Persists the vertex partitions at `targetStorageLevel`, which defaults to MEMORY_ONLY. */
override def cache(): this.type = {
partitionsRDD.persist(targetStorageLevel)
Expand Down
26 changes: 26 additions & 0 deletions graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.graphx

import org.scalatest.FunSuite

import com.google.common.io.Files

import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph._
import org.apache.spark.graphx.PartitionStrategy._
Expand Down Expand Up @@ -365,4 +367,28 @@ class GraphSuite extends FunSuite with LocalSparkContext {
}
}

test("checkpoint") {
val checkpointDir = Files.createTempDir()
checkpointDir.deleteOnExit()
withSpark { sc =>
sc.setCheckpointDir(checkpointDir.getAbsolutePath)
val ring = (0L to 100L).zip((1L to 99L) :+ 0L).map { case (a, b) => Edge(a, b, 1)}
val rdd = sc.parallelize(ring)
val graph = Graph.fromEdges(rdd, 1.0F)
graph.checkpoint()
val edgesDependencies = graph.edges.partitionsRDD.dependencies
val verticesDependencies = graph.vertices.partitionsRDD.dependencies
val edges = graph.edges.collect().map(_.attr)
val vertices = graph.vertices.collect().map(_._2)

graph.vertices.count()
graph.edges.count()

assert(graph.edges.partitionsRDD.dependencies != edgesDependencies)
assert(graph.vertices.partitionsRDD.dependencies != verticesDependencies)
assert(graph.vertices.collect().map(_._2) === vertices)
assert(graph.edges.collect().map(_.attr) === edges)
}
}

}

0 comments on commit e682724

Please sign in to comment.