Skip to content

Commit

Permalink
Address a comment
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed May 16, 2024
1 parent f40ffe0 commit cdc8af7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ class SparkContext(config: SparkConf) extends Logging {
private[spark] def cleaner: Option[ContextCleaner] = _cleaner

private[spark] var checkpointDir: Option[String] = None
config.getOption(CHECKPOINT_DIR.key).foreach(setCheckpointDir)

// Thread Local variable that can be used by users to pass information down the stack
protected[spark] val localProperties = new InheritableThreadLocal[Properties] {
Expand Down Expand Up @@ -602,6 +601,8 @@ class SparkContext(config: SparkConf) extends Logging {
.foreach(logLevel => _schedulerBackend.updateExecutorsLogLevel(logLevel))
}

_conf.get(CHECKPOINT_DIR).foreach(setCheckpointDir)

val _executorMetricsSource =
if (_conf.get(METRICS_EXECUTORMETRICS_SOURCE_ENABLED)) {
Some(new ExecutorMetricsSource)
Expand Down
17 changes: 16 additions & 1 deletion core/src/test/scala/org/apache/spark/CheckpointSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,6 @@ object CheckpointSuite {
}

class CheckpointStorageSuite extends SparkFunSuite with LocalSparkContext {

test("checkpoint compression") {
withTempDir { checkpointDir =>
val conf = new SparkConf()
Expand Down Expand Up @@ -669,4 +668,20 @@ class CheckpointStorageSuite extends SparkFunSuite with LocalSparkContext {
assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]])
}
}

test("SPARK-48268: checkpoint directory via configuration") {
withTempDir { checkpointDir =>
val conf = new SparkConf()
.set("spark.checkpoint.dir", checkpointDir.toString)
.set(UI_ENABLED.key, "false")
sc = new SparkContext("local", "test", conf)
val parCollection = sc.makeRDD(1 to 4)
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
flatMappedRDD.checkpoint()
assert(flatMappedRDD.dependencies.head.rdd === parCollection)
val result = flatMappedRDD.collect()
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
assert(flatMappedRDD.collect() === result)
}
}
}

0 comments on commit cdc8af7

Please sign in to comment.