From a393d6cdf00ce95b2a3fb4bd15bfc4d82883d1d2 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 23 May 2024 12:19:07 +0900 Subject: [PATCH] [SPARK-48370][CONNECT] Checkpoint and localCheckpoint in Scala Spark Connect client ### What changes were proposed in this pull request? This PR adds `Dataset.checkpoint` and `Dataset.localCheckpoint` into Scala Spark Connect client. Python API was implemented at https://github.com/apache/spark/pull/46570 ### Why are the changes needed? For API parity. ### Does this PR introduce _any_ user-facing change? Yes, it adds `Dataset.checkpoint` and `Dataset.localCheckpoint` into Scala Spark Connect client. ### How was this patch tested? Unittests added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46683 from HyukjinKwon/SPARK-48370. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/sql/Dataset.scala | 107 +++++++++++-- .../org/apache/spark/sql/SparkSession.scala | 10 +- .../spark/sql/internal/SessionCleaner.scala | 146 ++++++++++++++++++ .../apache/spark/sql/CheckpointSuite.scala | 117 ++++++++++++++ .../CheckConnectJvmClientCompatibility.scala | 10 ++ .../connect/client/SparkConnectClient.scala | 2 +- 6 files changed, 379 insertions(+), 13 deletions(-) create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala create mode 100644 connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 37f770319b695..fc9766357cb22 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3402,20 +3402,105 @@ class Dataset[T] private[sql] ( df } - def checkpoint(): Dataset[T] = { - throw new UnsupportedOperationException("checkpoint is not implemented.") - } + /** + * Eagerly checkpoint a Dataset and return the new Dataset. Checkpointing can be used to + * truncate the logical plan of this Dataset, which is especially useful in iterative algorithms + * where the plan may grow exponentially. It will be saved to files inside the checkpoint + * directory set with `SparkContext#setCheckpointDir`. + * + * @group basic + * @since 4.0.0 + */ + def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true) - def checkpoint(eager: Boolean): Dataset[T] = { - throw new UnsupportedOperationException("checkpoint is not implemented.") - } + /** + * Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the + * logical plan of this Dataset, which is especially useful in iterative algorithms where the + * plan may grow exponentially. It will be saved to files inside the checkpoint directory set + * with `SparkContext#setCheckpointDir`. + * + * @param eager + * Whether to checkpoint this dataframe immediately + * + * @note + * When checkpoint is used with eager = false, the final data that is checkpointed after the + * first action may be different from the data that was used during the job due to + * non-determinism of the underlying operation and retries. If checkpoint is used to achieve + * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is + * only deterministic after the first execution, after the checkpoint was finalized. + * + * @group basic + * @since 4.0.0 + */ + def checkpoint(eager: Boolean): Dataset[T] = + checkpoint(eager = eager, reliableCheckpoint = true) - def localCheckpoint(): Dataset[T] = { - throw new UnsupportedOperationException("localCheckpoint is not implemented.") - } + /** + * Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used + * to truncate the logical plan of this Dataset, which is especially useful in iterative + * algorithms where the plan may grow exponentially. Local checkpoints are written to executor + * storage and despite potentially faster they are unreliable and may compromise job completion. + * + * @group basic + * @since 4.0.0 + */ + def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false) + + /** + * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to + * truncate the logical plan of this Dataset, which is especially useful in iterative algorithms + * where the plan may grow exponentially. Local checkpoints are written to executor storage and + * despite potentially faster they are unreliable and may compromise job completion. + * + * @param eager + * Whether to checkpoint this dataframe immediately + * + * @note + * When checkpoint is used with eager = false, the final data that is checkpointed after the + * first action may be different from the data that was used during the job due to + * non-determinism of the underlying operation and retries. If checkpoint is used to achieve + * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is + * only deterministic after the first execution, after the checkpoint was finalized. + * + * @group basic + * @since 4.0.0 + */ + def localCheckpoint(eager: Boolean): Dataset[T] = + checkpoint(eager = eager, reliableCheckpoint = false) - def localCheckpoint(eager: Boolean): Dataset[T] = { - throw new UnsupportedOperationException("localCheckpoint is not implemented.") + /** + * Returns a checkpointed version of this Dataset. + * + * @param eager + * Whether to checkpoint this dataframe immediately + * @param reliableCheckpoint + * Whether to create a reliable checkpoint saved to files inside the checkpoint directory. If + * false creates a local checkpoint using the caching subsystem + */ + private def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { + sparkSession.newDataset(agnosticEncoder) { builder => + val command = sparkSession.newCommand { builder => + builder.getCheckpointCommandBuilder + .setLocal(reliableCheckpoint) + .setEager(eager) + .setRelation(this.plan.getRoot) + } + val responseIter = sparkSession.execute(command) + try { + val response = responseIter + .find(_.hasCheckpointCommandResult) + .getOrElse(throw new RuntimeException("CheckpointCommandResult must be present")) + + val cachedRemoteRelation = response.getCheckpointCommandResult.getRelation + sparkSession.cleaner.registerCachedRemoteRelationForCleanup(cachedRemoteRelation) + + // Update the builder with the values from the result. + builder.setCachedRemoteRelation(cachedRemoteRelation) + } finally { + // consume the rest of the iterator + responseIter.foreach(_ => ()) + } + } } /** diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 1188fba60a2fe..91ee0f52e8bd0 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, Spar import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf} +import org.apache.spark.sql.internal.{CatalogImpl, SessionCleaner, SqlApiConf} import org.apache.spark.sql.streaming.DataStreamReader import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.types.StructType @@ -73,6 +73,11 @@ class SparkSession private[sql] ( with Logging { private[this] val allocator = new RootAllocator() + private var shouldStopCleaner = false + private[sql] lazy val cleaner = { + shouldStopCleaner = true + new SessionCleaner(this) + } // a unique session ID for this session from client. private[sql] def sessionId: String = client.sessionId @@ -714,6 +719,9 @@ class SparkSession private[sql] ( if (releaseSessionOnClose) { client.releaseSession() } + if (shouldStopCleaner) { + cleaner.stop() + } client.shutdown() allocator.close() SparkSession.onSessionClose(this) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala new file mode 100644 index 0000000000000..036ea4a84fa97 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/SessionCleaner.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import java.lang.ref.{ReferenceQueue, WeakReference} +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.connect.proto +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession + +/** + * Classes that represent cleaning tasks. + */ +private sealed trait CleanupTask +private case class CleanupCachedRemoteRelation(dfID: String) extends CleanupTask + +/** + * A WeakReference associated with a CleanupTask. + * + * When the referent object becomes only weakly reachable, the corresponding + * CleanupTaskWeakReference is automatically added to the given reference queue. + */ +private class CleanupTaskWeakReference( + val task: CleanupTask, + referent: AnyRef, + referenceQueue: ReferenceQueue[AnyRef]) + extends WeakReference(referent, referenceQueue) + +/** + * An asynchronous cleaner for objects. + * + * This maintains a weak reference for each CashRemoteRelation, etc. of interest, to be processed + * when the associated object goes out of scope of the application. Actual cleanup is performed in + * a separate daemon thread. + */ +private[sql] class SessionCleaner(session: SparkSession) extends Logging { + + /** + * How often (seconds) to trigger a garbage collection in this JVM. This context cleaner + * triggers cleanups only when weak references are garbage collected. In long-running + * applications with large driver JVMs, where there is little memory pressure on the driver, + * this may happen very occasionally or not at all. Not cleaning at all may lead to executors + * running out of disk space after a while. + */ + private val refQueuePollTimeout: Long = 100 + + /** + * A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they + * have not been handled by the reference queue. + */ + private val referenceBuffer = + Collections.newSetFromMap[CleanupTaskWeakReference](new ConcurrentHashMap) + + private val referenceQueue = new ReferenceQueue[AnyRef] + + private val cleaningThread = new Thread() { override def run(): Unit = keepCleaning() } + + @volatile private var started = false + @volatile private var stopped = false + + /** Start the cleaner. */ + def start(): Unit = { + cleaningThread.setDaemon(true) + cleaningThread.setName("Spark Connect Context Cleaner") + cleaningThread.start() + } + + /** + * Stop the cleaning thread and wait until the thread has finished running its current task. + */ + def stop(): Unit = { + stopped = true + // Interrupt the cleaning thread, but wait until the current task has finished before + // doing so. This guards against the race condition where a cleaning thread may + // potentially clean similarly named variables created by a different SparkSession. + synchronized { + cleaningThread.interrupt() + } + cleaningThread.join() + } + + /** Register a CachedRemoteRelation for cleanup when it is garbage collected. */ + def registerCachedRemoteRelationForCleanup(relation: proto.CachedRemoteRelation): Unit = { + registerForCleanup(relation, CleanupCachedRemoteRelation(relation.getRelationId)) + } + + /** Register an object for cleanup. */ + private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask): Unit = { + if (!started) { + // Lazily starts when the first cleanup is registered. + start() + started = true + } + referenceBuffer.add(new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)) + } + + /** Keep cleaning objects. */ + private def keepCleaning(): Unit = { + while (!stopped && !session.client.channel.isShutdown) { + try { + val reference = Option(referenceQueue.remove(refQueuePollTimeout)) + .map(_.asInstanceOf[CleanupTaskWeakReference]) + // Synchronize here to avoid being interrupted on stop() + synchronized { + reference.foreach { ref => + logDebug("Got cleaning task " + ref.task) + referenceBuffer.remove(ref) + ref.task match { + case CleanupCachedRemoteRelation(dfID) => + doCleanupCachedRemoteRelation(dfID) + } + } + } + } catch { + case e: Throwable => logError("Error in cleaning thread", e) + } + } + } + + /** Perform CleanupCachedRemoteRelation cleanup. */ + private[spark] def doCleanupCachedRemoteRelation(dfID: String): Unit = { + session.execute { + session.newCommand { builder => + builder.getRemoveCachedRemoteRelationCommandBuilder + .setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfID).build()) + } + } + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala new file mode 100644 index 0000000000000..e57b051890f56 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import java.io.{ByteArrayOutputStream, PrintStream} + +import scala.concurrent.duration.DurationInt + +import org.apache.commons.io.output.TeeOutputStream +import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} +import org.scalatest.exceptions.TestFailedDueToTimeoutException + +import org.apache.spark.SparkException +import org.apache.spark.connect.proto +import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} + +class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelper { + + private def captureStdOut(block: => Unit): String = { + val currentOut = Console.out + val capturedOut = new ByteArrayOutputStream() + val newOut = new PrintStream(new TeeOutputStream(currentOut, capturedOut)) + Console.withOut(newOut) { + block + } + capturedOut.toString + } + + private def checkFragments(result: String, fragmentsToCheck: Seq[String]): Unit = { + fragmentsToCheck.foreach { fragment => + assert(result.contains(fragment)) + } + } + + private def testCapturedStdOut(block: => Unit, fragmentsToCheck: String*): Unit = { + checkFragments(captureStdOut(block), fragmentsToCheck) + } + + test("checkpoint") { + val df = spark.range(100).localCheckpoint() + testCapturedStdOut(df.explain(), "ExistingRDD") + } + + test("checkpoint gc") { + val df = spark.range(100).localCheckpoint(eager = true) + val encoder = df.agnosticEncoder + val dfId = df.plan.getRoot.getCachedRemoteRelation.getRelationId + spark.cleaner.doCleanupCachedRemoteRelation(dfId) + + val ex = intercept[SparkException] { + spark + .newDataset(encoder) { builder => + builder.setCachedRemoteRelation( + proto.CachedRemoteRelation + .newBuilder() + .setRelationId(dfId) + .build()) + } + .collect() + } + assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found")) + } + + // This test is flaky because cannot guarantee GC + // You can locally run this to verify the behavior. + ignore("checkpoint gc derived DataFrame") { + var df1 = spark.range(100).localCheckpoint(eager = true) + var derived = df1.repartition(10) + val encoder = df1.agnosticEncoder + val dfId = df1.plan.getRoot.getCachedRemoteRelation.getRelationId + + df1 = null + System.gc() + Thread.sleep(3000L) + + def condition(): Unit = { + val ex = intercept[SparkException] { + spark + .newDataset(encoder) { builder => + builder.setCachedRemoteRelation( + proto.CachedRemoteRelation + .newBuilder() + .setRelationId(dfId) + .build()) + } + .collect() + } + assert(ex.getMessage.contains(s"No DataFrame with id $dfId is found")) + } + + intercept[TestFailedDueToTimeoutException] { + eventually(timeout(5.seconds), interval(1.second))(condition()) + } + + // GC triggers remove the cached remote relation + derived = null + System.gc() + Thread.sleep(3000L) + + // Check the state was removed up on garbage-collection. + eventually(timeout(60.seconds), interval(1.second))(condition()) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 374d8464deebf..2e4bbab8d3a41 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -334,6 +334,16 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[ReversedMissingMethodProblem]( "org.apache.spark.sql.SQLImplicits._sqlContext" // protected ), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.internal.SessionCleaner"), + + // private + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.internal.CleanupTask"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.internal.CleanupTaskWeakReference"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.internal.CleanupCachedRemoteRelation"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.internal.CleanupCachedRemoteRelation$"), // Catalyst Refactoring ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.util.SparkCollectionUtils"), diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 1e7b4e6574ddb..b5eda024bfb3c 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.connect.common.config.ConnectCommon */ private[sql] class SparkConnectClient( private[sql] val configuration: SparkConnectClient.Configuration, - private val channel: ManagedChannel) { + private[sql] val channel: ManagedChannel) { private val userContext: UserContext = configuration.userContext