Skip to content

Commit

Permalink
[SPARK-45680][CONNECT] Release session
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Introduce a new `ReleaseSession` Spark Connect RPC, which cancels everything running in the session and removes the session server side. Refactor code around managing the cache of sessions into `SparkConnectSessionManager`.

### Why are the changes needed?

Better session management.

### Does this PR introduce _any_ user-facing change?

Not really. `SparkSession.stop()` API already existed on the client side. It was closing the client's network connection, but the Session was still there cached for 1 hour on the server side.

Caveats, which were not really supported user behaviour:
* After `session.stop()`, user could have created a new session with the same session_id in Configuration. That session would be a new session on the client side, but connect to the old cached session in the server. It could therefore e.g. access that old session's state like views or artifacts.
* If a session timed out and was removed in the server, it used to be that a new request would re-create the session. The client would then see this as the old session, but the server would see a new one, and e.g. not have access to old session state that was removed.
* User is no longer allowed to create a new session with the same session_id as before.

### How was this patch tested?

Tests added.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #43546 from juliuszsompolski/release-session.

Lead-authored-by: Juliusz Sompolski <julek@databricks.com>
Co-authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
juliuszsompolski and HyukjinKwon committed Nov 2, 2023
1 parent b14c1f0 commit 59e291d
Show file tree
Hide file tree
Showing 26 changed files with 819 additions and 168 deletions.
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -1737,6 +1737,11 @@
"Session already exists."
]
},
"SESSION_CLOSED" : {
"message" : [
"Session was closed."
]
},
"SESSION_NOT_FOUND" : {
"message" : [
"Session not found."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,9 @@ class SparkSession private[sql] (
* @since 3.4.0
*/
override def close(): Unit = {
if (releaseSessionOnClose) {
client.releaseSession()
}
client.shutdown()
allocator.close()
SparkSession.onSessionClose(this)
Expand Down Expand Up @@ -735,6 +738,11 @@ class SparkSession private[sql] (
* We null out the instance for now.
*/
private def writeReplace(): Any = null

/**
* Set to false to prevent client.releaseSession on close() (testing only)
*/
private[sql] var releaseSessionOnClose = true
}

// The minimal builder needed to create a spark session.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ class PlanGenerationTestSuite
}

override protected def afterAll(): Unit = {
session.close()
// Don't call client.releaseSession on close(), because the connection details are dummy.
session.releaseSessionOnClose = false
session.stop()
if (cleanOrphanedGoldenFiles) {
cleanOrphanedGoldenFile()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,24 @@ class SparkSessionSuite extends ConnectFunSuite {
private val connectionString2: String = "sc://test.me:14099"
private val connectionString3: String = "sc://doit:16845"

private def closeSession(session: SparkSession): Unit = {
// Don't call client.releaseSession on close(), because the connection details are dummy.
session.releaseSessionOnClose = false
session.close()
}

test("default") {
val session = SparkSession.builder().getOrCreate()
assert(session.client.configuration.host == "localhost")
assert(session.client.configuration.port == 15002)
session.close()
closeSession(session)
}

test("remote") {
val session = SparkSession.builder().remote(connectionString2).getOrCreate()
assert(session.client.configuration.host == "test.me")
assert(session.client.configuration.port == 14099)
session.close()
closeSession(session)
}

test("getOrCreate") {
Expand All @@ -53,8 +59,8 @@ class SparkSessionSuite extends ConnectFunSuite {
try {
assert(session1 eq session2)
} finally {
session1.close()
session2.close()
closeSession(session1)
closeSession(session2)
}
}

Expand All @@ -65,8 +71,8 @@ class SparkSessionSuite extends ConnectFunSuite {
assert(session1 ne session2)
assert(session1.client.configuration == session2.client.configuration)
} finally {
session1.close()
session2.close()
closeSession(session1)
closeSession(session2)
}
}

Expand All @@ -77,8 +83,8 @@ class SparkSessionSuite extends ConnectFunSuite {
assert(session1 ne session2)
assert(session1.client.configuration == session2.client.configuration)
} finally {
session1.close()
session2.close()
closeSession(session1)
closeSession(session2)
}
}

Expand All @@ -98,7 +104,7 @@ class SparkSessionSuite extends ConnectFunSuite {
assertThrows[RuntimeException] {
session.range(10).count()
}
session.close()
closeSession(session)
}

test("Default/Active session") {
Expand Down Expand Up @@ -136,12 +142,12 @@ class SparkSessionSuite extends ConnectFunSuite {
assert(SparkSession.getActiveSession.contains(session1))

// Close session1
session1.close()
closeSession(session1)
assert(SparkSession.getDefaultSession.contains(session2))
assert(SparkSession.getActiveSession.isEmpty)

// Close session2
session2.close()
closeSession(session2)
assert(SparkSession.getDefaultSession.isEmpty)
assert(SparkSession.getActiveSession.isEmpty)
}
Expand Down Expand Up @@ -187,7 +193,7 @@ class SparkSessionSuite extends ConnectFunSuite {

// Step 3 - close session 1, no more default session in both scripts
phaser.arriveAndAwaitAdvance()
session1.close()
closeSession(session1)

// Step 4 - no default session, same active session.
phaser.arriveAndAwaitAdvance()
Expand Down Expand Up @@ -240,27 +246,27 @@ class SparkSessionSuite extends ConnectFunSuite {

// Step 7 - close active session in script2
phaser.arriveAndAwaitAdvance()
internalSession.close()
closeSession(internalSession)
assert(SparkSession.getActiveSession.isEmpty)
}
assert(script1.get())
assert(script2.get())
assert(SparkSession.getActiveSession.contains(session2))
session2.close()
closeSession(session2)
assert(SparkSession.getActiveSession.isEmpty)
} finally {
executor.shutdown()
}
}

test("deprecated methods") {
SparkSession
val session = SparkSession
.builder()
.master("yayay")
.appName("bob")
.enableHiveSupport()
.create()
.close()
closeSession(session)
}

test("serialize as null") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,30 @@ message ReleaseExecuteResponse {
optional string operation_id = 2;
}

message ReleaseSessionRequest {
// (Required)
//
// The session_id of the request to reattach to.
// This must be an id of existing session.
string session_id = 1;

// (Required) User context
//
// user_context.user_id and session+id both identify a unique remote spark session on the
// server side.
UserContext user_context = 2;

// Provides optional information about the client sending the request. This field
// can be used for language or version specific information and is only intended for
// logging purposes and will not be interpreted by the server.
optional string client_type = 3;
}

message ReleaseSessionResponse {
// Session id of the session on which the release executed.
string session_id = 1;
}

message FetchErrorDetailsRequest {

// (Required)
Expand Down Expand Up @@ -934,6 +958,12 @@ service SparkConnectService {
// RPC and ReleaseExecute may not be used.
rpc ReleaseExecute(ReleaseExecuteRequest) returns (ReleaseExecuteResponse) {}

// Release a session.
// All the executions in the session will be released. Any further requests for the session with
// that session_id for the given user_id will fail. If the session didn't exist or was already
// released, this is a noop.
rpc ReleaseSession(ReleaseSessionRequest) returns (ReleaseSessionResponse) {}

// FetchErrorDetails retrieves the matched exception with details based on a provided error id.
rpc FetchErrorDetails(FetchErrorDetailsRequest) returns (FetchErrorDetailsResponse) {}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ private[connect] class CustomSparkConnectBlockingStub(
}
}

def releaseSession(request: ReleaseSessionRequest): ReleaseSessionResponse = {
grpcExceptionConverter.convert(
request.getSessionId,
request.getUserContext,
request.getClientType) {
retryHandler.retry {
stub.releaseSession(request)
}
}
}

def artifactStatus(request: ArtifactStatusesRequest): ArtifactStatusesResponse = {
grpcExceptionConverter.convert(
request.getSessionId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,16 @@ private[sql] class SparkConnectClient(
bstub.interrupt(request)
}

private[sql] def releaseSession(): proto.ReleaseSessionResponse = {
val builder = proto.ReleaseSessionRequest.newBuilder()
val request = builder
.setUserContext(userContext)
.setSessionId(sessionId)
.setClientType(userAgent)
.build()
bstub.releaseSession(request)
}

private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] {
override def childValue(parent: mutable.Set[String]): mutable.Set[String] = {
// Note: make a clone such that changes in the parent tags aren't reflected in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,24 @@ object Connect {
.intConf
.createWithDefault(1024)

val CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT =
buildStaticConf("spark.connect.session.manager.defaultSessionTimeout")
.internal()
.doc("Timeout after which sessions without any new incoming RPC will be removed.")
.version("4.0.0")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("60m")

val CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE =
buildStaticConf("spark.connect.session.manager.closedSessionsTombstonesSize")
.internal()
.doc(
"Maximum size of the cache of sessions after which sessions that did not receive any " +
"requests will be removed.")
.version("4.0.0")
.intConf
.createWithDefaultString("1000")

val CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT =
buildStaticConf("spark.connect.execute.manager.detachedTimeout")
.internal()
Expand Down
Loading

0 comments on commit 59e291d

Please sign in to comment.