Skip to content

Commit

Permalink
[SPARK-48775][SQL][STS] Replace SQLContext with SparkSession in STS
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Remove the exposed `SQLContext` which was added in SPARK-46575. And migrate STS internal used `SQLContext` to `SparkSession`.

### Why are the changes needed?

`SQLContext` is not recommended since Spark 2.0, the suggested replacement is `SparkSession`. We should avoid exposing the deprecated class to Developer API in new versions.

### Does this PR introduce _any_ user-facing change?

No. It touched the Developer API added in SPARK-46575, but is not released yet.

### How was this patch tested?

Pass GHA, and `dev/mima` (not breaking changes involved)

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#47176 from pan3793/SPARK-48775.

Lead-authored-by: Cheng Pan <chengpan@apache.org>
Co-authored-by: Cheng Pan <pan3793@gmail.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
  • Loading branch information
2 people authored and dongjoon-hyun committed Jul 11, 2024
1 parent 297a9d2 commit 4501285
Show file tree
Hide file tree
Showing 18 changed files with 150 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
import org.apache.spark.sql.hive.thriftserver.ui._
Expand All @@ -49,30 +49,32 @@ object HiveThriftServer2 extends Logging {

/**
* :: DeveloperApi ::
* Starts a new thrift server with the given context.
* Starts a new thrift server with the given SparkSession.
*
* @param sqlContext SQLContext to use for the server
* @param sparkSession SparkSession to use for the server
* @param exitOnError Whether to exit the JVM if HiveThriftServer2 fails to initialize. When true,
* the call logs the error and exits the JVM with exit code -1. When false, the
* call throws an exception instead.
*/
@Since("4.0.0")
@DeveloperApi
def startWithContext(sqlContext: SQLContext, exitOnError: Boolean): HiveThriftServer2 = {
def startWithSparkSession(
sparkSession: SparkSession,
exitOnError: Boolean): HiveThriftServer2 = {
systemExitOnError.set(exitOnError)

val executionHive = HiveUtils.newClientForExecution(
sqlContext.sparkContext.conf,
sqlContext.sessionState.newHadoopConf())
sparkSession.sparkContext.conf,
sparkSession.sessionState.newHadoopConf())

// Cleanup the scratch dir before starting
ServerUtils.cleanUpScratchDir(executionHive.conf)
val server = new HiveThriftServer2(sqlContext)
val server = new HiveThriftServer2(sparkSession)

server.init(executionHive.conf)
server.start()
logInfo("HiveThriftServer2 started")
createListenerAndUI(server, sqlContext.sparkContext)
createListenerAndUI(server, sparkSession.sparkContext)
server
}

Expand All @@ -82,10 +84,11 @@ object HiveThriftServer2 extends Logging {
*
* @param sqlContext SQLContext to use for the server
*/
@deprecated("Use startWithSparkSession instead", since = "4.0.0")
@Since("2.0.0")
@DeveloperApi
def startWithContext(sqlContext: SQLContext): HiveThriftServer2 = {
startWithContext(sqlContext, exitOnError = true)
startWithSparkSession(sqlContext.sparkSession, exitOnError = true)
}

private def createListenerAndUI(server: HiveThriftServer2, sc: SparkContext): Unit = {
Expand Down Expand Up @@ -122,7 +125,7 @@ object HiveThriftServer2 extends Logging {
}

try {
startWithContext(SparkSQLEnv.sqlContext)
startWithContext(SparkSQLEnv.sparkSession.sqlContext)
// If application was killed before HiveThriftServer2 start successfully then SparkSubmit
// process can not exit, so check whether if SparkContext was stopped.
if (SparkSQLEnv.sparkContext.stopped.get()) {
Expand All @@ -142,15 +145,15 @@ object HiveThriftServer2 extends Logging {
}
}

private[hive] class HiveThriftServer2(sqlContext: SQLContext)
private[hive] class HiveThriftServer2(sparkSession: SparkSession)
extends HiveServer2
with ReflectedCompositeService {
// state is tracked internally so that the server only attempts to shut down if it successfully
// started, and then once only.
private val started = new AtomicBoolean(false)

override def init(hiveConf: HiveConf): Unit = {
val sparkSqlCliService = new SparkSQLCLIService(this, sqlContext)
val sparkSqlCliService = new SparkSQLCLIService(this, sparkSession)
setSuperField(this, "cliService", sparkSqlCliService)
addService(sparkSqlCliService)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ import org.apache.hive.service.rpc.thrift.{TCLIServiceConstants, TColumnDesc, TP

import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND
import org.apache.spark.sql.internal.{SQLConf, VariableSubstitution}
import org.apache.spark.sql.types._
import org.apache.spark.util.{Utils => SparkUtils}

private[hive] class SparkExecuteStatementOperation(
val sqlContext: SQLContext,
val session: SparkSession,
parentSession: HiveSession,
statement: String,
confOverlay: JMap[String, String],
Expand All @@ -50,8 +50,6 @@ private[hive] class SparkExecuteStatementOperation(
with SparkOperation
with Logging {

val session = sqlContext.sparkSession

// If a timeout value `queryTimeout` is specified by users and it is smaller than
// a global timeout value, we use the user-specified value.
// This code follows the Hive timeout behaviour (See #29933 for details).
Expand Down Expand Up @@ -90,10 +88,10 @@ private[hive] class SparkExecuteStatementOperation(

def getNextRowSet(order: FetchOrientation, maxRowsL: Long): TRowSet = withLocalProperties {
try {
sqlContext.sparkContext.setJobGroup(statementId, redactedStatement, forceCancel)
session.sparkContext.setJobGroup(statementId, redactedStatement, forceCancel)
getNextRowSetInternal(order, maxRowsL)
} finally {
sqlContext.sparkContext.clearJobGroup()
session.sparkContext.clearJobGroup()
}
}

Expand Down Expand Up @@ -224,20 +222,20 @@ private[hive] class SparkExecuteStatementOperation(
}
}
// Always use the latest class loader provided by executionHive's state.
val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader
val executionHiveClassLoader = session.sharedState.jarClassLoader
Thread.currentThread().setContextClassLoader(executionHiveClassLoader)

// Always set the session state classloader to `executionHiveClassLoader` even for sync mode
if (!runInBackground) {
parentSession.getSessionState.getConf.setClassLoader(executionHiveClassLoader)
}

sqlContext.sparkContext.setJobGroup(statementId, redactedStatement, forceCancel)
result = sqlContext.sql(statement)
session.sparkContext.setJobGroup(statementId, redactedStatement, forceCancel)
result = session.sql(statement)
logDebug(result.queryExecution.toString())
HiveThriftServer2.eventManager.onStatementParsed(statementId,
result.queryExecution.toString())
iter = if (sqlContext.getConf(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) {
iter = if (session.conf.get(SQLConf.THRIFTSERVER_INCREMENTAL_COLLECT.key).toBoolean) {
new IterableFetchIterator[Row](new Iterable[Row] {
override def iterator: Iterator[Row] = result.toLocalIterator().asScala
})
Expand All @@ -254,7 +252,7 @@ private[hive] class SparkExecuteStatementOperation(
// task interrupted, it may have start some spark job, so we need to cancel again to
// make sure job was cancelled when background thread was interrupted
if (statementId != null) {
sqlContext.sparkContext.cancelJobGroup(statementId)
session.sparkContext.cancelJobGroup(statementId)
}
val currentState = getStatus().getState()
if (currentState.isTerminal) {
Expand All @@ -271,7 +269,7 @@ private[hive] class SparkExecuteStatementOperation(
e match {
case _: HiveSQLException => throw e
case _ => throw HiveThriftServerErrors.runningQueryError(
e, sqlContext.sparkSession.sessionState.conf.errorMessageFormat)
e, session.sessionState.conf.errorMessageFormat)
}
}
} finally {
Expand All @@ -281,7 +279,7 @@ private[hive] class SparkExecuteStatementOperation(
HiveThriftServer2.eventManager.onStatementFinish(statementId)
}
}
sqlContext.sparkContext.clearJobGroup()
session.sparkContext.clearJobGroup()
}
}

Expand Down Expand Up @@ -318,7 +316,7 @@ private[hive] class SparkExecuteStatementOperation(
}
// RDDs will be cleaned automatically upon garbage collection.
if (statementId != null) {
sqlContext.sparkContext.cancelJobGroup(statementId)
session.sparkContext.cancelJobGroup(statementId)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ import org.apache.hive.service.cli.session.HiveSession

import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.SparkSession

/**
* Spark's own GetCatalogsOperation
*
* @param sqlContext SQLContext to use
* @param session SparkSession to use
* @param parentSession a HiveSession from SessionManager
*/
private[hive] class SparkGetCatalogsOperation(
val sqlContext: SQLContext,
val session: SparkSession,
parentSession: HiveSession)
extends GetCatalogsOperation(parentSession)
with SparkOperation
Expand All @@ -44,7 +44,7 @@ private[hive] class SparkGetCatalogsOperation(
logInfo(log"Listing catalogs with ${MDC(STATEMENT_ID, statementId)}")
setState(OperationState.RUNNING)
// Always use the latest class loader provided by executionHive's state.
val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader
val executionHiveClassLoader = session.sharedState.jarClassLoader
Thread.currentThread().setContextClassLoader(executionHiveClassLoader)

HiveThriftServer2.eventManager.onStatementStart(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,23 @@ import org.apache.hive.service.cli.session.HiveSession

import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.types._

/**
* Spark's own SparkGetColumnsOperation
*
* @param sqlContext SQLContext to use
* @param session SparkSession to use
* @param parentSession a HiveSession from SessionManager
* @param catalogName catalog name. NULL if not applicable.
* @param schemaName database name, NULL or a concrete database name
* @param tableName table name
* @param columnName column name
*/
private[hive] class SparkGetColumnsOperation(
val sqlContext: SQLContext,
val session: SparkSession,
parentSession: HiveSession,
catalogName: String,
schemaName: String,
Expand All @@ -55,7 +55,7 @@ private[hive] class SparkGetColumnsOperation(
with SparkOperation
with Logging {

val catalog: SessionCatalog = sqlContext.sessionState.catalog
val catalog: SessionCatalog = session.sessionState.catalog

override def runInternal(): Unit = {
// Do not change cmdStr. It's used for Hive auditing and authorization.
Expand All @@ -72,7 +72,7 @@ private[hive] class SparkGetColumnsOperation(

setState(OperationState.RUNNING)
// Always use the latest class loader provided by executionHive's state.
val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader
val executionHiveClassLoader = session.sharedState.jarClassLoader
Thread.currentThread().setContextClassLoader(executionHiveClassLoader)

HiveThriftServer2.eventManager.onStatementStart(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ import org.apache.hive.service.cli.operation.MetadataOperation.DEFAULT_HIVE_CATA
import org.apache.hive.service.cli.session.HiveSession

import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.SparkSession

/**
* Spark's own GetFunctionsOperation
*
* @param sqlContext SQLContext to use
* @param session SparkSession to use
* @param parentSession a HiveSession from SessionManager
* @param catalogName catalog name. null if not applicable
* @param schemaName database name, null or a concrete database name
* @param functionName function name pattern
*/
private[hive] class SparkGetFunctionsOperation(
val sqlContext: SQLContext,
val session: SparkSession,
parentSession: HiveSession,
catalogName: String,
schemaName: String,
Expand All @@ -58,10 +58,10 @@ private[hive] class SparkGetFunctionsOperation(
logInfo(logMDC + log" with ${MDC(LogKeys.STATEMENT_ID, statementId)}")
setState(OperationState.RUNNING)
// Always use the latest class loader provided by executionHive's state.
val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader
val executionHiveClassLoader = session.sharedState.jarClassLoader
Thread.currentThread().setContextClassLoader(executionHiveClassLoader)

val catalog = sqlContext.sessionState.catalog
val catalog = session.sessionState.catalog
// get databases for schema pattern
val schemaPattern = convertSchemaPattern(schemaName)
val matchingDbs = catalog.listDatabases(schemaPattern)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@ import org.apache.hive.service.cli.session.HiveSession

import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.SparkSession

/**
* Spark's own GetSchemasOperation
*
* @param sqlContext SQLContext to use
* @param session SparkSession to use
* @param parentSession a HiveSession from SessionManager
* @param catalogName catalog name. null if not applicable.
* @param schemaName database name, null or a concrete database name
*/
private[hive] class SparkGetSchemasOperation(
val sqlContext: SQLContext,
val session: SparkSession,
parentSession: HiveSession,
catalogName: String,
schemaName: String)
Expand All @@ -59,7 +59,7 @@ private[hive] class SparkGetSchemasOperation(

setState(OperationState.RUNNING)
// Always use the latest class loader provided by executionHive's state.
val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader
val executionHiveClassLoader = session.sharedState.jarClassLoader
Thread.currentThread().setContextClassLoader(executionHiveClassLoader)

if (isAuthV2Enabled) {
Expand All @@ -75,11 +75,11 @@ private[hive] class SparkGetSchemasOperation(

try {
val schemaPattern = convertSchemaPattern(schemaName)
sqlContext.sessionState.catalog.listDatabases(schemaPattern).foreach { dbName =>
session.sessionState.catalog.listDatabases(schemaPattern).foreach { dbName =>
rowSet.addRow(Array[AnyRef](dbName, DEFAULT_HIVE_CATALOG))
}

val globalTempViewDb = sqlContext.sessionState.catalog.globalTempDatabase
val globalTempViewDb = session.sessionState.catalog.globalTempDatabase
val databasePattern = Pattern.compile(CLIServiceUtils.patternToRegex(schemaName))
if (schemaName == null || schemaName.isEmpty ||
databasePattern.matcher(globalTempViewDb).matches()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ import org.apache.hive.service.cli.session.HiveSession

import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.CatalogTableType

/**
* Spark's own GetTableTypesOperation
*
* @param sqlContext SQLContext to use
* @param session SparkSession to use
* @param parentSession a HiveSession from SessionManager
*/
private[hive] class SparkGetTableTypesOperation(
val sqlContext: SQLContext,
val session: SparkSession,
parentSession: HiveSession)
extends GetTableTypesOperation(parentSession)
with SparkOperation
Expand All @@ -48,7 +48,7 @@ private[hive] class SparkGetTableTypesOperation(
logInfo(log"Listing table types with ${MDC(STATEMENT_ID, statementId)}")
setState(OperationState.RUNNING)
// Always use the latest class loader provided by executionHive's state.
val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader
val executionHiveClassLoader = session.sharedState.jarClassLoader
Thread.currentThread().setContextClassLoader(executionHiveClassLoader)

if (isAuthV2Enabled) {
Expand Down
Loading

0 comments on commit 4501285

Please sign in to comment.