Skip to content

Commit

Permalink
resolve
Browse files Browse the repository at this point in the history
  • Loading branch information
itholic committed Aug 8, 2023
2 parents 607c041 + d2b60ff commit 55a2ec5
Show file tree
Hide file tree
Showing 30 changed files with 582 additions and 189 deletions.
17 changes: 17 additions & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,22 @@ jobs:
- name: Spark connect jvm client mima check
if: inputs.branch != 'branch-3.3'
run: ./dev/connect-jvm-client-mima-check
- name: Install Python linter dependencies for branch-3.3
if: inputs.branch == 'branch-3.3'
run: |
# SPARK-44554: Copy from https://github.com/apache/spark/blob/073d0b60d31bf68ebacdc005f59b928a5902670f/.github/workflows/build_and_test.yml#L501-L508
# Should delete this section after SPARK 3.3 EOL.
python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==21.12b0'
python3.9 -m pip install 'pandas-stubs==1.2.0.53'
- name: Install Python linter dependencies for branch-3.4
if: inputs.branch == 'branch-3.4'
run: |
# SPARK-44554: Copy from https://github.com/apache/spark/blob/a05c27e85829fe742c1828507a1fd180cdc84b54/.github/workflows/build_and_test.yml#L571-L578
# Should delete this section after SPARK 3.4 EOL.
python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==22.6.0'
python3.9 -m pip install 'pandas-stubs==1.2.0.53' ipython 'grpcio==1.48.1' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0'
- name: Install Python linter dependencies
if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4'
run: |
# TODO(SPARK-32407): Sphinx 3.1+ does not correctly index nested classes.
# See also https://github.com/sphinx-doc/sphinx/issues/7551.
Expand All @@ -668,6 +683,7 @@ jobs:
- name: Python linter
run: PYTHON_EXECUTABLE=python3.9 ./dev/lint-python
- name: Install dependencies for Python code generation check
if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4'
run: |
# See more in "Installation" https://docs.buf.build/installation#tarball
curl -LO https://github.com/bufbuild/buf/releases/download/v1.24.0/buf-Linux-x86_64.tar.gz
Expand All @@ -676,6 +692,7 @@ jobs:
rm buf-Linux-x86_64.tar.gz
python3.9 -m pip install 'protobuf==3.20.3' 'mypy-protobuf==3.3.0'
- name: Python code generation check
if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4'
run: if test -f ./dev/connect-check-protos.py; then PATH=$PATH:$HOME/buf/bin PYTHON_EXECUTABLE=python3.9 ./dev/connect-check-protos.py; fi
- name: Install JavaScript linter dependencies
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql
import java.io.Closeable
import java.net.URI
import java.util.concurrent.TimeUnit._
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}

import scala.collection.JavaConverters._
import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -730,6 +730,23 @@ object SparkSession extends Logging {
override def load(c: Configuration): SparkSession = create(c)
})

/** The active SparkSession for the current thread. */
private val activeThreadSession = new InheritableThreadLocal[SparkSession]

/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[SparkSession]

/**
* Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when
* they are not set yet.
*/
private def setDefaultAndActiveSession(session: SparkSession): Unit = {
defaultSession.compareAndSet(null, session)
if (getActiveSession.isEmpty) {
setActiveSession(session)
}
}

