Skip to content

Commit

Permalink
Merge pull request #1582 from apache/master
Browse files Browse the repository at this point in the history
Create a new pull request by comparing changes across two branches
  • Loading branch information
GulajavaMinistudio authored Nov 8, 2023
2 parents f2ff420 + d582b74 commit 05be60b
Show file tree
Hide file tree
Showing 55 changed files with 2,129 additions and 828 deletions.
17 changes: 4 additions & 13 deletions common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,6 @@ public final class Platform {
// reflection to invoke it, which is not necessarily possible by default in Java 9+.
// Code below can test for null to see whether to use it.

// The implementation of Cleaner changed from JDK 8 to 9
String cleanerClassName;
if (majorVersion < 9) {
cleanerClassName = "sun.misc.Cleaner";
} else {
cleanerClassName = "jdk.internal.ref.Cleaner";
}
try {
Class<?> cls = Class.forName("java.nio.DirectByteBuffer");
Constructor<?> constructor = (majorVersion < 21) ?
Expand All @@ -84,7 +77,7 @@ public final class Platform {

// no point continuing if the above failed:
if (DBB_CONSTRUCTOR != null && DBB_CLEANER_FIELD != null) {
Class<?> cleanerClass = Class.forName(cleanerClassName);
Class<?> cleanerClass = Class.forName("jdk.internal.ref.Cleaner");
Method createMethod = cleanerClass.getMethod("create", Object.class, Runnable.class);
// Accessing jdk.internal.ref.Cleaner should actually fail by default in JDK 9+,
// unfortunately, unless the user has allowed access with something like
Expand Down Expand Up @@ -314,7 +307,7 @@ public static void throwException(Throwable t) {
}
}

// This requires `majorVersion` and `_UNSAFE`.
// This requires `_UNSAFE`.
static {
boolean _unaligned;
String arch = System.getProperty("os.arch", "");
Expand All @@ -326,10 +319,8 @@ public static void throwException(Throwable t) {
try {
Class<?> bitsClass =
Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader());
if (_UNSAFE != null && majorVersion >= 9) {
// Java 9/10 and 11/12 have different field names.
Field unalignedField =
bitsClass.getDeclaredField(majorVersion >= 11 ? "UNALIGNED" : "unaligned");
if (_UNSAFE != null) {
Field unalignedField = bitsClass.getDeclaredField("UNALIGNED");
_unaligned = _UNSAFE.getBoolean(
_UNSAFE.staticFieldBase(unalignedField), _UNSAFE.staticFieldOffset(unalignedField));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM
assert(
ex.getStackTrace
.find(_.getClassName.contains("org.apache.spark.sql.catalyst.analysis.CheckAnalysis"))
.isDefined
== isServerStackTraceEnabled)
.isDefined)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ArtifactStatusesRequest, ArtifactStatusesResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.test.ConnectFunSuite

Expand Down Expand Up @@ -210,19 +209,15 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
}

for ((name, constructor) <- GrpcExceptionConverter.errorFactory) {
test(s"error framework parameters - ${name}") {
test(s"error framework parameters - $name") {
val testParams = GrpcExceptionConverter.ErrorParams(
message = "test message",
message = "Found duplicate keys `abc`",
cause = None,
errorClass = Some("test error class"),
messageParameters = Map("key" -> "value"),
errorClass = Some("DUPLICATE_KEY"),
messageParameters = Map("keyColumn" -> "`abc`"),
queryContext = Array.empty)
val error = constructor(testParams)
if (!error.isInstanceOf[ParseException]) {
assert(error.getMessage == testParams.message)
} else {
assert(error.getMessage == s"\n${testParams.message}")
}
assert(error.getMessage.contains(testParams.message))
assert(error.getCause == null)
error match {
case sparkThrowable: SparkThrowable =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,9 @@ private[client] object GrpcExceptionConverter {
errorConstructor(params =>
new ParseException(
None,
params.message,
Origin(),
Origin(),
errorClass = params.errorClass,
errorClass = params.errorClass.orNull,
messageParameters = params.messageParameters,
queryContext = params.queryContext)),
errorConstructor(params =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Aggregate [count(if (((a#0 > 0) = false)) null else (a#0 > 0)) AS count_if((a > 0))#0L]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) AS count_if((a > 0))#0L]
+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Project [if ((regexp_extract(g#0, \d{2}(a|b|m), 0) = )) null else regexp_extract(g#0, \d{2}(a|b|m), 0) AS regexp_substr(g, \d{2}(a|b|m))#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Project [if ((_common_expr_0#0 = )) null else _common_expr_0#0 AS regexp_substr(g, \d{2}(a|b|m))#0]
+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, regexp_extract(g#0, \d{2}(a|b|m), 0) AS _common_expr_0#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.FetchErrorDetailsResponse
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.utils.ErrorUtils
import org.apache.spark.sql.internal.SQLConf

/**
* Handles [[proto.FetchErrorDetailsRequest]]s for the [[SparkConnectService]]. The handler
Expand All @@ -46,9 +44,7 @@ class SparkConnectFetchErrorDetailsHandler(

ErrorUtils.throwableToFetchErrorDetailsResponse(
st = error,
serverStackTraceEnabled = sessionHolder.session.conf.get(
Connect.CONNECT_SERVER_STACKTRACE_ENABLED) || sessionHolder.session.conf.get(
SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED))
serverStackTraceEnabled = true)
}
.getOrElse(FetchErrorDetailsResponse.newBuilder().build())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,13 @@ class SparkConnectSessionManager extends Logging {
}

private def newIsolatedSession(): SparkSession = {
SparkSession.active.newSession()
val active = SparkSession.active
if (active.sparkContext.isStopped) {
assert(SparkSession.getDefaultSession.nonEmpty)
SparkSession.getDefaultSession.get.newSession()
} else {
active.newSession()
}
}

private def validateSessionCreate(key: SessionKey): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,20 @@ private[connect] object ErrorUtils extends Logging {
"classes",
JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName))))

// Add the SQL State and Error Class to the response metadata of the ErrorInfoObject.
st match {
case e: SparkThrowable =>
val state = e.getSqlState
if (state != null && state.nonEmpty) {
errorInfo.putMetadata("sqlState", state)
}
val errorClass = e.getErrorClass
if (errorClass != null && errorClass.nonEmpty) {
errorInfo.putMetadata("errorClass", errorClass)
}
case _ =>
}

if (sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED))) {
// Generate a new unique key for this exception.
val errorId = UUID.randomUUID().toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.{catalog, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{caseSensitiveResolution, Analyzer, FunctionRegistry, Resolver, TableFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions
import org.apache.spark.sql.catalyst.optimizer.{ReplaceExpressions, RewriteWithExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.service.SessionHolder
Expand Down Expand Up @@ -181,8 +183,15 @@ class ProtoToParsedPlanTestSuite
val planner = new SparkConnectPlanner(SessionHolder.forTesting(spark))
val catalystPlan =
analyzer.executeAndCheck(planner.transformRelation(relation), new QueryPlanningTracker)
val actual =
removeMemoryAddress(normalizeExprIds(ReplaceExpressions(catalystPlan)).treeString)
val finalAnalyzedPlan = {
object Helper extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Finish Analysis", Once, ReplaceExpressions) ::
Batch("Rewrite With expression", Once, RewriteWithExpression) :: Nil
}
Helper.execute(catalystPlan)
}
val actual = removeMemoryAddress(normalizeExprIds(finalAnalyzedPlan).treeString)
val goldenFile = goldenFilePath.resolve(relativePath).getParent.resolve(name + ".explain")
Try(readGoldenFile(goldenFile)) match {
case Success(expected) if expected == actual => // Test passes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,11 @@ class FetchErrorDetailsHandlerSuite extends SharedSparkSession with ResourceHelp
assert(response.getErrors(1).getErrorTypeHierarchy(1) == classOf[Throwable].getName)
assert(response.getErrors(1).getErrorTypeHierarchy(2) == classOf[Object].getName)
assert(!response.getErrors(1).hasCauseIdx)
if (serverStacktraceEnabled) {
assert(response.getErrors(0).getStackTraceCount == testError.getStackTrace.length)
assert(
response.getErrors(1).getStackTraceCount ==
testError.getCause.getStackTrace.length)
} else {
assert(response.getErrors(0).getStackTraceCount == 0)
assert(response.getErrors(1).getStackTraceCount == 0)
}
assert(response.getErrors(0).getStackTraceCount == testError.getStackTrace.length)
assert(
response.getErrors(1).getStackTraceCount ==
testError.getCause.getStackTrace.length)

} finally {
sessionHolder.session.conf.unset(Connect.CONNECT_SERVER_STACKTRACE_ENABLED.key)
sessionHolder.session.conf.unset(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED.key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,56 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
accumulator = null)
}

test("python listener process: process terminates after listener is removed") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
// scalastyle:on assume

val sessionHolder = SessionHolder.forTesting(spark)
try {
SparkConnectService.start(spark.sparkContext)

val pythonFn = dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction)

val id1 = "listener_removeListener_test_1"
val id2 = "listener_removeListener_test_2"
val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder)

sessionHolder.cacheListenerById(id1, listener1)
spark.streams.addListener(listener1)
sessionHolder.cacheListenerById(id2, listener2)
spark.streams.addListener(listener2)

val (runner1, runner2) = (listener1.runner, listener2.runner)

// assert both python processes are running
assert(!runner1.isWorkerStopped().get)
assert(!runner2.isWorkerStopped().get)

// remove listener1
spark.streams.removeListener(listener1)
sessionHolder.removeCachedListener(id1)
// assert listener1's python process is not running
eventually(timeout(30.seconds)) {
assert(runner1.isWorkerStopped().get)
assert(!runner2.isWorkerStopped().get)
}

// remove listener2
spark.streams.removeListener(listener2)
sessionHolder.removeCachedListener(id2)
eventually(timeout(30.seconds)) {
// assert listener2's python process is not running
assert(runner2.isWorkerStopped().get)
// all listeners are removed
assert(spark.streams.listListeners().isEmpty)
}
} finally {
SparkConnectService.stop()
}
}

test("python foreachBatch process: process terminates after query is stopped") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
Expand Down Expand Up @@ -232,58 +282,10 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
assert(spark.streams.listListeners().length == 1) // only process termination listener
} finally {
SparkConnectService.stop()
// Wait for things to calm down.
Thread.sleep(4.seconds.toMillis)
// remove process termination listener
spark.streams.listListeners().foreach(spark.streams.removeListener)
}
}

test("python listener process: process terminates after listener is removed") {
// scalastyle:off assume
assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
// scalastyle:on assume

val sessionHolder = SessionHolder.forTesting(spark)
try {
SparkConnectService.start(spark.sparkContext)

val pythonFn = dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction)

val id1 = "listener_removeListener_test_1"
val id2 = "listener_removeListener_test_2"
val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder)

sessionHolder.cacheListenerById(id1, listener1)
spark.streams.addListener(listener1)
sessionHolder.cacheListenerById(id2, listener2)
spark.streams.addListener(listener2)

val (runner1, runner2) = (listener1.runner, listener2.runner)

// assert both python processes are running
assert(!runner1.isWorkerStopped().get)
assert(!runner2.isWorkerStopped().get)

// remove listener1
spark.streams.removeListener(listener1)
sessionHolder.removeCachedListener(id1)
// assert listener1's python process is not running
eventually(timeout(30.seconds)) {
assert(runner1.isWorkerStopped().get)
assert(!runner2.isWorkerStopped().get)
}

// remove listener2
spark.streams.removeListener(listener2)
sessionHolder.removeCachedListener(id2)
eventually(timeout(30.seconds)) {
// assert listener2's python process is not running
assert(runner2.isWorkerStopped().get)
// all listeners are removed
assert(spark.streams.listListeners().isEmpty)
}
} finally {
SparkConnectService.stop()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ private[deploy] object DeployMessages {
case class DriverStatusResponse(found: Boolean, state: Option[DriverState],
workerId: Option[String], workerHostPort: Option[String], exception: Option[Exception])

case object RequestClearCompletedDriversAndApps extends DeployMessage

// Internal message in AppClient

case object StopAppClient
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,14 @@ private[deploy] class Master(
}
}

case RequestClearCompletedDriversAndApps =>
val numDrivers = completedDrivers.length
val numApps = completedApps.length
logInfo(s"Asked to clear $numDrivers completed drivers and $numApps completed apps.")
completedDrivers.clear()
completedApps.clear()
context.reply(true)

case RequestDriverStatus(driverId) =>
if (state != RecoveryState.ALIVE) {
val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " +
Expand Down
Loading

0 comments on commit 05be60b

Please sign in to comment.