/**
* Create a new [[SparkSession]] based on the connect client [[Configuration]].
*/
Expand All @@ -742,8 +759,17 @@ object SparkSession extends Logging {
*/
private[sql] def onSessionClose(session: SparkSession): Unit = {
sessions.invalidate(session.client.configuration)
defaultSession.compareAndSet(session, null)
if (getActiveSession.contains(session)) {
clearActiveSession()
}
}

/**
* Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
*
* @since 3.4.0
*/
def builder(): Builder = new Builder()

private[sql] lazy val cleaner = {
Expand Down Expand Up @@ -799,10 +825,15 @@ object SparkSession extends Logging {
*
* This will always return a newly created session.
*
* This method will update the default and/or active session if they are not set.
*
* @since 3.5.0
*/
def create(): SparkSession = {
tryCreateSessionFromClient().getOrElse(SparkSession.this.create(builder.configuration))
val session = tryCreateSessionFromClient()
.getOrElse(SparkSession.this.create(builder.configuration))
setDefaultAndActiveSession(session)
session
}

/**
Expand All @@ -811,30 +842,79 @@ object SparkSession extends Logging {
* If a session exist with the same configuration that is returned instead of creating a new
* session.
*
* This method will update the default and/or active session if they are not set.
*
* @since 3.5.0
*/
def getOrCreate(): SparkSession = {
tryCreateSessionFromClient().getOrElse(sessions.get(builder.configuration))
val session = tryCreateSessionFromClient()
.getOrElse(sessions.get(builder.configuration))
setDefaultAndActiveSession(session)
session
}
}

def getActiveSession: Option[SparkSession] = {
throw new UnsupportedOperationException("getActiveSession is not supported")
/**
* Returns the default SparkSession.
*
* @since 3.5.0
*/
def getDefaultSession: Option[SparkSession] = Option(defaultSession.get())

/**
* Sets the default SparkSession.
*
* @since 3.5.0
*/
def setDefaultSession(session: SparkSession): Unit = {
defaultSession.set(session)
}

def getDefaultSession: Option[SparkSession] = {
throw new UnsupportedOperationException("getDefaultSession is not supported")
/**
* Clears the default SparkSession.
*
* @since 3.5.0
*/
def clearDefaultSession(): Unit = {
defaultSession.set(null)
}

/**
* Returns the active SparkSession for the current thread.
*
* @since 3.5.0
*/
def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get())

/**
* Changes the SparkSession that will be returned in this thread and its children when
* SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
* an isolated SparkSession.
*
* @since 3.5.0
*/
def setActiveSession(session: SparkSession): Unit = {
throw new UnsupportedOperationException("setActiveSession is not supported")
activeThreadSession.set(session)
}

/**
* Clears the active SparkSession for current thread.
*
* @since 3.5.0
*/
def clearActiveSession(): Unit = {
throw new UnsupportedOperationException("clearActiveSession is not supported")
activeThreadSession.remove()
}

/**
* Returns the currently active SparkSession, otherwise the default one. If there is no default
* SparkSession, throws an exception.
*
* @since 3.5.0
*/
def active: SparkSession = {
throw new UnsupportedOperationException("active is not supported")
getActiveSession
.orElse(getDefaultSession)
.getOrElse(throw new IllegalStateException("No active or default Spark session found"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
*/
package org.apache.spark.sql

import java.util.concurrent.{Executors, Phaser}

import scala.util.control.NonFatal

import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor}

import org.apache.spark.sql.connect.client.util.ConnectFunSuite
Expand All @@ -24,6 +28,10 @@ import org.apache.spark.sql.connect.client.util.ConnectFunSuite
* Tests for non-dataframe related SparkSession operations.
*/
class SparkSessionSuite extends ConnectFunSuite {
private val connectionString1: String = "sc://test.it:17845"
private val connectionString2: String = "sc://test.me:14099"
private val connectionString3: String = "sc://doit:16845"

test("default") {
val session = SparkSession.builder().getOrCreate()
assert(session.client.configuration.host == "localhost")
Expand All @@ -32,16 +40,15 @@ class SparkSessionSuite extends ConnectFunSuite {
}

test("remote") {
val session = SparkSession.builder().remote("sc://test.me:14099").getOrCreate()
val session = SparkSession.builder().remote(connectionString2).getOrCreate()
assert(session.client.configuration.host == "test.me")
assert(session.client.configuration.port == 14099)
session.close()
}

test("getOrCreate") {
val connectionString = "sc://test.it:17865"
val session1 = SparkSession.builder().remote(connectionString).getOrCreate()
val session2 = SparkSession.builder().remote(connectionString).getOrCreate()
val session1 = SparkSession.builder().remote(connectionString1).getOrCreate()
val session2 = SparkSession.builder().remote(connectionString1).getOrCreate()
try {
assert(session1 eq session2)
} finally {
Expand All @@ -51,9 +58,8 @@ class SparkSessionSuite extends ConnectFunSuite {
}

test("create") {
val connectionString = "sc://test.it:17845"
val session1 = SparkSession.builder().remote(connectionString).create()
val session2 = SparkSession.builder().remote(connectionString).create()
val session1 = SparkSession.builder().remote(connectionString1).create()
val session2 = SparkSession.builder().remote(connectionString1).create()
try {
assert(session1 ne session2)
assert(session1.client.configuration == session2.client.configuration)
Expand All @@ -64,8 +70,7 @@ class SparkSessionSuite extends ConnectFunSuite {
}

test("newSession") {
val connectionString = "sc://doit:16845"
val session1 = SparkSession.builder().remote(connectionString).create()
val session1 = SparkSession.builder().remote(connectionString3).create()
val session2 = session1.newSession()
try {
assert(session1 ne session2)
Expand All @@ -92,5 +97,126 @@ class SparkSessionSuite extends ConnectFunSuite {
assertThrows[RuntimeException] {
session.range(10).count()
}
session.close()
}

test("Default/Active session") {
// Make sure we start with a clean slate.
SparkSession.clearDefaultSession()
SparkSession.clearActiveSession()
assert(SparkSession.getDefaultSession.isEmpty)
assert(SparkSession.getActiveSession.isEmpty)
intercept[IllegalStateException](SparkSession.active)

// Create a session
val session1 = SparkSession.builder().remote(connectionString1).getOrCreate()
assert(SparkSession.getDefaultSession.contains(session1))
assert(SparkSession.getActiveSession.contains(session1))
assert(SparkSession.active == session1)

// Create another session...
val session2 = SparkSession.builder().remote(connectionString2).create()
assert(SparkSession.getDefaultSession.contains(session1))
assert(SparkSession.getActiveSession.contains(session1))
SparkSession.setActiveSession(session2)
assert(SparkSession.getDefaultSession.contains(session1))
assert(SparkSession.getActiveSession.contains(session2))

// Clear sessions
SparkSession.clearDefaultSession()
assert(SparkSession.getDefaultSession.isEmpty)
SparkSession.clearActiveSession()
assert(SparkSession.getDefaultSession.isEmpty)

// Flip sessions
SparkSession.setActiveSession(session1)
SparkSession.setDefaultSession(session2)
assert(SparkSession.getDefaultSession.contains(session2))
assert(SparkSession.getActiveSession.contains(session1))

// Close session1
session1.close()
assert(SparkSession.getDefaultSession.contains(session2))
assert(SparkSession.getActiveSession.isEmpty)

// Close session2
session2.close()
assert(SparkSession.getDefaultSession.isEmpty)
assert(SparkSession.getActiveSession.isEmpty)
}

test("active session in multiple threads") {
SparkSession.clearDefaultSession()
SparkSession.clearActiveSession()
val session1 = SparkSession.builder().remote(connectionString1).create()
val session2 = SparkSession.builder().remote(connectionString1).create()
SparkSession.setActiveSession(session2)
assert(SparkSession.getDefaultSession.contains(session1))
assert(SparkSession.getActiveSession.contains(session2))

val phaser = new Phaser(2)
val executor = Executors.newFixedThreadPool(2)
def execute(block: Phaser => Unit): java.util.concurrent.Future[Boolean] = {
executor.submit[Boolean] { () =>
try {
block(phaser)
true
} catch {
case NonFatal(e) =>
phaser.forceTermination()
throw e
}
}
}

try {
val script1 = execute { phaser =>
phaser.arriveAndAwaitAdvance()
assert(SparkSession.getDefaultSession.contains(session1))
assert(SparkSession.getActiveSession.contains(session2))

phaser.arriveAndAwaitAdvance()
assert(SparkSession.getDefaultSession.contains(session1))
assert(SparkSession.getActiveSession.contains(session2))
session1.close()

phaser.arriveAndAwaitAdvance()
assert(SparkSession.getDefaultSession.isEmpty)
assert(SparkSession.getActiveSession.contains(session2))
SparkSession.clearActiveSession()

phaser.arriveAndAwaitAdvance()
assert(SparkSession.getDefaultSession.isEmpty)
assert(SparkSession.getActiveSession.isEmpty)
}
val script2 = execute { phaser =>
phaser.arriveAndAwaitAdvance()
assert(SparkSession.getDefaultSession.contains(session1))
assert(SparkSession.getActiveSession.contains(session2))
SparkSession.clearActiveSession()
val internalSession = SparkSession.builder().remote(connectionString3).getOrCreate()

phaser.arriveAndAwaitAdvance()
assert(SparkSession.getDefaultSession.contains(session1))
assert(SparkSession.getActiveSession.contains(internalSession))

phaser.arriveAndAwaitAdvance()
assert(SparkSession.getDefaultSession.isEmpty)
assert(SparkSession.getActiveSession.contains(internalSession))

phaser.arriveAndAwaitAdvance()
assert(SparkSession.getDefaultSession.isEmpty)
assert(SparkSession.getActiveSession.contains(internalSession))
internalSession.close()
assert(SparkSession.getActiveSession.isEmpty)
}
assert(script1.get())
assert(script2.get())
assert(SparkSession.getActiveSession.contains(session2))
session2.close()
assert(SparkSession.getActiveSession.isEmpty)
} finally {
executor.shutdown()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"),

// SparkSession
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.clearDefaultSession"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.setDefaultSession"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sparkContext"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sharedState"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sessionState"),
Expand Down
Loading

0 comments on commit 55a2ec5

Please sign in to comment.