diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 5825185f34450..eded5da5c1ddd 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -1063,7 +1063,7 @@ jobs: export PVC_TESTS_VM_PATH=$PVC_TMP_DIR minikube mount ${PVC_TESTS_HOST_PATH}:${PVC_TESTS_VM_PATH} --gid=0 --uid=185 & kubectl create clusterrolebinding serviceaccounts-cluster-admin --clusterrole=cluster-admin --group=system:serviceaccounts || true - kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.0/installer/volcano-development.yaml || true + kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.1/installer/volcano-development.yaml || true eval $(minikube docker-env) build/sbt -Psparkr -Pkubernetes -Pvolcano -Pkubernetes-integration-tests -Dspark.kubernetes.test.driverRequestCores=0.5 -Dspark.kubernetes.test.executorRequestCores=0.2 -Dspark.kubernetes.test.volcanoMaxConcurrencyJobNum=1 -Dtest.exclude.tags=local "kubernetes-integration-tests/test" - name: Upload Spark on K8S integration tests log files diff --git a/common/utils/src/main/java/org/apache/spark/QueryContext.java b/common/utils/src/main/java/org/apache/spark/QueryContext.java index de5b29d02951d..45e38c8cfe0f3 100644 --- a/common/utils/src/main/java/org/apache/spark/QueryContext.java +++ b/common/utils/src/main/java/org/apache/spark/QueryContext.java @@ -27,6 +27,9 @@ */ @Evolving public interface QueryContext { + // The type of this query context. + QueryContextType contextType(); + // The object type of the query which throws the exception. // If the exception is directly from the main query, it should be an empty string. // Otherwise, it should be the exact object type in upper case. For example, a "VIEW". @@ -45,4 +48,10 @@ public interface QueryContext { // The corresponding fragment of the query which throws the exception. String fragment(); + + // The user code (call site of the API) that caused throwing the exception. + String callSite(); + + // Summary of the exception cause. + String summary(); } diff --git a/common/utils/src/main/java/org/apache/spark/QueryContextType.java b/common/utils/src/main/java/org/apache/spark/QueryContextType.java new file mode 100644 index 0000000000000..171833162bafa --- /dev/null +++ b/common/utils/src/main/java/org/apache/spark/QueryContextType.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.apache.spark.annotation.Evolving; + +/** + * The type of {@link QueryContext}. + * + * @since 4.0.0 + */ +@Evolving +public enum QueryContextType { + SQL, + DataFrame +} diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 278011b8cc8f4..af32bcf129c08 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -1737,6 +1737,11 @@ "Session already exists." ] }, + "SESSION_CLOSED" : { + "message" : [ + "Session was closed." + ] + }, "SESSION_NOT_FOUND" : { "message" : [ "Session not found." diff --git a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index b312a1a7e2278..a44d36ff85b55 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -114,13 +114,19 @@ private[spark] object SparkThrowableHelper { g.writeArrayFieldStart("queryContext") e.getQueryContext.foreach { c => g.writeStartObject() - g.writeStringField("objectType", c.objectType()) - g.writeStringField("objectName", c.objectName()) - val startIndex = c.startIndex() + 1 - if (startIndex > 0) g.writeNumberField("startIndex", startIndex) - val stopIndex = c.stopIndex() + 1 - if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex) - g.writeStringField("fragment", c.fragment()) + c.contextType() match { + case QueryContextType.SQL => + g.writeStringField("objectType", c.objectType()) + g.writeStringField("objectName", c.objectName()) + val startIndex = c.startIndex() + 1 + if (startIndex > 0) g.writeNumberField("startIndex", startIndex) + val stopIndex = c.stopIndex() + 1 + if (stopIndex > 0) g.writeNumberField("stopIndex", stopIndex) + g.writeStringField("fragment", c.fragment()) + case QueryContextType.DataFrame => + g.writeStringField("fragment", c.fragment()) + g.writeStringField("callSite", c.callSite()) + } g.writeEndObject() } g.writeEndArray() diff --git a/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala new file mode 100644 index 0000000000000..08997a800c957 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/ArrayImplicits.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.immutable + +/** + * Implicit methods related to Scala Array. + */ +private[spark] object ArrayImplicits { + + implicit class SparkArrayOps[T](xs: Array[T]) { + + /** + * Wraps an Array[T] as an immutable.ArraySeq[T] without copying. + */ + def toImmutableArraySeq: immutable.ArraySeq[T] = + if (xs eq null) null + else immutable.ArraySeq.unsafeWrapArray(xs) + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 969ac017ecb1d..34756f9a440bb 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,7 +21,6 @@ import java.net.URI import java.util.concurrent.TimeUnit._ import java.util.concurrent.atomic.{AtomicLong, AtomicReference} -import scala.collection.immutable import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag @@ -45,6 +44,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SqlApiConf} import org.apache.spark.sql.streaming.DataStreamReader import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ArrayImplicits._ /** * The entry point to programming Spark with the Dataset and DataFrame API. @@ -248,7 +248,7 @@ class SparkSession private[sql] ( proto.SqlCommand .newBuilder() .setSql(sqlText) - .addAllPosArguments(immutable.ArraySeq.unsafeWrapArray(args.map(lit(_).expr)).asJava))) + .addAllPosArguments(args.map(lit(_).expr).toImmutableArraySeq.asJava))) val plan = proto.Plan.newBuilder().setCommand(cmd) // .toBuffer forces that the iterator is consumed and closed val responseSeq = client.execute(plan.build()).toBuffer.toSeq @@ -665,6 +665,9 @@ class SparkSession private[sql] ( * @since 3.4.0 */ override def close(): Unit = { + if (releaseSessionOnClose) { + client.releaseSession() + } client.shutdown() allocator.close() SparkSession.onSessionClose(this) @@ -735,6 +738,11 @@ class SparkSession private[sql] ( * We null out the instance for now. */ private def writeReplace(): Any = null + + /** + * Set to false to prevent client.releaseSession on close() (testing only) + */ + private[sql] var releaseSessionOnClose = true } // The minimal builder needed to create a spark session. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index cf287088b59fb..5cc63bc45a04a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -120,7 +120,9 @@ class PlanGenerationTestSuite } override protected def afterAll(): Unit = { - session.close() + // Don't call client.releaseSession on close(), because the connection details are dummy. + session.releaseSessionOnClose = false + session.stop() if (cleanOrphanedGoldenFiles) { cleanOrphanedGoldenFile() } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 4c858262c6ef5..8abc41639fdd2 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -33,18 +33,24 @@ class SparkSessionSuite extends ConnectFunSuite { private val connectionString2: String = "sc://test.me:14099" private val connectionString3: String = "sc://doit:16845" + private def closeSession(session: SparkSession): Unit = { + // Don't call client.releaseSession on close(), because the connection details are dummy. + session.releaseSessionOnClose = false + session.close() + } + test("default") { val session = SparkSession.builder().getOrCreate() assert(session.client.configuration.host == "localhost") assert(session.client.configuration.port == 15002) - session.close() + closeSession(session) } test("remote") { val session = SparkSession.builder().remote(connectionString2).getOrCreate() assert(session.client.configuration.host == "test.me") assert(session.client.configuration.port == 14099) - session.close() + closeSession(session) } test("getOrCreate") { @@ -53,8 +59,8 @@ class SparkSessionSuite extends ConnectFunSuite { try { assert(session1 eq session2) } finally { - session1.close() - session2.close() + closeSession(session1) + closeSession(session2) } } @@ -65,8 +71,8 @@ class SparkSessionSuite extends ConnectFunSuite { assert(session1 ne session2) assert(session1.client.configuration == session2.client.configuration) } finally { - session1.close() - session2.close() + closeSession(session1) + closeSession(session2) } } @@ -77,8 +83,8 @@ class SparkSessionSuite extends ConnectFunSuite { assert(session1 ne session2) assert(session1.client.configuration == session2.client.configuration) } finally { - session1.close() - session2.close() + closeSession(session1) + closeSession(session2) } } @@ -98,7 +104,7 @@ class SparkSessionSuite extends ConnectFunSuite { assertThrows[RuntimeException] { session.range(10).count() } - session.close() + closeSession(session) } test("Default/Active session") { @@ -136,12 +142,12 @@ class SparkSessionSuite extends ConnectFunSuite { assert(SparkSession.getActiveSession.contains(session1)) // Close session1 - session1.close() + closeSession(session1) assert(SparkSession.getDefaultSession.contains(session2)) assert(SparkSession.getActiveSession.isEmpty) // Close session2 - session2.close() + closeSession(session2) assert(SparkSession.getDefaultSession.isEmpty) assert(SparkSession.getActiveSession.isEmpty) } @@ -187,7 +193,7 @@ class SparkSessionSuite extends ConnectFunSuite { // Step 3 - close session 1, no more default session in both scripts phaser.arriveAndAwaitAdvance() - session1.close() + closeSession(session1) // Step 4 - no default session, same active session. phaser.arriveAndAwaitAdvance() @@ -240,13 +246,13 @@ class SparkSessionSuite extends ConnectFunSuite { // Step 7 - close active session in script2 phaser.arriveAndAwaitAdvance() - internalSession.close() + closeSession(internalSession) assert(SparkSession.getActiveSession.isEmpty) } assert(script1.get()) assert(script2.get()) assert(SparkSession.getActiveSession.contains(session2)) - session2.close() + closeSession(session2) assert(SparkSession.getActiveSession.isEmpty) } finally { executor.shutdown() @@ -254,13 +260,13 @@ class SparkSessionSuite extends ConnectFunSuite { } test("deprecated methods") { - SparkSession + val session = SparkSession .builder() .master("yayay") .appName("bob") .enableHiveSupport() .create() - .close() + closeSession(session) } test("serialize as null") { diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 5b94c6d663cca..19a94a5a429f0 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -784,6 +784,30 @@ message ReleaseExecuteResponse { optional string operation_id = 2; } +message ReleaseSessionRequest { + // (Required) + // + // The session_id of the request to reattach to. + // This must be an id of existing session. + string session_id = 1; + + // (Required) User context + // + // user_context.user_id and session+id both identify a unique remote spark session on the + // server side. + UserContext user_context = 2; + + // Provides optional information about the client sending the request. This field + // can be used for language or version specific information and is only intended for + // logging purposes and will not be interpreted by the server. + optional string client_type = 3; +} + +message ReleaseSessionResponse { + // Session id of the session on which the release executed. + string session_id = 1; +} + message FetchErrorDetailsRequest { // (Required) @@ -823,6 +847,13 @@ message FetchErrorDetailsResponse { // QueryContext defines the schema for the query context of a SparkThrowable. // It helps users understand where the error occurs while executing queries. message QueryContext { + // The type of this query context. + enum ContextType { + SQL = 0; + DATAFRAME = 1; + } + ContextType context_type = 10; + // The object type of the query which throws the exception. // If the exception is directly from the main query, it should be an empty string. // Otherwise, it should be the exact object type in upper case. For example, a "VIEW". @@ -841,6 +872,12 @@ message FetchErrorDetailsResponse { // The corresponding fragment of the query which throws the exception. string fragment = 5; + + // The user code (call site of the API) that caused throwing the exception. + string callSite = 6; + + // Summary of the exception cause. + string summary = 7; } // SparkThrowable defines the schema for SparkThrowable exceptions. @@ -921,6 +958,12 @@ service SparkConnectService { // RPC and ReleaseExecute may not be used. rpc ReleaseExecute(ReleaseExecuteRequest) returns (ReleaseExecuteResponse) {} + // Release a session. + // All the executions in the session will be released. Any further requests for the session with + // that session_id for the given user_id will fail. If the session didn't exist or was already + // released, this is a noop. + rpc ReleaseSession(ReleaseSessionRequest) returns (ReleaseSessionResponse) {} + // FetchErrorDetails retrieves the matched exception with details based on a provided error id. rpc FetchErrorDetails(FetchErrorDetailsRequest) returns (FetchErrorDetailsResponse) {} } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala index f2efa26f6b609..e963b4136160f 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -96,6 +96,17 @@ private[connect] class CustomSparkConnectBlockingStub( } } + def releaseSession(request: ReleaseSessionRequest): ReleaseSessionResponse = { + grpcExceptionConverter.convert( + request.getSessionId, + request.getUserContext, + request.getClientType) { + retryHandler.retry { + stub.releaseSession(request) + } + } + } + def artifactStatus(request: ArtifactStatusesRequest): ArtifactStatusesResponse = { grpcExceptionConverter.convert( request.getSessionId, diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index b2782442f4a53..652797bc2e40f 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connect.client import java.time.DateTimeException -import scala.collection.immutable import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -28,7 +27,7 @@ import io.grpc.protobuf.StatusProto import org.json4s.DefaultFormats import org.json4s.jackson.JsonMethods -import org.apache.spark.{QueryContext, SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} +import org.apache.spark.{QueryContext, QueryContextType, SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} import org.apache.spark.connect.proto.{FetchErrorDetailsRequest, FetchErrorDetailsResponse, UserContext} import org.apache.spark.connect.proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub import org.apache.spark.internal.Logging @@ -37,6 +36,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.streaming.StreamingQueryException +import org.apache.spark.util.ArrayImplicits._ /** * GrpcExceptionConverter handles the conversion of StatusRuntimeExceptions into Spark exceptions. @@ -324,15 +324,18 @@ private[client] object GrpcExceptionConverter { val queryContext = error.getSparkThrowable.getQueryContextsList.asScala.map { queryCtx => new QueryContext { + override def contextType(): QueryContextType = queryCtx.getContextType match { + case FetchErrorDetailsResponse.QueryContext.ContextType.DATAFRAME => + QueryContextType.DataFrame + case _ => QueryContextType.SQL + } override def objectType(): String = queryCtx.getObjectType - override def objectName(): String = queryCtx.getObjectName - override def startIndex(): Int = queryCtx.getStartIndex - override def stopIndex(): Int = queryCtx.getStopIndex - override def fragment(): String = queryCtx.getFragment + override def callSite(): String = queryCtx.getCallSite + override def summary(): String = queryCtx.getSummary } }.toArray @@ -372,7 +375,7 @@ private[client] object GrpcExceptionConverter { FetchErrorDetailsResponse.Error .newBuilder() .setMessage(message) - .addAllErrorTypeHierarchy(immutable.ArraySeq.unsafeWrapArray(classes).asJava) + .addAllErrorTypeHierarchy(classes.toImmutableArraySeq.asJava) .build())) } } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 42ace003da89f..6d3d9420e2263 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -243,6 +243,16 @@ private[sql] class SparkConnectClient( bstub.interrupt(request) } + private[sql] def releaseSession(): proto.ReleaseSessionResponse = { + val builder = proto.ReleaseSessionRequest.newBuilder() + val request = builder + .setUserContext(userContext) + .setSessionId(sessionId) + .setClientType(userAgent) + .build() + bstub.releaseSession(request) + } + private[this] val tags = new InheritableThreadLocal[mutable.Set[String]] { override def childValue(parent: mutable.Set[String]): mutable.Set[String] = { // Note: make a clone such that changes in the parent tags aren't reflected in diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 2b3f218362cd3..1a5944676f5fb 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -74,6 +74,24 @@ object Connect { .intConf .createWithDefault(1024) + val CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT = + buildStaticConf("spark.connect.session.manager.defaultSessionTimeout") + .internal() + .doc("Timeout after which sessions without any new incoming RPC will be removed.") + .version("4.0.0") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("60m") + + val CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE = + buildStaticConf("spark.connect.session.manager.closedSessionsTombstonesSize") + .internal() + .doc( + "Maximum size of the cache of sessions after which sessions that did not receive any " + + "requests will be removed.") + .version("4.0.0") + .intConf + .createWithDefaultString("1000") + val CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT = buildStaticConf("spark.connect.execute.manager.detachedTimeout") .internal() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index ec57909ad144e..018e293795e9d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.connect.planner -import scala.collection.immutable import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.util.Try @@ -80,6 +79,7 @@ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.storage.CacheId +import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils final case class InvalidCommandInput( @@ -3184,9 +3184,9 @@ class SparkConnectPlanner( case StreamingQueryManagerCommand.CommandCase.ACTIVE => val active_queries = session.streams.active respBuilder.getActiveBuilder.addAllActiveQueries( - immutable.ArraySeq - .unsafeWrapArray(active_queries - .map(query => buildStreamingQueryInstance(query))) + active_queries + .map(query => buildStreamingQueryInstance(query)) + .toImmutableArraySeq .asJava) case StreamingQueryManagerCommand.CommandCase.GET_QUERY => @@ -3265,15 +3265,16 @@ class SparkConnectPlanner( .setGetResourcesCommandResult( proto.GetResourcesCommandResult .newBuilder() - .putAllResources(session.sparkContext.resources.view - .mapValues(resource => - proto.ResourceInformation - .newBuilder() - .setName(resource.name) - .addAllAddresses(immutable.ArraySeq.unsafeWrapArray(resource.addresses).asJava) - .build()) - .toMap - .asJava) + .putAllResources( + session.sparkContext.resources.view + .mapValues(resource => + proto.ResourceInformation + .newBuilder() + .setName(resource.name) + .addAllAddresses(resource.addresses.toImmutableArraySeq.asJava) + .build()) + .toMap + .asJava) .build()) .build()) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index dcced21f37148..792012a682b28 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._ import com.google.common.base.Ticker import com.google.common.cache.CacheBuilder -import org.apache.spark.{JobArtifactSet, SparkException} +import org.apache.spark.{JobArtifactSet, SparkException, SparkSQLException} import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession @@ -40,12 +40,19 @@ import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.SystemClock import org.apache.spark.util.Utils +// Unique key identifying session by combination of user, and session id +case class SessionKey(userId: String, sessionId: String) + /** * Object used to hold the Spark Connect session state. */ case class SessionHolder(userId: String, sessionId: String, session: SparkSession) extends Logging { + @volatile private var lastRpcAccessTime: Option[Long] = None + + @volatile private var isClosing: Boolean = false + private val executions: ConcurrentMap[String, ExecuteHolder] = new ConcurrentHashMap[String, ExecuteHolder]() @@ -73,8 +80,21 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio private[connect] lazy val streamingForeachBatchRunnerCleanerCache = new StreamingForeachBatchHelper.CleanerCache(this) - /** Add ExecuteHolder to this session. Called only by SparkConnectExecutionManager. */ + def key: SessionKey = SessionKey(userId, sessionId) + + /** + * Add ExecuteHolder to this session. + * + * Called only by SparkConnectExecutionManager under executionsLock. + */ private[service] def addExecuteHolder(executeHolder: ExecuteHolder): Unit = { + if (isClosing) { + // Do not accept new executions if the session is closing. + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.SESSION_CLOSED", + messageParameters = Map("handle" -> sessionId)) + } + val oldExecute = executions.putIfAbsent(executeHolder.operationId, executeHolder) if (oldExecute != null) { // the existence of this should alrady be checked by SparkConnectExecutionManager @@ -160,21 +180,55 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ def classloader: ClassLoader = artifactManager.classloader + private[connect] def updateAccessTime(): Unit = { + lastRpcAccessTime = Some(System.currentTimeMillis()) + } + + /** + * Initialize the session. + * + * Called only by SparkConnectSessionManager. + */ private[connect] def initializeSession(): Unit = { + updateAccessTime() eventManager.postStarted() } /** * Expire this session and trigger state cleanup mechanisms. + * + * Called only by SparkConnectSessionManager. */ - private[connect] def expireSession(): Unit = { - logDebug(s"Expiring session with userId: $userId and sessionId: $sessionId") + private[connect] def close(): Unit = { + logInfo(s"Closing session with userId: $userId and sessionId: $sessionId") + + // After isClosing=true, SessionHolder.addExecuteHolder() will not allow new executions for + // this session. Because both SessionHolder.addExecuteHolder() and + // SparkConnectExecutionManager.removeAllExecutionsForSession() are executed under + // executionsLock, this guarantees that removeAllExecutionsForSession triggered below will + // remove all executions and no new executions will be added in the meanwhile. + isClosing = true + + // Note on the below notes about concurrency: + // While closing the session can potentially race with operations started on the session, the + // intended use is that the client session will get closed when it's really not used anymore, + // or that it expires due to inactivity, in which case there should be no races. + + // Clean up all artifacts. + // Note: there can be concurrent AddArtifact calls still adding something. artifactManager.cleanUpResources() - eventManager.postClosed() - // Clean up running queries + + // Clean up running streaming queries. + // Note: there can be concurrent streaming queries being started. SparkConnectService.streamingSessionManager.cleanupRunningQueries(this) streamingForeachBatchRunnerCleanerCache.cleanUpAll() // Clean up any streaming workers. removeAllListeners() // removes all listener and stop python listener processes if necessary. + + // Clean up all executions + // It is guaranteed at this point that no new addExecuteHolder are getting started. + SparkConnectService.executionManager.removeAllExecutionsForSession(this.key) + + eventManager.postClosed() } /** @@ -204,6 +258,10 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } } + /** Get SessionInfo with information about this SessionHolder. */ + def getSessionHolderInfo: SessionHolderInfo = + SessionHolderInfo(userId, sessionId, eventManager.status, lastRpcAccessTime) + /** * Caches given DataFrame with the ID. The cache does not expire. The entry needs to be * explicitly removed by the owners of the DataFrame once it is not needed. @@ -291,7 +349,14 @@ object SessionHolder { userId = "testUser", sessionId = UUID.randomUUID().toString, session = session) - SparkConnectService.putSessionForTesting(ret) + SparkConnectService.sessionManager.putSessionForTesting(ret) ret } } + +/** Basic information about SessionHolder. */ +case class SessionHolderInfo( + userId: String, + sessionId: String, + status: SessionStatus, + lastRpcAccesTime: Option[Long]) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index 3c72548978222..c004358e1cf18 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -95,11 +95,16 @@ private[connect] class SparkConnectExecutionManager() extends Logging { * Remove an ExecuteHolder from this global manager and from its session. Interrupt the * execution if still running, free all resources. */ - private[connect] def removeExecuteHolder(key: ExecuteKey): Unit = { + private[connect] def removeExecuteHolder(key: ExecuteKey, abandoned: Boolean = false): Unit = { var executeHolder: Option[ExecuteHolder] = None executionsLock.synchronized { executeHolder = executions.remove(key) - executeHolder.foreach(e => e.sessionHolder.removeExecuteHolder(e.operationId)) + executeHolder.foreach { e => + if (abandoned) { + abandonedTombstones.put(key, e.getExecuteInfo) + } + e.sessionHolder.removeExecuteHolder(e.operationId) + } if (executions.isEmpty) { lastExecutionTime = Some(System.currentTimeMillis()) } @@ -115,6 +120,17 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } } + private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = { + val sessionExecutionHolders = executionsLock.synchronized { + executions.filter(_._2.sessionHolder.key == key) + } + sessionExecutionHolders.foreach { case (_, executeHolder) => + val info = executeHolder.getExecuteInfo + logInfo(s"Execution $info removed in removeSessionExecutions.") + removeExecuteHolder(executeHolder.key, abandoned = true) + } + } + /** Get info about abandoned execution, if there is one. */ private[connect] def getAbandonedTombstone(key: ExecuteKey): Option[ExecuteInfo] = { Option(abandonedTombstones.getIfPresent(key)) @@ -204,8 +220,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { toRemove.foreach { executeHolder => val info = executeHolder.getExecuteInfo logInfo(s"Found execution $info that was abandoned and expired and will be removed.") - removeExecuteHolder(executeHolder.key) - abandonedTombstones.put(executeHolder.key, info) + removeExecuteHolder(executeHolder.key, abandoned = true) } } logInfo("Finished periodic run of SparkConnectExecutionManager maintenance.") diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala index a3a7815609e40..1ca886960d536 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala @@ -28,8 +28,8 @@ class SparkConnectReleaseExecuteHandler( extends Logging { def handle(v: proto.ReleaseExecuteRequest): Unit = { - val sessionHolder = SparkConnectService - .getIsolatedSession(v.getUserContext.getUserId, v.getSessionId) + val sessionHolder = SparkConnectService.sessionManager + .getIsolatedSession(SessionKey(v.getUserContext.getUserId, v.getSessionId)) val responseBuilder = proto.ReleaseExecuteResponse .newBuilder() diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala new file mode 100644 index 0000000000000..a32852bac45ea --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseSessionHandler.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import io.grpc.stub.StreamObserver + +import org.apache.spark.connect.proto +import org.apache.spark.internal.Logging + +class SparkConnectReleaseSessionHandler( + responseObserver: StreamObserver[proto.ReleaseSessionResponse]) + extends Logging { + + def handle(v: proto.ReleaseSessionRequest): Unit = { + val responseBuilder = proto.ReleaseSessionResponse.newBuilder() + responseBuilder.setSessionId(v.getSessionId) + + // If the session doesn't exist, this will just be a noop. + val key = SessionKey(v.getUserContext.getUserId, v.getSessionId) + SparkConnectService.sessionManager.closeSession(key) + + responseObserver.onNext(responseBuilder.build()) + responseObserver.onCompleted() + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index e82c9cba56264..e4b60eeeff0d6 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -18,13 +18,10 @@ package org.apache.spark.sql.connect.service import java.net.InetSocketAddress -import java.util.UUID -import java.util.concurrent.{Callable, TimeUnit} +import java.util.concurrent.TimeUnit import scala.jdk.CollectionConverters._ -import com.google.common.base.Ticker -import com.google.common.cache.{CacheBuilder, RemovalListener, RemovalNotification} import com.google.protobuf.MessageLite import io.grpc.{BindableService, MethodDescriptor, Server, ServerMethodDefinition, ServerServiceDefinition} import io.grpc.MethodDescriptor.PrototypeMarshaller @@ -34,13 +31,12 @@ import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver import org.apache.commons.lang3.StringUtils -import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException} +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, SparkConnectServiceGrpc} import org.apache.spark.connect.proto.SparkConnectServiceGrpc.AsyncService import org.apache.spark.internal.Logging import org.apache.spark.internal.config.UI.UI_ENABLED -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE} import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab} import org.apache.spark.sql.connect.utils.ErrorUtils @@ -201,6 +197,22 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ sessionId = request.getSessionId) } + /** + * Release session. + */ + override def releaseSession( + request: proto.ReleaseSessionRequest, + responseObserver: StreamObserver[proto.ReleaseSessionResponse]): Unit = { + try { + new SparkConnectReleaseSessionHandler(responseObserver).handle(request) + } catch + ErrorUtils.handleError( + "releaseSession", + observer = responseObserver, + userId = request.getUserContext.getUserId, + sessionId = request.getSessionId) + } + override def fetchErrorDetails( request: proto.FetchErrorDetailsRequest, responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]): Unit = { @@ -268,14 +280,6 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ */ object SparkConnectService extends Logging { - private val CACHE_SIZE = 100 - - private val CACHE_TIMEOUT_SECONDS = 3600 - - // Type alias for the SessionCacheKey. Right now this is a String but allows us to switch to a - // different or complex type easily. - private type SessionCacheKey = (String, String) - private[connect] var server: Server = _ private[connect] var uiTab: Option[SparkConnectServerTab] = None @@ -289,77 +293,18 @@ object SparkConnectService extends Logging { server.getPort } - private val userSessionMapping = - cacheBuilder(CACHE_SIZE, CACHE_TIMEOUT_SECONDS).build[SessionCacheKey, SessionHolder]() - private[connect] lazy val executionManager = new SparkConnectExecutionManager() + private[connect] lazy val sessionManager = new SparkConnectSessionManager() + private[connect] val streamingSessionManager = new SparkConnectStreamingQueryCache() - private class RemoveSessionListener extends RemovalListener[SessionCacheKey, SessionHolder] { - override def onRemoval( - notification: RemovalNotification[SessionCacheKey, SessionHolder]): Unit = { - notification.getValue.expireSession() - } - } - - // Simple builder for creating the cache of Sessions. - private def cacheBuilder(cacheSize: Int, timeoutSeconds: Int): CacheBuilder[Object, Object] = { - var cacheBuilder = CacheBuilder.newBuilder().ticker(Ticker.systemTicker()) - if (cacheSize >= 0) { - cacheBuilder = cacheBuilder.maximumSize(cacheSize) - } - if (timeoutSeconds >= 0) { - cacheBuilder.expireAfterAccess(timeoutSeconds, TimeUnit.SECONDS) - } - cacheBuilder.removalListener(new RemoveSessionListener) - cacheBuilder - } - /** * Based on the userId and sessionId, find or create a new SparkSession. */ def getOrCreateIsolatedSession(userId: String, sessionId: String): SessionHolder = { - getSessionOrDefault( - userId, - sessionId, - () => { - val holder = SessionHolder(userId, sessionId, newIsolatedSession()) - holder.initializeSession() - holder - }) - } - - /** - * Based on the userId and sessionId, find an existing SparkSession or throw error. - */ - def getIsolatedSession(userId: String, sessionId: String): SessionHolder = { - getSessionOrDefault( - userId, - sessionId, - () => { - logDebug(s"Session not found: ($userId, $sessionId)") - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.SESSION_NOT_FOUND", - messageParameters = Map("handle" -> sessionId)) - }) - } - - private def getSessionOrDefault( - userId: String, - sessionId: String, - default: Callable[SessionHolder]): SessionHolder = { - // Validate that sessionId is formatted like UUID before creating session. - try { - UUID.fromString(sessionId).toString - } catch { - case _: IllegalArgumentException => - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.FORMAT", - messageParameters = Map("handle" -> sessionId)) - } - userSessionMapping.get((userId, sessionId), default) + sessionManager.getOrCreateIsolatedSession(SessionKey(userId, sessionId)) } /** @@ -368,24 +313,6 @@ object SparkConnectService extends Logging { */ def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = executionManager.listActiveExecutions - /** - * Used for testing - */ - private[connect] def invalidateAllSessions(): Unit = { - userSessionMapping.invalidateAll() - } - - /** - * Used for testing. - */ - private[connect] def putSessionForTesting(sessionHolder: SessionHolder): Unit = { - userSessionMapping.put((sessionHolder.userId, sessionHolder.sessionId), sessionHolder) - } - - private def newIsolatedSession(): SparkSession = { - SparkSession.active.newSession() - } - private def createListenerAndUI(sc: SparkContext): Unit = { val kvStore = sc.statusStore.store.asInstanceOf[ElementTrackingStore] listener = new SparkConnectServerListener(kvStore, sc.conf) @@ -445,7 +372,7 @@ object SparkConnectService extends Logging { } streamingSessionManager.shutdown() executionManager.shutdown() - userSessionMapping.invalidateAll() + sessionManager.shutdown() uiTab.foreach(_.detach()) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala new file mode 100644 index 0000000000000..5c8e3c611586c --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.service + +import java.util.UUID +import java.util.concurrent.{Callable, TimeUnit} + +import com.google.common.base.Ticker +import com.google.common.cache.{CacheBuilder, RemovalListener, RemovalNotification} + +import org.apache.spark.{SparkEnv, SparkSQLException} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.config.Connect.{CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE, CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT} + +/** + * Global tracker of all SessionHolders holding Spark Connect sessions. + */ +class SparkConnectSessionManager extends Logging { + + private val sessionsLock = new Object + + private val sessionStore = + CacheBuilder + .newBuilder() + .ticker(Ticker.systemTicker()) + .expireAfterAccess( + SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_DEFAULT_SESSION_TIMEOUT), + TimeUnit.MILLISECONDS) + .removalListener(new RemoveSessionListener) + .build[SessionKey, SessionHolder]() + + private val closedSessionsCache = + CacheBuilder + .newBuilder() + .maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE)) + .build[SessionKey, SessionHolderInfo]() + + /** + * Based on the userId and sessionId, find or create a new SparkSession. + */ + private[connect] def getOrCreateIsolatedSession(key: SessionKey): SessionHolder = { + // Lock to guard against concurrent removal and insertion into closedSessionsCache. + sessionsLock.synchronized { + getSession( + key, + Some(() => { + validateSessionCreate(key) + val holder = SessionHolder(key.userId, key.sessionId, newIsolatedSession()) + holder.initializeSession() + holder + })) + } + } + + /** + * Based on the userId and sessionId, find an existing SparkSession or throw error. + */ + private[connect] def getIsolatedSession(key: SessionKey): SessionHolder = { + getSession( + key, + Some(() => { + logDebug(s"Session not found: $key") + if (closedSessionsCache.getIfPresent(key) != null) { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.SESSION_CLOSED", + messageParameters = Map("handle" -> key.sessionId)) + } else { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.SESSION_NOT_FOUND", + messageParameters = Map("handle" -> key.sessionId)) + } + })) + } + + /** + * Based on the userId and sessionId, get an existing SparkSession if present. + */ + private[connect] def getIsolatedSessionIfPresent(key: SessionKey): Option[SessionHolder] = { + Option(getSession(key, None)) + } + + private def getSession( + key: SessionKey, + default: Option[Callable[SessionHolder]]): SessionHolder = { + val session = default match { + case Some(callable) => sessionStore.get(key, callable) + case None => sessionStore.getIfPresent(key) + } + // record access time before returning + session match { + case null => + null + case s: SessionHolder => + s.updateAccessTime() + s + } + } + + def closeSession(key: SessionKey): Unit = { + // Invalidate will trigger RemoveSessionListener + sessionStore.invalidate(key) + } + + private class RemoveSessionListener extends RemovalListener[SessionKey, SessionHolder] { + override def onRemoval(notification: RemovalNotification[SessionKey, SessionHolder]): Unit = { + val sessionHolder = notification.getValue + sessionsLock.synchronized { + // First put into closedSessionsCache, so that it cannot get accidentally recreated by + // getOrCreateIsolatedSession. + closedSessionsCache.put(sessionHolder.key, sessionHolder.getSessionHolderInfo) + } + // Rest of the cleanup outside sessionLock - the session cannot be accessed anymore by + // getOrCreateIsolatedSession. + sessionHolder.close() + } + } + + def shutdown(): Unit = { + sessionsLock.synchronized { + sessionStore.invalidateAll() + closedSessionsCache.invalidateAll() + } + } + + private def newIsolatedSession(): SparkSession = { + SparkSession.active.newSession() + } + + private def validateSessionCreate(key: SessionKey): Unit = { + // Validate that sessionId is formatted like UUID before creating session. + try { + UUID.fromString(key.sessionId).toString + } catch { + case _: IllegalArgumentException => + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.FORMAT", + messageParameters = Map("handle" -> key.sessionId)) + } + // Validate that session with that key has not been already closed. + if (closedSessionsCache.getIfPresent(key) != null) { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.SESSION_CLOSED", + messageParameters = Map("handle" -> key.sessionId)) + } + } + + /** + * Used for testing + */ + private[connect] def invalidateAllSessions(): Unit = { + sessionStore.invalidateAll() + closedSessionsCache.invalidateAll() + } + + /** + * Used for testing. + */ + private[connect] def putSessionForTesting(sessionHolder: SessionHolder): Unit = { + sessionStore.put(sessionHolder.key, sessionHolder) + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 741fa97f17878..744fa3c8aa1a4 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.connect.utils import java.util.UUID import scala.annotation.tailrec -import scala.collection.immutable import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ @@ -41,8 +40,9 @@ import org.apache.spark.api.python.PythonException import org.apache.spark.connect.proto.FetchErrorDetailsResponse import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.config.Connect -import org.apache.spark.sql.connect.service.{ExecuteEventsManager, SessionHolder, SparkConnectService} +import org.apache.spark.sql.connect.service.{ExecuteEventsManager, SessionHolder, SessionKey, SparkConnectService} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ArrayImplicits._ private[connect] object ErrorUtils extends Logging { @@ -91,21 +91,21 @@ private[connect] object ErrorUtils extends Logging { if (serverStackTraceEnabled) { builder.addAllStackTrace( - immutable.ArraySeq - .unsafeWrapArray(currentError.getStackTrace - .map { stackTraceElement => - val stackTraceBuilder = FetchErrorDetailsResponse.StackTraceElement - .newBuilder() - .setDeclaringClass(stackTraceElement.getClassName) - .setMethodName(stackTraceElement.getMethodName) - .setLineNumber(stackTraceElement.getLineNumber) - - if (stackTraceElement.getFileName != null) { - stackTraceBuilder.setFileName(stackTraceElement.getFileName) - } - - stackTraceBuilder.build() - }) + currentError.getStackTrace + .map { stackTraceElement => + val stackTraceBuilder = FetchErrorDetailsResponse.StackTraceElement + .newBuilder() + .setDeclaringClass(stackTraceElement.getClassName) + .setMethodName(stackTraceElement.getMethodName) + .setLineNumber(stackTraceElement.getLineNumber) + + if (stackTraceElement.getFileName != null) { + stackTraceBuilder.setFileName(stackTraceElement.getFileName) + } + + stackTraceBuilder.build() + } + .toImmutableArraySeq .asJava) } @@ -153,7 +153,9 @@ private[connect] object ErrorUtils extends Logging { .build() } - private def buildStatusFromThrowable(st: Throwable, sessionHolder: SessionHolder): RPCStatus = { + private def buildStatusFromThrowable( + st: Throwable, + sessionHolderOpt: Option[SessionHolder]): RPCStatus = { val errorInfo = ErrorInfo .newBuilder() .setReason(st.getClass.getName) @@ -162,20 +164,20 @@ private[connect] object ErrorUtils extends Logging { "classes", JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName)))) - if (sessionHolder.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED)) { + if (sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED))) { // Generate a new unique key for this exception. val errorId = UUID.randomUUID().toString errorInfo.putMetadata("errorId", errorId) - sessionHolder.errorIdToError + sessionHolderOpt.get.errorIdToError .put(errorId, st) } lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st)) val withStackTrace = - if (sessionHolder.session.conf.get( - SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty) { + if (sessionHolderOpt.exists( + _.session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty)) { val maxSize = SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE) errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize)) } else { @@ -215,19 +217,22 @@ private[connect] object ErrorUtils extends Logging { sessionId: String, events: Option[ExecuteEventsManager] = None, isInterrupted: Boolean = false): PartialFunction[Throwable, Unit] = { - val sessionHolder = - SparkConnectService - .getOrCreateIsolatedSession(userId, sessionId) + + // SessionHolder may not be present, e.g. if the session was already closed. + // When SessionHolder is not present error details will not be available for FetchErrorDetails. + val sessionHolderOpt = + SparkConnectService.sessionManager.getIsolatedSessionIfPresent( + SessionKey(userId, sessionId)) val partial: PartialFunction[Throwable, (Throwable, Throwable)] = { case se: SparkException if isPythonExecutionException(se) => ( se, StatusProto.toStatusRuntimeException( - buildStatusFromThrowable(se.getCause, sessionHolder))) + buildStatusFromThrowable(se.getCause, sessionHolderOpt))) case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) => - (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, sessionHolder))) + (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, sessionHolderOpt))) case e: Throwable => ( diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala index 7b02377f4847c..120126f20ec24 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -59,10 +59,6 @@ trait SparkConnectServerTest extends SharedSparkSession { withSparkEnvConfs((Connect.CONNECT_GRPC_BINDING_PORT.key, serverPort.toString)) { SparkConnectService.start(spark.sparkContext) } - // register udf directly on the server, we're not testing client UDFs here... - val serverSession = - SparkConnectService.getOrCreateIsolatedSession(defaultUserId, defaultSessionId).session - serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms })) } override def afterAll(): Unit = { @@ -84,6 +80,7 @@ trait SparkConnectServerTest extends SharedSparkSession { protected def clearAllExecutions(): Unit = { SparkConnectService.executionManager.listExecuteHolders.foreach(_.close()) SparkConnectService.executionManager.periodicMaintenance(0) + SparkConnectService.sessionManager.invalidateAllSessions() assertNoActiveExecutions() } @@ -215,12 +212,24 @@ trait SparkConnectServerTest extends SharedSparkSession { } } + protected def withClient(sessionId: String = defaultSessionId, userId: String = defaultUserId)( + f: SparkConnectClient => Unit): Unit = { + withClient(f, sessionId, userId) + } + protected def withClient(f: SparkConnectClient => Unit): Unit = { + withClient(f, defaultSessionId, defaultUserId) + } + + protected def withClient( + f: SparkConnectClient => Unit, + sessionId: String, + userId: String): Unit = { val client = SparkConnectClient .builder() .port(serverPort) - .sessionId(defaultSessionId) - .userId(defaultUserId) + .sessionId(sessionId) + .userId(userId) .enableReattachableExecute() .build() try f(client) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala index 0e29a07b719af..784b978f447df 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala @@ -347,6 +347,10 @@ class ReattachableExecuteSuite extends SparkConnectServerTest { } test("long sleeping query") { + // register udf directly on the server, we're not testing client UDFs here... + val serverSession = + SparkConnectService.getOrCreateIsolatedSession(defaultUserId, defaultSessionId).session + serverSession.udf.register("sleep", ((ms: Int) => { Thread.sleep(ms); ms })) // query will be sleeping and not returning results, while having multiple reattach withSparkEnvConfs( (Connect.CONNECT_EXECUTE_REATTACHABLE_SENDER_MAX_STREAM_DURATION.key, "1s")) { diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index ce452623e6b84..b314e7d8d4834 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -841,12 +841,12 @@ class SparkConnectServiceSuite spark.sparkContext.addSparkListener(verifyEvents.listener) Utils.tryWithSafeFinally({ f(verifyEvents) - SparkConnectService.invalidateAllSessions() + SparkConnectService.sessionManager.invalidateAllSessions() verifyEvents.onSessionClosed() }) { verifyEvents.waitUntilEmpty() spark.sparkContext.removeSparkListener(verifyEvents.listener) - SparkConnectService.invalidateAllSessions() + SparkConnectService.sessionManager.invalidateAllSessions() SparkConnectPluginRegistry.reset() } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala index 14ecc9a2e95e4..cc0481dab0f4f 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala @@ -16,13 +16,171 @@ */ package org.apache.spark.sql.connect.service +import java.util.UUID + import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkException import org.apache.spark.sql.connect.SparkConnectServerTest class SparkConnectServiceE2ESuite extends SparkConnectServerTest { + // Making results of these queries large enough, so that all the results do not fit in the + // buffers and are not pushed out immediately even when the client doesn't consume them, so that + // even if the connection got closed, the client would see it as succeeded because the results + // were all already in the buffer. + val BIG_ENOUGH_QUERY = "select * from range(1000000)" + + test("ReleaseSession releases all queries and does not allow more requests in the session") { + withClient { client => + val query1 = client.execute(buildPlan(BIG_ENOUGH_QUERY)) + val query2 = client.execute(buildPlan(BIG_ENOUGH_QUERY)) + val query3 = client.execute(buildPlan("select 1")) + // just creating the iterator is lazy, trigger query1 and query2 to be sent. + query1.hasNext + query2.hasNext + Eventually.eventually(timeout(eventuallyTimeout)) { + SparkConnectService.executionManager.listExecuteHolders.length == 2 + } + + // Close session + client.releaseSession() + + // Check that queries get cancelled + Eventually.eventually(timeout(eventuallyTimeout)) { + SparkConnectService.executionManager.listExecuteHolders.length == 0 + // SparkConnectService.sessionManager. + } + + // query1 and query2 could get either an: + // OPERATION_CANCELED if it happens fast - when closing the session interrupted the queries, + // and that error got pushed to the client buffers before the client got disconnected. + // OPERATION_ABANDONED if it happens slow - when closing the session interrupted the client + // RPCs before it pushed out the error above. The client would then get an + // INVALID_CURSOR.DISCONNECTED, which it will retry with a ReattachExecute, and then get an + // INVALID_HANDLE.OPERATION_ABANDONED. + val query1Error = intercept[SparkException] { + while (query1.hasNext) query1.next() + } + assert( + query1Error.getMessage.contains("OPERATION_CANCELED") || + query1Error.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + val query2Error = intercept[SparkException] { + while (query2.hasNext) query2.next() + } + assert( + query2Error.getMessage.contains("OPERATION_CANCELED") || + query2Error.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + + // query3 has not been submitted before, so it should now fail with SESSION_CLOSED + val query3Error = intercept[SparkException] { + query3.hasNext + } + assert(query3Error.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED")) + + // No other requests should be allowed in the session, failing with SESSION_CLOSED + val requestError = intercept[SparkException] { + client.interruptAll() + } + assert(requestError.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED")) + } + } + + private def testReleaseSessionTwoSessions( + sessionIdA: String, + userIdA: String, + sessionIdB: String, + userIdB: String): Unit = { + withClient(sessionId = sessionIdA, userId = userIdA) { clientA => + withClient(sessionId = sessionIdB, userId = userIdB) { clientB => + val queryA = clientA.execute(buildPlan(BIG_ENOUGH_QUERY)) + val queryB = clientB.execute(buildPlan(BIG_ENOUGH_QUERY)) + // just creating the iterator is lazy, trigger query1 and query2 to be sent. + queryA.hasNext + queryB.hasNext + Eventually.eventually(timeout(eventuallyTimeout)) { + SparkConnectService.executionManager.listExecuteHolders.length == 2 + } + // Close session A + clientA.releaseSession() + + // A's query gets kicked out. + Eventually.eventually(timeout(eventuallyTimeout)) { + SparkConnectService.executionManager.listExecuteHolders.length == 1 + } + val queryAError = intercept[SparkException] { + while (queryA.hasNext) queryA.next() + } + assert( + queryAError.getMessage.contains("OPERATION_CANCELED") || + queryAError.getMessage.contains("INVALID_HANDLE.OPERATION_ABANDONED")) + + // B's query can run. + while (queryB.hasNext) queryB.next() + + // B can submit more queries. + val queryB2 = clientB.execute(buildPlan("SELECT 1")) + while (queryB2.hasNext) queryB2.next() + // A can't submit more queries. + val queryA2 = clientA.execute(buildPlan("SELECT 1")) + val queryA2Error = intercept[SparkException] { + clientA.interruptAll() + } + assert(queryA2Error.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED")) + } + } + } + + test("ReleaseSession for different user_id with same session_id do not affect each other") { + testReleaseSessionTwoSessions(defaultSessionId, "A", defaultSessionId, "B") + } + + test("ReleaseSession for different session_id with same user_id do not affect each other") { + val sessionIdA = UUID.randomUUID.toString() + val sessionIdB = UUID.randomUUID.toString() + testReleaseSessionTwoSessions(sessionIdA, "X", sessionIdB, "X") + } + + test("ReleaseSession: can't create a new session with the same id and user after release") { + val sessionId = UUID.randomUUID.toString() + val userId = "Y" + withClient(sessionId = sessionId, userId = userId) { client => + // this will create the session, and then ReleaseSession at the end of withClient. + val query = client.execute(buildPlan("SELECT 1")) + query.hasNext // trigger execution + client.releaseSession() + } + withClient(sessionId = sessionId, userId = userId) { client => + // shall not be able to create a new session with the same id and user. + val query = client.execute(buildPlan("SELECT 1")) + val queryError = intercept[SparkException] { + while (query.hasNext) query.next() + } + assert(queryError.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED")) + } + } + + test("ReleaseSession: session with different session_id or user_id allowed after release") { + val sessionId = UUID.randomUUID.toString() + val userId = "Y" + withClient(sessionId = sessionId, userId = userId) { client => + val query = client.execute(buildPlan("SELECT 1")) + query.hasNext // trigger execution + client.releaseSession() + } + withClient(sessionId = UUID.randomUUID.toString, userId = userId) { client => + val query = client.execute(buildPlan("SELECT 1")) + query.hasNext // trigger execution + client.releaseSession() + } + withClient(sessionId = sessionId, userId = "YY") { client => + val query = client.execute(buildPlan("SELECT 1")) + query.hasNext // trigger execution + client.releaseSession() + } + } + test("SPARK-45133 query should reach FINISHED state when results are not consumed") { withRawBlockingStub { stub => val iter = diff --git a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt index 0915e73385686..19bd2faa9a5a5 100644 --- a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt +++ b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt @@ -2,26 +2,26 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 21+35 on Linux 5.15.0-1046-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 21.0.1+12-LTS on Linux 5.15.0-1050-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 2649 2653 6 0.0 264919.3 1.0X -Compression 10000 times at level 2 without buffer pool 2785 2785 1 0.0 278455.9 1.0X -Compression 10000 times at level 3 without buffer pool 3078 3082 6 0.0 307845.7 0.9X -Compression 10000 times at level 1 with buffer pool 2353 2378 35 0.0 235340.6 1.1X -Compression 10000 times at level 2 with buffer pool 2462 2466 6 0.0 246194.4 1.1X -Compression 10000 times at level 3 with buffer pool 2761 2765 6 0.0 276095.6 1.0X +Compression 10000 times at level 1 without buffer pool 1817 1819 3 0.0 181709.6 1.0X +Compression 10000 times at level 2 without buffer pool 2081 2083 3 0.0 208053.7 0.9X +Compression 10000 times at level 3 without buffer pool 2288 2290 3 0.0 228795.4 0.8X +Compression 10000 times at level 1 with buffer pool 1997 1998 1 0.0 199686.9 0.9X +Compression 10000 times at level 2 with buffer pool 2062 2063 1 0.0 206209.3 0.9X +Compression 10000 times at level 3 with buffer pool 2243 2243 1 0.0 224271.8 0.8X -OpenJDK 64-Bit Server VM 21+35 on Linux 5.15.0-1046-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 21.0.1+12-LTS on Linux 5.15.0-1050-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 2644 2644 0 0.0 264430.7 1.0X -Decompression 10000 times from level 2 without buffer pool 2597 2608 15 0.0 259715.2 1.0X -Decompression 10000 times from level 3 without buffer pool 2586 2600 20 0.0 258633.7 1.0X -Decompression 10000 times from level 1 with buffer pool 2290 2296 9 0.0 228957.6 1.2X -Decompression 10000 times from level 2 with buffer pool 2315 2319 6 0.0 231535.4 1.1X -Decompression 10000 times from level 3 with buffer pool 2283 2302 27 0.0 228308.3 1.2X +Decompression 10000 times from level 1 without buffer pool 1980 1981 1 0.0 197970.8 1.0X +Decompression 10000 times from level 2 without buffer pool 1978 1979 1 0.0 197813.4 1.0X +Decompression 10000 times from level 3 without buffer pool 1981 1983 2 0.0 198141.7 1.0X +Decompression 10000 times from level 1 with buffer pool 1825 1827 3 0.0 182475.2 1.1X +Decompression 10000 times from level 2 with buffer pool 1827 1827 0 0.0 182667.1 1.1X +Decompression 10000 times from level 3 with buffer pool 1826 1826 0 0.0 182579.8 1.1X diff --git a/core/benchmarks/ZStandardBenchmark-results.txt b/core/benchmarks/ZStandardBenchmark-results.txt index 5299a52dc7b8a..6f7efdd6c94c5 100644 --- a/core/benchmarks/ZStandardBenchmark-results.txt +++ b/core/benchmarks/ZStandardBenchmark-results.txt @@ -2,26 +2,26 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.8+7-LTS on Linux 5.15.0-1046-azure +OpenJDK 64-Bit Server VM 17.0.9+8-LTS on Linux 5.15.0-1050-azure Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 2800 2801 2 0.0 279995.2 1.0X -Compression 10000 times at level 2 without buffer pool 2832 2846 19 0.0 283227.4 1.0X -Compression 10000 times at level 3 without buffer pool 2978 3003 35 0.0 297782.6 0.9X -Compression 10000 times at level 1 with buffer pool 2650 2652 2 0.0 265042.4 1.1X -Compression 10000 times at level 2 with buffer pool 2684 2688 5 0.0 268419.4 1.0X -Compression 10000 times at level 3 with buffer pool 2811 2816 8 0.0 281069.3 1.0X +Compression 10000 times at level 1 without buffer pool 2786 2787 2 0.0 278560.1 1.0X +Compression 10000 times at level 2 without buffer pool 2831 2833 3 0.0 283091.0 1.0X +Compression 10000 times at level 3 without buffer pool 2958 2959 2 0.0 295806.3 0.9X +Compression 10000 times at level 1 with buffer pool 211 214 4 0.0 21145.3 13.2X +Compression 10000 times at level 2 with buffer pool 253 255 1 0.0 25328.1 11.0X +Compression 10000 times at level 3 with buffer pool 370 371 1 0.0 37046.1 7.5X -OpenJDK 64-Bit Server VM 17.0.8+7-LTS on Linux 5.15.0-1046-azure +OpenJDK 64-Bit Server VM 17.0.9+8-LTS on Linux 5.15.0-1050-azure Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 2747 2752 7 0.0 274669.2 1.0X -Decompression 10000 times from level 2 without buffer pool 2743 2748 8 0.0 274265.3 1.0X -Decompression 10000 times from level 3 without buffer pool 2743 2750 10 0.0 274344.4 1.0X -Decompression 10000 times from level 1 with buffer pool 2608 2608 0 0.0 260803.0 1.1X -Decompression 10000 times from level 2 with buffer pool 2608 2608 0 0.0 260804.2 1.1X -Decompression 10000 times from level 3 with buffer pool 2605 2607 3 0.0 260514.9 1.1X +Decompression 10000 times from level 1 without buffer pool 2745 2748 5 0.0 274454.0 1.0X +Decompression 10000 times from level 2 without buffer pool 2744 2745 1 0.0 274438.1 1.0X +Decompression 10000 times from level 3 without buffer pool 2746 2746 1 0.0 274586.0 1.0X +Decompression 10000 times from level 1 with buffer pool 2587 2588 1 0.0 258707.4 1.1X +Decompression 10000 times from level 2 with buffer pool 2586 2586 1 0.0 258566.8 1.1X +Decompression 10000 times from level 3 with buffer pool 2589 2589 0 0.0 258870.6 1.1X diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index a7e4c1fbab295..b0ee6018970ab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -90,7 +90,7 @@ private[spark] class StandaloneAppClient( case e: Exception => logWarning("Failed to connect to master", e) markDisconnected() - stop() + this.stop() } } @@ -168,7 +168,7 @@ private[spark] class StandaloneAppClient( case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) - stop() + this.stop() case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = s"$appId/$id" @@ -203,7 +203,7 @@ private[spark] class StandaloneAppClient( markDead("Application has been stopped.") sendToMaster(UnregisterApplication(appId.get)) context.reply(true) - stop() + this.stop() case r: RequestExecutors => master match { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 29022c7419b4b..058b944c591ad 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -53,6 +53,9 @@ private[deploy] class Master( private val forwardMessageThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") + private val driverIdPattern = conf.get(DRIVER_ID_PATTERN) + private val appIdPattern = conf.get(APP_ID_PATTERN) + // For application IDs private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) @@ -1150,7 +1153,7 @@ private[deploy] class Master( /** Generate a new app ID given an app's submission date */ private def newApplicationId(submitDate: Date): String = { - val appId = "app-%s-%04d".format(createDateFormat.format(submitDate), nextAppNumber) + val appId = appIdPattern.format(createDateFormat.format(submitDate), nextAppNumber) nextAppNumber += 1 appId } @@ -1175,7 +1178,7 @@ private[deploy] class Master( } private def newDriverId(submitDate: Date): String = { - val appId = "driver-%s-%04d".format(createDateFormat.format(submitDate), nextDriverNumber) + val appId = driverIdPattern.format(createDateFormat.format(submitDate), nextDriverNumber) nextDriverNumber += 1 appId } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 48c0c9601c14b..cb325b37958ec 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -98,10 +98,15 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { val state = getMasterState - val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory", "Resources") + val showResourceColumn = state.workers.filter(_.resourcesInfoUsed.nonEmpty).nonEmpty + val workerHeaders = if (showResourceColumn) { + Seq("Worker Id", "Address", "State", "Cores", "Memory", "Resources") + } else { + Seq("Worker Id", "Address", "State", "Cores", "Memory") + } val workers = state.workers.sortBy(_.id) val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE) - val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) + val workerTable = UIUtils.listingTable(workerHeaders, workerRow(showResourceColumn), workers) val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Executor", "Resources Per Executor", "Submitted Time", "User", "State", "Duration") @@ -256,7 +261,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { UIUtils.basicSparkPage(request, content, "Spark Master at " + state.uri) } - private def workerRow(worker: WorkerInfo): Seq[Node] = { + private def workerRow(showResourceColumn: Boolean): WorkerInfo => Seq[Node] = worker => { { @@ -276,7 +281,9 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {Utils.megabytesToString(worker.memory)} ({Utils.megabytesToString(worker.memoryUsed)} Used) - {formatWorkerResourcesDetails(worker)} + {if (showResourceColumn) { + {formatWorkerResourcesDetails(worker)} + }} } diff --git a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala index aaeb37a17249a..c6ccf9550bc91 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/Deploy.scala @@ -82,4 +82,21 @@ private[spark] object Deploy { .checkValue(_ > 0, "The maximum number of running drivers should be positive.") .createWithDefault(Int.MaxValue) + val DRIVER_ID_PATTERN = ConfigBuilder("spark.deploy.driverIdPattern") + .doc("The pattern for driver ID generation based on Java `String.format` method. " + + "The default value is `driver-%s-%04d` which represents the existing driver id string " + + ", e.g., `driver-20231031224459-0019`. Please be careful to generate unique IDs") + .version("4.0.0") + .stringConf + .checkValue(!_.format("20231101000000", 0).exists(_.isWhitespace), "Whitespace is not allowed.") + .createWithDefault("driver-%s-%04d") + + val APP_ID_PATTERN = ConfigBuilder("spark.deploy.appIdPattern") + .doc("The pattern for app ID generation based on Java `String.format` method.. " + + "The default value is `app-%s-%04d` which represents the existing app id string, " + + "e.g., `app-20231031224509-0008`. Plesae be careful to generate unique IDs.") + .version("4.0.0") + .stringConf + .checkValue(!_.format("20231101000000", 0).exists(_.isWhitespace), "Whitespace is not allowed.") + .createWithDefault("app-%s-%04d") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index c49b2411e7635..e02dd27937062 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -319,7 +319,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case StopDriver => context.reply(true) - stop() + this.stop() case UpdateExecutorsLogLevel(logLevel) => currentLogLevel = Some(logLevel) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index f8bd73e65617f..2096da2fca02d 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -63,7 +63,7 @@ private[spark] class DiskBlockObjectWriter( */ private trait ManualCloseOutputStream extends OutputStream { abstract override def close(): Unit = { - flush() + this.flush() } def manualClose(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index e3792eb0d237b..518c0592488fc 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -342,7 +342,7 @@ abstract class SparkFunSuite sqlState: Option[String] = None, parameters: Map[String, String] = Map.empty, matchPVals: Boolean = false, - queryContext: Array[QueryContext] = Array.empty): Unit = { + queryContext: Array[ExpectedContext] = Array.empty): Unit = { assert(exception.getErrorClass === errorClass) sqlState.foreach(state => assert(exception.getSqlState === state)) val expectedParameters = exception.getMessageParameters.asScala @@ -364,16 +364,25 @@ abstract class SparkFunSuite val actualQueryContext = exception.getQueryContext() assert(actualQueryContext.length === queryContext.length, "Invalid length of the query context") actualQueryContext.zip(queryContext).foreach { case (actual, expected) => - assert(actual.objectType() === expected.objectType(), - "Invalid objectType of a query context Actual:" + actual.toString) - assert(actual.objectName() === expected.objectName(), - "Invalid objectName of a query context. Actual:" + actual.toString) - assert(actual.startIndex() === expected.startIndex(), - "Invalid startIndex of a query context. Actual:" + actual.toString) - assert(actual.stopIndex() === expected.stopIndex(), - "Invalid stopIndex of a query context. Actual:" + actual.toString) - assert(actual.fragment() === expected.fragment(), - "Invalid fragment of a query context. Actual:" + actual.toString) + assert(actual.contextType() === expected.contextType, + "Invalid contextType of a query context Actual:" + actual.toString) + if (actual.contextType() == QueryContextType.SQL) { + assert(actual.objectType() === expected.objectType, + "Invalid objectType of a query context Actual:" + actual.toString) + assert(actual.objectName() === expected.objectName, + "Invalid objectName of a query context. Actual:" + actual.toString) + assert(actual.startIndex() === expected.startIndex, + "Invalid startIndex of a query context. Actual:" + actual.toString) + assert(actual.stopIndex() === expected.stopIndex, + "Invalid stopIndex of a query context. Actual:" + actual.toString) + assert(actual.fragment() === expected.fragment, + "Invalid fragment of a query context. Actual:" + actual.toString) + } else if (actual.contextType() == QueryContextType.DataFrame) { + assert(actual.fragment() === expected.fragment, + "Invalid code fragment of a query context. Actual:" + actual.toString) + assert(actual.callSite().matches(expected.callSitePattern), + "Invalid callSite of a query context. Actual:" + actual.toString) + } } } @@ -389,21 +398,21 @@ abstract class SparkFunSuite errorClass: String, sqlState: String, parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, Some(sqlState), parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, errorClass: String, parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, None, parameters, false, Array(context)) protected def checkError( exception: SparkThrowable, errorClass: String, sqlState: String, - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, None, Map.empty, false, Array(context)) protected def checkError( @@ -411,7 +420,7 @@ abstract class SparkFunSuite errorClass: String, sqlState: Option[String], parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, sqlState, parameters, false, Array(context)) @@ -426,7 +435,7 @@ abstract class SparkFunSuite errorClass: String, sqlState: Option[String], parameters: Map[String, String], - context: QueryContext): Unit = + context: ExpectedContext): Unit = checkError(exception, errorClass, sqlState, parameters, matchPVals = true, Array(context)) @@ -453,16 +462,33 @@ abstract class SparkFunSuite parameters = Map("relationName" -> tableName)) case class ExpectedContext( + contextType: QueryContextType, objectType: String, objectName: String, startIndex: Int, stopIndex: Int, - fragment: String) extends QueryContext + fragment: String, + callSitePattern: String + ) object ExpectedContext { def apply(fragment: String, start: Int, stop: Int): ExpectedContext = { ExpectedContext("", "", start, stop, fragment) } + + def apply( + objectType: String, + objectName: String, + startIndex: Int, + stopIndex: Int, + fragment: String): ExpectedContext = { + new ExpectedContext(QueryContextType.SQL, objectType, objectName, startIndex, stopIndex, + fragment, "") + } + + def apply(fragment: String, callSitePattern: String): ExpectedContext = { + new ExpectedContext(QueryContextType.DataFrame, "", "", -1, -1, fragment, callSitePattern) + } } class LogAppender(msg: String = "", maxEvents: Int = 1000) diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 9f32d81f1ae3d..0206205c353a1 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -503,11 +503,14 @@ class SparkThrowableSuite extends SparkFunSuite { test("Get message in the specified format") { import ErrorMessageFormat._ class TestQueryContext extends QueryContext { + override val contextType = QueryContextType.SQL override val objectName = "v1" override val objectType = "VIEW" override val startIndex = 2 override val stopIndex = -1 override val fragment = "1 / 0" + override def callSite: String = throw new UnsupportedOperationException + override val summary = "" } val e = new SparkArithmeticException( errorClass = "DIVIDE_BY_ZERO", @@ -577,6 +580,54 @@ class SparkThrowableSuite extends SparkFunSuite { | "message" : "Test message" | } |}""".stripMargin) + + class TestQueryContext2 extends QueryContext { + override val contextType = QueryContextType.DataFrame + override def objectName: String = throw new UnsupportedOperationException + override def objectType: String = throw new UnsupportedOperationException + override def startIndex: Int = throw new UnsupportedOperationException + override def stopIndex: Int = throw new UnsupportedOperationException + override val fragment: String = "div" + override val callSite: String = "SimpleApp$.main(SimpleApp.scala:9)" + override val summary = "" + } + val e4 = new SparkArithmeticException( + errorClass = "DIVIDE_BY_ZERO", + messageParameters = Map("config" -> "CONFIG"), + context = Array(new TestQueryContext2), + summary = "Query summary") + + assert(SparkThrowableHelper.getMessage(e4, PRETTY) === + "[DIVIDE_BY_ZERO] Division by zero. Use `try_divide` to tolerate divisor being 0 " + + "and return NULL instead. If necessary set CONFIG to \"false\" to bypass this error." + + " SQLSTATE: 22012\nQuery summary") + // scalastyle:off line.size.limit + assert(SparkThrowableHelper.getMessage(e4, MINIMAL) === + """{ + | "errorClass" : "DIVIDE_BY_ZERO", + | "sqlState" : "22012", + | "messageParameters" : { + | "config" : "CONFIG" + | }, + | "queryContext" : [ { + | "fragment" : "div", + | "callSite" : "SimpleApp$.main(SimpleApp.scala:9)" + | } ] + |}""".stripMargin) + assert(SparkThrowableHelper.getMessage(e4, STANDARD) === + """{ + | "errorClass" : "DIVIDE_BY_ZERO", + | "messageTemplate" : "Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL instead. If necessary set to \"false\" to bypass this error.", + | "sqlState" : "22012", + | "messageParameters" : { + | "config" : "CONFIG" + | }, + | "queryContext" : [ { + | "fragment" : "div", + | "callSite" : "SimpleApp$.main(SimpleApp.scala:9)" + | } ] + |}""".stripMargin) + // scalastyle:on line.size.limit } test("overwrite error classes") { diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index fc6c7d267e6a5..e8615cdbdd559 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -802,6 +802,8 @@ class MasterSuite extends SparkFunSuite private val _waitingDrivers = PrivateMethod[mutable.ArrayBuffer[DriverInfo]](Symbol("waitingDrivers")) private val _state = PrivateMethod[RecoveryState.Value](Symbol("state")) + private val _newDriverId = PrivateMethod[String](Symbol("newDriverId")) + private val _newApplicationId = PrivateMethod[String](Symbol("newApplicationId")) private val workerInfo = makeWorkerInfo(4096, 10) private val workerInfos = Array(workerInfo, workerInfo, workerInfo) @@ -1236,6 +1238,34 @@ class MasterSuite extends SparkFunSuite private def getState(master: Master): RecoveryState.Value = { master.invokePrivate(_state()) } + + test("SPARK-45753: Support driver id pattern") { + val master = makeMaster(new SparkConf().set(DRIVER_ID_PATTERN, "my-driver-%2$05d")) + val submitDate = new Date() + assert(master.invokePrivate(_newDriverId(submitDate)) === "my-driver-00000") + assert(master.invokePrivate(_newDriverId(submitDate)) === "my-driver-00001") + } + + test("SPARK-45753: Prevent invalid driver id patterns") { + val m = intercept[IllegalArgumentException] { + makeMaster(new SparkConf().set(DRIVER_ID_PATTERN, "my driver")) + }.getMessage + assert(m.contains("Whitespace is not allowed")) + } + + test("SPARK-45754: Support app id pattern") { + val master = makeMaster(new SparkConf().set(APP_ID_PATTERN, "my-app-%2$05d")) + val submitDate = new Date() + assert(master.invokePrivate(_newApplicationId(submitDate)) === "my-app-00000") + assert(master.invokePrivate(_newApplicationId(submitDate)) === "my-app-00001") + } + + test("SPARK-45754: Prevent invalid app id patterns") { + val m = intercept[IllegalArgumentException] { + makeMaster(new SparkConf().set(APP_ID_PATTERN, "my app")) + }.getMessage + assert(m.contains("Whitespace is not allowed")) + } } private class FakeRecoveryModeFactory(conf: SparkConf, ser: serializer.Serializer) diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala index 3ef4da6d3d3f1..28af0656869b3 100644 --- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala @@ -326,11 +326,11 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite def getFakeTaskRunner(taskDescription: TaskDescription): Executor#TaskRunner = { new executor.TaskRunner(backend, taskDescription, None) { override def run(): Unit = { - logInfo(s"task ${taskDescription.taskId} runs.") + logInfo(s"task ${this.taskDescription.taskId} runs.") } override def kill(interruptThread: Boolean, reason: String): Unit = { - logInfo(s"task ${taskDescription.taskId} killed.") + logInfo(s"task ${this.taskDescription.taskId} killed.") } } } @@ -434,13 +434,13 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite def getFakeTaskRunner(taskDescription: TaskDescription): Executor#TaskRunner = { new executor.TaskRunner(backend, taskDescription, None) { override def run(): Unit = { - tasksExecuted.put(taskDescription.taskId, true) - logInfo(s"task ${taskDescription.taskId} runs.") + tasksExecuted.put(this.taskDescription.taskId, true) + logInfo(s"task ${this.taskDescription.taskId} runs.") } override def kill(interruptThread: Boolean, reason: String): Unit = { - logInfo(s"task ${taskDescription.taskId} killed.") - tasksKilled.put(taskDescription.taskId, true) + logInfo(s"task ${this.taskDescription.taskId} killed.") + tasksKilled.put(this.taskDescription.taskId, true) } } } @@ -523,13 +523,13 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite def getFakeTaskRunner(taskDescription: TaskDescription): Executor#TaskRunner = { new executor.TaskRunner(backend, taskDescription, None) { override def run(): Unit = { - tasksExecuted.put(taskDescription.taskId, true) - logInfo(s"task ${taskDescription.taskId} runs.") + tasksExecuted.put(this.taskDescription.taskId, true) + logInfo(s"task ${this.taskDescription.taskId} runs.") } override def kill(interruptThread: Boolean, reason: String): Unit = { - logInfo(s"task ${taskDescription.taskId} killed.") - tasksKilled.put(taskDescription.taskId, true) + logInfo(s"task ${this.taskDescription.taskId} killed.") + tasksKilled.put(this.taskDescription.taskId, true) } } } diff --git a/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala b/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala new file mode 100644 index 0000000000000..135af550c4b39 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/ArrayImplicitsSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.immutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.ArrayImplicits._ + +class ArrayImplicitsSuite extends SparkFunSuite { + + test("Int Array") { + val data = Array(1, 2, 3) + val arraySeq = data.toImmutableArraySeq + assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofInt]) + assert(arraySeq.length === 3) + assert(arraySeq.unsafeArray.sameElements(data)) + } + + test("TestClass Array") { + val data = Array(TestClass(1), TestClass(2), TestClass(3)) + val arraySeq = data.toImmutableArraySeq + assert(arraySeq.getClass === classOf[immutable.ArraySeq.ofRef[TestClass]]) + assert(arraySeq.length === 3) + assert(arraySeq.unsafeArray.sameElements(data)) + } + + test("Null Array") { + val data: Array[Int] = null + val arraySeq = data.toImmutableArraySeq + assert(arraySeq == null) + } + + case class TestClass(i: Int) +} diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index b5e345ecf0eb8..6364ec48fb664 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -43,7 +43,7 @@ commons-compiler/3.1.9//commons-compiler-3.1.9.jar commons-compress/1.24.0//commons-compress-1.24.0.jar commons-crypto/1.1.0//commons-crypto-1.1.0.jar commons-dbcp/1.4//commons-dbcp-1.4.jar -commons-io/2.14.0//commons-io-2.14.0.jar +commons-io/2.15.0//commons-io-2.15.0.jar commons-lang/2.6//commons-lang-2.6.jar commons-lang3/3.13.0//commons-lang3-3.13.0.jar commons-logging/1.1.3//commons-logging-1.1.3.jar @@ -181,11 +181,11 @@ log4j-core/2.21.0//log4j-core-2.21.0.jar log4j-slf4j2-impl/2.21.0//log4j-slf4j2-impl-2.21.0.jar logging-interceptor/3.12.12//logging-interceptor-3.12.12.jar lz4-java/1.8.0//lz4-java-1.8.0.jar -metrics-core/4.2.19//metrics-core-4.2.19.jar -metrics-graphite/4.2.19//metrics-graphite-4.2.19.jar -metrics-jmx/4.2.19//metrics-jmx-4.2.19.jar -metrics-json/4.2.19//metrics-json-4.2.19.jar -metrics-jvm/4.2.19//metrics-jvm-4.2.19.jar +metrics-core/4.2.21//metrics-core-4.2.21.jar +metrics-graphite/4.2.21//metrics-graphite-4.2.21.jar +metrics-jmx/4.2.21//metrics-jmx-4.2.21.jar +metrics-json/4.2.21//metrics-json-4.2.21.jar +metrics-jvm/4.2.21//metrics-jvm-4.2.21.jar minlog/1.3.0//minlog-1.3.0.jar netty-all/4.1.100.Final//netty-all-4.1.100.Final.jar netty-buffer/4.1.100.Final//netty-buffer-4.1.100.Final.jar @@ -264,4 +264,4 @@ xz/1.9//xz-1.9.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper-jute/3.6.3//zookeeper-jute-3.6.3.jar zookeeper/3.6.3//zookeeper-3.6.3.jar -zstd-jni/1.5.5-5//zstd-jni-1.5.5-5.jar +zstd-jni/1.5.5-7//zstd-jni-1.5.5-7.jar diff --git a/docs/configuration.md b/docs/configuration.md index bd908f3b34d78..60cad24e71c44 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -666,7 +666,7 @@ Apart from these, the following properties are also available, and may be useful spark.executor.logs.rolling.maxRetainedFiles - (none) + -1 Sets the number of latest rolling log files that are going to be retained by the system. Older log files will be deleted. Disabled by default. diff --git a/docs/sql-error-conditions-invalid-handle-error-class.md b/docs/sql-error-conditions-invalid-handle-error-class.md index c4cbb48035ff5..14526cd53724f 100644 --- a/docs/sql-error-conditions-invalid-handle-error-class.md +++ b/docs/sql-error-conditions-invalid-handle-error-class.md @@ -45,6 +45,10 @@ Operation not found. Session already exists. +## SESSION_CLOSED + +Session was closed. + ## SESSION_NOT_FOUND Session not found. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala index 1eed97a8d4f65..2f3f396730be2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -17,12 +17,12 @@ package org.apache.spark.mllib.api.python -import scala.collection.immutable import scala.jdk.CollectionConverters._ import org.apache.spark.SparkContext import org.apache.spark.mllib.clustering.GaussianMixtureModel import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.util.ArrayImplicits._ /** * Wrapper around GaussianMixtureModel to provide helper methods in Python @@ -38,7 +38,7 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { val modelGaussians = model.gaussians.map { gaussian => Array[Any](gaussian.mu, gaussian.sigma) } - SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(modelGaussians).asJava) + SerDe.dumps(modelGaussians.toImmutableArraySeq.asJava) } def predictSoft(point: Vector): Vector = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala index b919b0a8c3f2e..6a6c6cf6bcfb3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala @@ -16,12 +16,12 @@ */ package org.apache.spark.mllib.api.python -import scala.collection.immutable import scala.jdk.CollectionConverters._ import org.apache.spark.SparkContext import org.apache.spark.mllib.clustering.LDAModel import org.apache.spark.mllib.linalg.Matrix +import org.apache.spark.util.ArrayImplicits._ /** * Wrapper around LDAModel to provide helper methods in Python @@ -36,11 +36,11 @@ private[python] class LDAModelWrapper(model: LDAModel) { def describeTopics(maxTermsPerTopic: Int): Array[Byte] = { val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) => - val jTerms = immutable.ArraySeq.unsafeWrapArray(terms).asJava - val jTermWeights = immutable.ArraySeq.unsafeWrapArray(termWeights).asJava + val jTerms = terms.toImmutableArraySeq.asJava + val jTermWeights = termWeights.toImmutableArraySeq.asJava Array[Any](jTerms, jTermWeights) } - SerDe.dumps(immutable.ArraySeq.unsafeWrapArray(topics).asJava) + SerDe.dumps(topics.toImmutableArraySeq.asJava) } def save(sc: SparkContext, path: String): Unit = model.save(sc, path) diff --git a/pom.xml b/pom.xml index 9f550456cbcbe..2e0c95516c177 100644 --- a/pom.xml +++ b/pom.xml @@ -156,7 +156,7 @@ If you change codahale.metrics.version, you also need to change the link to metrics.dropwizard.io in docs/monitoring.md. --> - 4.2.19 + 4.2.21 1.11.3 1.12.0 @@ -192,7 +192,7 @@ 3.0.3 1.16.0 1.24.0 - 2.14.0 + 2.15.0 2.6 @@ -799,7 +799,7 @@ com.github.luben zstd-jni - 1.5.5-5 + 1.5.5-7 com.clearspring.analytics @@ -2662,6 +2662,10 @@ javax.annotation javax.annotation-api + + com.github.luben + zstd-jni + @@ -2985,13 +2989,6 @@ SPARK-40497 Upgrade Scala to 2.13.11 and suppress `Implicit definition should have explicit type` --> -Wconf:msg=Implicit definition should have explicit type:s - - -Wconf:msg=legacy-binding:s diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 10864390e3fc7..c0275e162722a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,14 @@ object MimaExcludes { // [SPARK-45427][CORE] Add RPC SSL settings to SSLOptions and SparkTransportConf ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.SparkTransportConf.fromSparkConf"), // [SPARK-45136][CONNECT] Enhance ClosureCleaner with Ammonite support - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.MethodIdentifier$") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.MethodIdentifier$"), + // [SPARK-45022][SQL] Provide context for dataset API errors + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.contextType"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.code"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.callSite"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.QueryContext.summary"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.types.Decimal.fromStringANSI$default$3"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.types.Decimal.fromStringANSI") ) // Default exclude rules diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d9d4a836ab5d4..d76af6a06cfd5 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -249,11 +249,6 @@ object SparkBuild extends PomBuild { "-Wconf:cat=deprecation&msg=procedure syntax is deprecated:e", // SPARK-40497 Upgrade Scala to 2.13.11 and suppress `Implicit definition should have explicit type` "-Wconf:msg=Implicit definition should have explicit type:s", - // SPARK-45331 Upgrade Scala to 2.13.12 and suppress "In Scala 2, symbols inherited - // from a superclass shadow symbols defined in an outer scope. Such references are - // ambiguous in Scala 3. To continue using the inherited symbol, write `this.stop`. - // Or use `-Wconf:msg=legacy-binding:s` to silence this warning. [quickfixable]" - "-Wconf:msg=legacy-binding:s", // SPARK-45627 Symbol literals are deprecated in Scala 2.13 and it's a compile error in Scala 3. "-Wconf:cat=deprecation&msg=symbol literal is deprecated:e", // SPARK-45627 `enum`, `export` and `given` will become keywords in Scala 3, diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 318f7d7ade4a2..11a1112ad1fe7 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -19,7 +19,7 @@ "SparkConnectClient", ] -from pyspark.loose_version import LooseVersion + from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) @@ -61,6 +61,7 @@ from google.protobuf import text_format from google.rpc import error_details_pb2 +from pyspark.loose_version import LooseVersion from pyspark.version import __version__ from pyspark.resource.information import ResourceInformation from pyspark.sql.connect.client.artifact import ArtifactManager @@ -1471,6 +1472,26 @@ def interrupt_operation(self, op_id: str) -> Optional[List[str]]: except Exception as error: self._handle_error(error) + def release_session(self) -> None: + req = pb2.ReleaseSessionRequest() + req.session_id = self._session_id + req.client_type = self._builder.userAgent + if self._user_id: + req.user_context.user_id = self._user_id + try: + for attempt in self._retrying(): + with attempt: + resp = self._stub.ReleaseSession(req, metadata=self._builder.metadata()) + if resp.session_id != self._session_id: + raise SparkConnectException( + "Received incorrect session identifier for request:" + f"{resp.session_id} != {self._session_id}" + ) + return + raise SparkConnectException("Invalid state during retry exception handling.") + except Exception as error: + self._handle_error(error) + def add_tag(self, tag: str) -> None: self._throw_if_invalid_tag(tag) if not hasattr(self.thread_local, "tags"): diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 05040d8135017..0e374e7aa2ccb 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -37,7 +37,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\x99\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa0\x04\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\xe6\x0f\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1at\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x1a\x10\n\x0eResultCompleteB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\x8c\x02\n\x18\x41rtifactStatusesResponse\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01"\xd8\x03\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB\x0e\n\x0c_client_type"[\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x93\x02\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x01R\x0elastResponseId\x88\x01\x01\x42\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc6\x03\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB\x0e\n\x0c_client_type"p\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xc9\x01\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xfb\t\n\x19\x46\x65tchErrorDetailsResponse\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xac\x01\n\x0cQueryContext\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xd1\x06\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x12\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relationB\t\n\x07\x61nalyzeB\x0e\n\x0c_client_type"\x99\r\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevelB\x08\n\x06result"\xa0\x04\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\xa5\x01\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\xe6\x0f\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1av\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x42\x0f\n\r_start_offset\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1at\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x1a\x10\n\x0eResultCompleteB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\x84\x08\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\x34\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB\x0e\n\x0c_client_type"z\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xe7\x06\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x01R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB\x0e\n\x0c_client_type"\xbc\x01\n\x14\x41\x64\x64\x41rtifactsResponse\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc3\x01\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB\x0e\n\x0c_client_type"\x8c\x02\n\x18\x41rtifactStatusesResponse\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01"\xd8\x03\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB\x0e\n\x0c_client_type"[\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\x93\x02\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x01R\x0elastResponseId\x88\x01\x01\x42\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc6\x03\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB\x0e\n\x0c_client_type"p\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xab\x01\n\x15ReleaseSessionRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"7\n\x16ReleaseSessionResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId"\xc9\x01\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x00R\nclientType\x88\x01\x01\x42\x0e\n\x0c_client_type"\xbe\x0b\n\x19\x46\x65tchErrorDetailsResponse\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xef\x02\n\x0cQueryContext\x12\x64\n\x0c\x63ontext_type\x18\n \x01(\x0e\x32\x41.spark.connect.FetchErrorDetailsResponse.QueryContext.ContextTypeR\x0b\x63ontextType\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x12\x1a\n\x08\x63\x61llSite\x18\x06 \x01(\tR\x08\x63\x61llSite\x12\x18\n\x07summary\x18\x07 \x01(\tR\x07summary"%\n\x0b\x43ontextType\x12\x07\n\x03SQL\x10\x00\x12\r\n\tDATAFRAME\x10\x01\x1a\x99\x03\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_state\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx2\xb2\x07\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12_\n\x0eReleaseSession\x12$.spark.connect.ReleaseSessionRequest\x1a%.spark.connect.ReleaseSessionResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -199,20 +199,26 @@ _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11234 _RELEASEEXECUTERESPONSE._serialized_start = 11263 _RELEASEEXECUTERESPONSE._serialized_end = 11375 - _FETCHERRORDETAILSREQUEST._serialized_start = 11378 - _FETCHERRORDETAILSREQUEST._serialized_end = 11579 - _FETCHERRORDETAILSRESPONSE._serialized_start = 11582 - _FETCHERRORDETAILSRESPONSE._serialized_end = 12857 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11727 - _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11901 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 11904 - _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12076 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12079 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12488 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 12390 - _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 12458 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12491 - _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 12838 - _SPARKCONNECTSERVICE._serialized_start = 12860 - _SPARKCONNECTSERVICE._serialized_end = 13709 + _RELEASESESSIONREQUEST._serialized_start = 11378 + _RELEASESESSIONREQUEST._serialized_end = 11549 + _RELEASESESSIONRESPONSE._serialized_start = 11551 + _RELEASESESSIONRESPONSE._serialized_end = 11606 + _FETCHERRORDETAILSREQUEST._serialized_start = 11609 + _FETCHERRORDETAILSREQUEST._serialized_end = 11810 + _FETCHERRORDETAILSRESPONSE._serialized_start = 11813 + _FETCHERRORDETAILSRESPONSE._serialized_end = 13283 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11958 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 12132 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_start = 12135 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT._serialized_end = 12502 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_start = 12465 + _FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE._serialized_end = 12502 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_start = 12505 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE._serialized_end = 12914 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_start = 12816 + _FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY._serialized_end = 12884 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 12917 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 13264 + _SPARKCONNECTSERVICE._serialized_start = 13286 + _SPARKCONNECTSERVICE._serialized_end = 14232 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 5d2ebeb573990..20abbcb348bdd 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -2763,6 +2763,84 @@ class ReleaseExecuteResponse(google.protobuf.message.Message): global___ReleaseExecuteResponse = ReleaseExecuteResponse +class ReleaseSessionRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SESSION_ID_FIELD_NUMBER: builtins.int + USER_CONTEXT_FIELD_NUMBER: builtins.int + CLIENT_TYPE_FIELD_NUMBER: builtins.int + session_id: builtins.str + """(Required) + + The session_id of the request to reattach to. + This must be an id of existing session. + """ + @property + def user_context(self) -> global___UserContext: + """(Required) User context + + user_context.user_id and session+id both identify a unique remote spark session on the + server side. + """ + client_type: builtins.str + """Provides optional information about the client sending the request. This field + can be used for language or version specific information and is only intended for + logging purposes and will not be interpreted by the server. + """ + def __init__( + self, + *, + session_id: builtins.str = ..., + user_context: global___UserContext | None = ..., + client_type: builtins.str | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_client_type", + b"_client_type", + "client_type", + b"client_type", + "user_context", + b"user_context", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_client_type", + b"_client_type", + "client_type", + b"client_type", + "session_id", + b"session_id", + "user_context", + b"user_context", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_client_type", b"_client_type"] + ) -> typing_extensions.Literal["client_type"] | None: ... + +global___ReleaseSessionRequest = ReleaseSessionRequest + +class ReleaseSessionResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SESSION_ID_FIELD_NUMBER: builtins.int + session_id: builtins.str + """Session id of the session on which the release executed.""" + def __init__( + self, + *, + session_id: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["session_id", b"session_id"] + ) -> None: ... + +global___ReleaseSessionResponse = ReleaseSessionResponse + class FetchErrorDetailsRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -2885,11 +2963,35 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + class _ContextType: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + + class _ContextTypeEnumTypeWrapper( + google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[ + FetchErrorDetailsResponse.QueryContext._ContextType.ValueType + ], + builtins.type, + ): # noqa: F821 + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + SQL: FetchErrorDetailsResponse.QueryContext._ContextType.ValueType # 0 + DATAFRAME: FetchErrorDetailsResponse.QueryContext._ContextType.ValueType # 1 + + class ContextType(_ContextType, metaclass=_ContextTypeEnumTypeWrapper): + """The type of this query context.""" + + SQL: FetchErrorDetailsResponse.QueryContext.ContextType.ValueType # 0 + DATAFRAME: FetchErrorDetailsResponse.QueryContext.ContextType.ValueType # 1 + + CONTEXT_TYPE_FIELD_NUMBER: builtins.int OBJECT_TYPE_FIELD_NUMBER: builtins.int OBJECT_NAME_FIELD_NUMBER: builtins.int START_INDEX_FIELD_NUMBER: builtins.int STOP_INDEX_FIELD_NUMBER: builtins.int FRAGMENT_FIELD_NUMBER: builtins.int + CALLSITE_FIELD_NUMBER: builtins.int + SUMMARY_FIELD_NUMBER: builtins.int + context_type: global___FetchErrorDetailsResponse.QueryContext.ContextType.ValueType object_type: builtins.str """The object type of the query which throws the exception. If the exception is directly from the main query, it should be an empty string. @@ -2906,18 +3008,29 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): """The stopping index in the query which throws the exception. The index starts from 0.""" fragment: builtins.str """The corresponding fragment of the query which throws the exception.""" + callSite: builtins.str + """The user code (call site of the API) that caused throwing the exception.""" + summary: builtins.str + """Summary of the exception cause.""" def __init__( self, *, + context_type: global___FetchErrorDetailsResponse.QueryContext.ContextType.ValueType = ..., object_type: builtins.str = ..., object_name: builtins.str = ..., start_index: builtins.int = ..., stop_index: builtins.int = ..., fragment: builtins.str = ..., + callSite: builtins.str = ..., + summary: builtins.str = ..., ) -> None: ... def ClearField( self, field_name: typing_extensions.Literal[ + "callSite", + b"callSite", + "context_type", + b"context_type", "fragment", b"fragment", "object_name", @@ -2928,6 +3041,8 @@ class FetchErrorDetailsResponse(google.protobuf.message.Message): b"start_index", "stop_index", b"stop_index", + "summary", + b"summary", ], ) -> None: ... diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py b/python/pyspark/sql/connect/proto/base_pb2_grpc.py index f6c5573ded6b5..12675747e0f92 100644 --- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py +++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py @@ -70,6 +70,11 @@ def __init__(self, channel): request_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.FromString, ) + self.ReleaseSession = channel.unary_unary( + "/spark.connect.SparkConnectService/ReleaseSession", + request_serializer=spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.SerializeToString, + response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.FromString, + ) self.FetchErrorDetails = channel.unary_unary( "/spark.connect.SparkConnectService/FetchErrorDetails", request_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.SerializeToString, @@ -141,6 +146,16 @@ def ReleaseExecute(self, request, context): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def ReleaseSession(self, request, context): + """Release a session. + All the executions in the session will be released. Any further requests for the session with + that session_id for the given user_id will fail. If the session didn't exist or was already + released, this is a noop. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def FetchErrorDetails(self, request, context): """FetchErrorDetails retrieves the matched exception with details based on a provided error id.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -190,6 +205,11 @@ def add_SparkConnectServiceServicer_to_server(servicer, server): request_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.FromString, response_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.SerializeToString, ), + "ReleaseSession": grpc.unary_unary_rpc_method_handler( + servicer.ReleaseSession, + request_deserializer=spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.FromString, + response_serializer=spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.SerializeToString, + ), "FetchErrorDetails": grpc.unary_unary_rpc_method_handler( servicer.FetchErrorDetails, request_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.FromString, @@ -438,6 +458,35 @@ def ReleaseExecute( metadata, ) + @staticmethod + def ReleaseSession( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/spark.connect.SparkConnectService/ReleaseSession", + spark_dot_connect_dot_base__pb2.ReleaseSessionRequest.SerializeToString, + spark_dot_connect_dot_base__pb2.ReleaseSessionResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + @staticmethod def FetchErrorDetails( request, diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 09bd60606c769..1aa857b4f6175 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -254,6 +254,9 @@ def __init__(self, connection: Union[str, ChannelBuilder], userId: Optional[str] self._client = SparkConnectClient(connection=connection, user_id=userId) self._session_id = self._client._session_id + # Set to false to prevent client.release_session on close() (testing only) + self.release_session_on_close = True + @classmethod def _set_default_and_active_session(cls, session: "SparkSession") -> None: """ @@ -645,15 +648,16 @@ def clearTags(self) -> None: clearTags.__doc__ = PySparkSession.clearTags.__doc__ def stop(self) -> None: - # Stopping the session will only close the connection to the current session (and - # the life cycle of the session is maintained by the server), - # whereas the regular PySpark session immediately terminates the Spark Context - # itself, meaning that stopping all Spark sessions. + # Whereas the regular PySpark session immediately terminates the Spark Context + # itself, meaning that stopping all Spark sessions, this will only stop this one session + # on the server. # It is controversial to follow the existing the regular Spark session's behavior # specifically in Spark Connect the Spark Connect server is designed for # multi-tenancy - the remote client side cannot just stop the server and stop # other remote clients being used from other users. with SparkSession._lock: + if not self.is_stopped and self.release_session_on_close: + self.client.release_session() self.client.close() if self is SparkSession._default_session: SparkSession._default_session = None diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 34bd314c76f7c..f024a03c2686c 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3437,6 +3437,7 @@ def test_can_create_multiple_sessions_to_different_remotes(self): # Gets currently active session. same = PySparkSession.builder.remote("sc://other.remote.host:114/").getOrCreate() self.assertEquals(other, same) + same.release_session_on_close = False # avoid sending release to dummy connection same.stop() # Make sure the environment is clean. diff --git a/resource-managers/kubernetes/integration-tests/README.md b/resource-managers/kubernetes/integration-tests/README.md index d39fdfbfd3966..d5ccd3fe756b7 100644 --- a/resource-managers/kubernetes/integration-tests/README.md +++ b/resource-managers/kubernetes/integration-tests/README.md @@ -329,11 +329,11 @@ You can also specify your specific dockerfile to build JVM/Python/R based image ## Requirements - A minimum of 6 CPUs and 9G of memory is required to complete all Volcano test cases. -- Volcano v1.8.0. +- Volcano v1.8.1. ## Installation - kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.0/installer/volcano-development.yaml + kubectl apply -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.1/installer/volcano-development.yaml ## Run tests @@ -354,5 +354,5 @@ You can also specify `volcano` tag to only run Volcano test: ## Cleanup Volcano - kubectl delete -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.0/installer/volcano-development.yaml + kubectl delete -f https://raw.githubusercontent.com/volcano-sh/volcano/v1.8.1/installer/volcano-development.yaml diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala index 51d2b4beab227..22e6c67090b4d 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala @@ -26,7 +26,7 @@ import org.antlr.v4.runtime.tree.TerminalNodeImpl import org.apache.spark.{QueryContext, SparkThrowable, SparkThrowableHelper} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, WithOrigin} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin, SQLQueryContext, WithOrigin} import org.apache.spark.sql.catalyst.util.SparkParserUtils import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SqlApiConf @@ -229,7 +229,7 @@ class ParseException( val builder = new StringBuilder builder ++= "\n" ++= message start match { - case Origin(Some(l), Some(p), _, _, _, _, _) => + case Origin(Some(l), Some(p), _, _, _, _, _, _) => builder ++= s" (line $l, pos $p)\n" command.foreach { cmd => val (above, below) = cmd.split("\n").splitAt(l) @@ -262,8 +262,7 @@ class ParseException( object ParseException { def getQueryContext(): Array[QueryContext] = { - val context = CurrentOrigin.get.context - if (context.isValid) Array(context) else Array.empty + Some(CurrentOrigin.get.context).collect { case b: SQLQueryContext if b.isValid => b }.toArray } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala similarity index 78% rename from sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala index 5b29cb3dde74f..b8288b24535e8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/SQLQueryContext.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/QueryContexts.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.trees -import org.apache.spark.QueryContext +import org.apache.spark.{QueryContext, QueryContextType} /** The class represents error context of a SQL query. */ case class SQLQueryContext( @@ -28,6 +28,7 @@ case class SQLQueryContext( sqlText: Option[String], originObjectType: Option[String], originObjectName: Option[String]) extends QueryContext { + override val contextType = QueryContextType.SQL override val objectType = originObjectType.getOrElse("") override val objectName = originObjectName.getOrElse("") @@ -40,7 +41,7 @@ case class SQLQueryContext( * SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i * ^^^^^^^^^^^^^^^ */ - lazy val summary: String = { + override lazy val summary: String = { // If the query context is missing or incorrect, simply return an empty string. if (!isValid) { "" @@ -116,7 +117,7 @@ case class SQLQueryContext( } /** Gets the textual fragment of a SQL query. */ - override lazy val fragment: String = { + lazy val fragment: String = { if (!isValid) { "" } else { @@ -128,6 +129,45 @@ case class SQLQueryContext( sqlText.isDefined && originStartIndex.isDefined && originStopIndex.isDefined && originStartIndex.get >= 0 && originStopIndex.get < sqlText.get.length && originStartIndex.get <= originStopIndex.get + } + + override def callSite: String = throw new UnsupportedOperationException +} + +case class DataFrameQueryContext( + override val fragment: String, + override val callSite: String) extends QueryContext { + override val contextType = QueryContextType.DataFrame + + override def objectType: String = throw new UnsupportedOperationException + override def objectName: String = throw new UnsupportedOperationException + override def startIndex: Int = throw new UnsupportedOperationException + override def stopIndex: Int = throw new UnsupportedOperationException + + override lazy val summary: String = { + val builder = new StringBuilder + builder ++= "== DataFrame ==\n" + builder ++= "\"" + + builder ++= fragment + builder ++= "\"" + builder ++= " was called from " + builder ++= callSite + builder += '\n' + builder.result() + } +} + +object DataFrameQueryContext { + def apply(elements: Array[StackTraceElement]): DataFrameQueryContext = { + val methodName = elements(0).getMethodName + val code = if (methodName.length > 1 && methodName(0) == '$') { + methodName.substring(1) + } else { + methodName + } + val callSite = elements(1).toString + DataFrameQueryContext(code, callSite) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala index ec3e627ac9585..dd24dae16ba8c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/trees/origin.scala @@ -30,15 +30,21 @@ case class Origin( stopIndex: Option[Int] = None, sqlText: Option[String] = None, objectType: Option[String] = None, - objectName: Option[String] = None) { + objectName: Option[String] = None, + stackTrace: Option[Array[StackTraceElement]] = None) { - lazy val context: SQLQueryContext = SQLQueryContext( - line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) - - def getQueryContext: Array[QueryContext] = if (context.isValid) { - Array(context) + lazy val context: QueryContext = if (stackTrace.isDefined) { + DataFrameQueryContext(stackTrace.get) } else { - Array.empty + SQLQueryContext( + line, startPosition, startIndex, stopIndex, sqlText, objectType, objectName) + } + + def getQueryContext: Array[QueryContext] = { + Some(context).filter { + case s: SQLQueryContext => s.isValid + case _ => true + }.toArray } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala index 7c1b37e9e5815..99caef978bb4a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.ExecutionErrors /** @@ -27,37 +27,37 @@ object MathUtils { def addExact(a: Int, b: Int): Int = withOverflow(Math.addExact(a, b)) - def addExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def addExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.addExact(a, b), hint = "try_add", context) } def addExact(a: Long, b: Long): Long = withOverflow(Math.addExact(a, b)) - def addExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def addExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.addExact(a, b), hint = "try_add", context) } def subtractExact(a: Int, b: Int): Int = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def subtractExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) } def subtractExact(a: Long, b: Long): Long = withOverflow(Math.subtractExact(a, b)) - def subtractExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def subtractExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.subtractExact(a, b), hint = "try_subtract", context) } def multiplyExact(a: Int, b: Int): Int = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Int, b: Int, context: SQLQueryContext): Int = { + def multiplyExact(a: Int, b: Int, context: QueryContext): Int = { withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) } def multiplyExact(a: Long, b: Long): Long = withOverflow(Math.multiplyExact(a, b)) - def multiplyExact(a: Long, b: Long, context: SQLQueryContext): Long = { + def multiplyExact(a: Long, b: Long, context: QueryContext): Long = { withOverflow(Math.multiplyExact(a, b), hint = "try_multiply", context) } @@ -78,7 +78,7 @@ object MathUtils { def withOverflow[A]( f: => A, hint: String = "", - context: SQLQueryContext = null): A = { + context: QueryContext = null): A = { try { f } catch { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala index 698e7b37a9ef0..f8a9274a5646c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala @@ -25,7 +25,7 @@ import scala.util.control.NonFatal import sun.util.calendar.ZoneInfo -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.RebaseDateTime.{rebaseGregorianToJulianDays, rebaseGregorianToJulianMicros, rebaseJulianToGregorianDays, rebaseJulianToGregorianMicros} import org.apache.spark.sql.errors.ExecutionErrors @@ -355,7 +355,7 @@ trait SparkDateTimeUtils { def stringToDateAnsi( s: UTF8String, - context: SQLQueryContext = null): Int = { + context: QueryContext = null): Int = { stringToDate(s).getOrElse { throw ExecutionErrors.invalidInputInCastToDatetimeError(s, DateType, context) } @@ -567,7 +567,7 @@ trait SparkDateTimeUtils { def stringToTimestampAnsi( s: UTF8String, timeZoneId: ZoneId, - context: SQLQueryContext = null): Long = { + context: QueryContext = null): Long = { stringToTimestamp(s, timeZoneId).getOrElse { throw ExecutionErrors.invalidInputInCastToDatetimeError(s, TimestampType, context) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala index 5e52e283338d3..b30f7b7a00e91 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrors.scala @@ -16,9 +16,9 @@ */ package org.apache.spark.sql.errors -import org.apache.spark.{SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.{QueryContext, SparkArithmeticException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} +import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.catalyst.util.QuotingUtils import org.apache.spark.sql.catalyst.util.QuotingUtils.toSQLSchema import org.apache.spark.sql.types.{DataType, Decimal, StringType} @@ -191,7 +191,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { numericValueOutOfRange(value, decimalPrecision, decimalScale, context) } @@ -199,7 +199,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { numericValueOutOfRange(value, decimalPrecision, decimalScale, context) } @@ -207,7 +207,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext): ArithmeticException = { + context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", messageParameters = Map( @@ -222,7 +222,7 @@ private[sql] object DataTypeErrors extends DataTypeErrorsBase { def invalidInputInCastToNumberError( to: DataType, s: UTF8String, - context: SQLQueryContext): SparkNumberFormatException = { + context: QueryContext): SparkNumberFormatException = { val convertedValueStr = "'" + s.toString.replace("\\", "\\\\").replace("'", "\\'") + "'" new SparkNumberFormatException( errorClass = "CAST_INVALID_INPUT", diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala index 911d900053cf9..d1d9dd806b3b8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/DataTypeErrorsBase.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.errors import java.util.Locale import org.apache.spark.QueryContext -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.{AttributeNameParser, QuotingUtils} import org.apache.spark.sql.types.{AbstractDataType, DataType, TypeCollection} import org.apache.spark.unsafe.types.UTF8String @@ -89,11 +88,11 @@ private[sql] trait DataTypeErrorsBase { "\"" + elem + "\"" } - def getSummary(sqlContext: SQLQueryContext): String = { + def getSummary(sqlContext: QueryContext): String = { if (sqlContext == null) "" else sqlContext.summary } - def getQueryContext(sqlContext: SQLQueryContext): Array[QueryContext] = { - if (sqlContext == null) Array.empty else Array(sqlContext.asInstanceOf[QueryContext]) + def getQueryContext(context: QueryContext): Array[QueryContext] = { + if (context == null) Array.empty else Array(context) } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala index c8321e81027ba..394e56062071b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala @@ -21,9 +21,8 @@ import java.time.temporal.ChronoField import org.apache.arrow.vector.types.pojo.ArrowType -import org.apache.spark.{SparkArithmeticException, SparkBuildInfo, SparkDateTimeException, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} +import org.apache.spark.{QueryContext, SparkArithmeticException, SparkBuildInfo, SparkDateTimeException, SparkException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} import org.apache.spark.sql.catalyst.WalkedTypePath -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{DataType, DoubleType, StringType, UserDefinedType} import org.apache.spark.unsafe.types.UTF8String @@ -83,14 +82,14 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def invalidInputInCastToDatetimeError( value: UTF8String, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { invalidInputInCastToDatetimeErrorInternal(toSQLValue(value), StringType, to, context) } def invalidInputInCastToDatetimeError( value: Double, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { invalidInputInCastToDatetimeErrorInternal(toSQLValue(value), DoubleType, to, context) } @@ -98,7 +97,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { sqlValue: String, from: DataType, to: DataType, - context: SQLQueryContext): SparkDateTimeException = { + context: QueryContext): SparkDateTimeException = { new SparkDateTimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -113,7 +112,7 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { def arithmeticOverflowError( message: String, hint: String = "", - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { val alternative = if (hint.nonEmpty) { s" Use '$hint' to tolerate overflow and return NULL instead." } else "" diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 3c386b20a7912..5652e5adda9d4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -21,8 +21,8 @@ import java.math.{BigDecimal => JavaBigDecimal, BigInteger, MathContext, Roundin import scala.util.Try +import org.apache.spark.QueryContext import org.apache.spark.annotation.Unstable -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.unsafe.types.UTF8String @@ -341,7 +341,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { scale: Int, roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP, nullOnOverflow: Boolean = true, - context: SQLQueryContext = null): Decimal = { + context: QueryContext = null): Decimal = { val copy = clone() if (copy.changePrecision(precision, scale, roundMode)) { copy @@ -617,7 +617,7 @@ object Decimal { def fromStringANSI( str: UTF8String, to: DecimalType = DecimalType.USER_DEFAULT, - context: SQLQueryContext = null): Decimal = { + context: QueryContext = null): Decimal = { try { val bigDecimal = stringToJavaBigDecimal(str) // We fast fail because constructing a very large JavaBigDecimal to Decimal is very slow. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 99117d81b34ad..62295fe260535 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -21,13 +21,13 @@ import java.time.{ZoneId, ZoneOffset} import java.util.Locale import java.util.concurrent.TimeUnit._ -import org.apache.spark.SparkArithmeticException +import org.apache.spark.{QueryContext, SparkArithmeticException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, TreeNodeTag} +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.{PhysicalFractionalType, PhysicalIntegralType, PhysicalNumericType} import org.apache.spark.sql.catalyst.util._ @@ -527,7 +527,7 @@ case class Cast( } } - override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) { + override def initQueryContext(): Option[QueryContext] = if (ansiEnabled) { Some(origin.context) } else { None @@ -945,7 +945,7 @@ case class Cast( private[this] def toPrecision( value: Decimal, decimalType: DecimalType, - context: SQLQueryContext): Decimal = + context: QueryContext): Decimal = value.toPrecision( decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled, context) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 46c36ab8e3c31..0dc70c6c3947c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import org.apache.spark.SparkException +import org.apache.spark.{QueryContext, SparkException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, TernaryLike, TreeNode, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.truncatedString @@ -613,11 +613,11 @@ abstract class UnaryExpression extends Expression with UnaryLike[Expression] { * to executors. It will also be kept after rule transforms. */ trait SupportQueryContext extends Expression with Serializable { - protected var queryContext: Option[SQLQueryContext] = initQueryContext() + protected var queryContext: Option[QueryContext] = initQueryContext() - def initQueryContext(): Option[SQLQueryContext] + def initQueryContext(): Option[QueryContext] - def getContextOrNull(): SQLQueryContext = queryContext.orNull + def getContextOrNull(): QueryContext = queryContext.orNull def getContextOrNullCode(ctx: CodegenContext, withErrorContext: Boolean = true): String = { if (withErrorContext && queryContext.isDefined) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index fd6131f185606..fe30e2ea6f3ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors @@ -134,7 +135,7 @@ case class Average( override protected def withNewChildInternal(newChild: Expression): Average = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = if (evalMode == EvalMode.ANSI) { + override def initQueryContext(): Option[QueryContext] = if (evalMode == EvalMode.ANSI) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index e3881520e4902..dfd41ad12a280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{EvalMode, _} -import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{SUM, TreePattern} import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors @@ -186,7 +187,7 @@ case class Sum( // The flag `evalMode` won't be shown in the `toString` or `toAggString` methods override def flatArguments: Iterator[Any] = Iterator(child) - override def initQueryContext(): Option[SQLQueryContext] = if (evalMode == EvalMode.ANSI) { + override def initQueryContext(): Option[QueryContext] = if (evalMode == EvalMode.ANSI) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index a556ac9f12947..e3c5184c5acc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions import scala.math.{max, min} +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern, UNARY_POSITIVE} import org.apache.spark.sql.catalyst.types.{PhysicalDecimalType, PhysicalFractionalType, PhysicalIntegerType, PhysicalIntegralType, PhysicalLongType} import org.apache.spark.sql.catalyst.util.{IntervalMathUtils, IntervalUtils, MathUtils, TypeUtils} @@ -266,7 +266,7 @@ abstract class BinaryArithmetic extends BinaryOperator final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC) - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (failOnError) { Some(origin.context) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index adaed0ff819be..25da787b8874f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -22,13 +22,14 @@ import java.util.Comparator import scala.collection.mutable import scala.reflect.ClassTag +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern} import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType, PhysicalIntegralType} import org.apache.spark.sql.catalyst.util._ @@ -2526,7 +2527,7 @@ case class ElementAt( override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): ElementAt = copy(left = newLeft, right = newRight) - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (failOnError && left.resolved && left.dataType.isInstanceOf[ArrayType]) { Some(origin.context) } else { @@ -5046,7 +5047,7 @@ case class ArrayInsert( newSrcArrayExpr: Expression, newPosExpr: Expression, newItemExpr: Expression): ArrayInsert = copy(srcArrayExpr = newSrcArrayExpr, posExpr = newPosExpr, itemExpr = newItemExpr) - override def initQueryContext(): Option[SQLQueryContext] = Some(origin.context) + override def initQueryContext(): Option[QueryContext] = Some(origin.context) } @ExpressionDescription( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3885a5b9f5b32..a801d0367080d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.trees.TreePattern.{EXTRACT_VALUE, TreePattern} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -316,7 +316,7 @@ case class GetArrayItem( newLeft: Expression, newRight: Expression): GetArrayItem = copy(child = newLeft, ordinal = newRight) - override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + override def initQueryContext(): Option[QueryContext] = if (failOnError) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 378920856eb11..5f13d397d1bf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.types.PhysicalDecimalType import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.errors.QueryExecutionErrors @@ -146,7 +146,7 @@ case class CheckOverflow( override protected def withNewChildInternal(newChild: Expression): CheckOverflow = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = if (!nullOnOverflow) { + override def initQueryContext(): Option[QueryContext] = if (!nullOnOverflow) { Some(origin.context) } else { None @@ -158,7 +158,7 @@ case class CheckOverflowInSum( child: Expression, dataType: DecimalType, nullOnOverflow: Boolean, - context: SQLQueryContext) extends UnaryExpression with SupportQueryContext { + context: QueryContext) extends UnaryExpression with SupportQueryContext { override def nullable: Boolean = true @@ -210,7 +210,7 @@ case class CheckOverflowInSum( override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum = copy(child = newChild) - override def initQueryContext(): Option[SQLQueryContext] = Option(context) + override def initQueryContext(): Option[QueryContext] = Option(context) } /** @@ -256,12 +256,12 @@ case class DecimalDivideWithOverflowCheck( left: Expression, right: Expression, override val dataType: DecimalType, - context: SQLQueryContext, + context: QueryContext, nullOnOverflow: Boolean) extends BinaryExpression with ExpectsInputTypes with SupportQueryContext { override def nullable: Boolean = nullOnOverflow override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, DecimalType) - override def initQueryContext(): Option[SQLQueryContext] = Option(context) + override def initQueryContext(): Option[QueryContext] = Option(context) def decimalMethod: String = "$div" override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 3750a9271cff0..aa1f6159def8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, Un import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike} +import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ @@ -200,9 +200,11 @@ trait HigherOrderFunction extends Expression with ExpectsInputTypes { */ final def bind( f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): HigherOrderFunction = { - val res = bindInternal(f) - res.copyTagsFrom(this) - res + CurrentOrigin.withOrigin(origin) { + val res = bindInternal(f) + res.copyTagsFrom(this) + res + } } protected def bindInternal( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 5378639e6838b..13676733a9bad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -22,8 +22,8 @@ import java.util.Locale import com.google.common.math.{DoubleMath, IntMath, LongMath} +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants.MONTHS_PER_YEAR import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils._ @@ -604,7 +604,7 @@ trait IntervalDivide { minValue: Any, num: Expression, numValue: Any, - context: SQLQueryContext): Unit = { + context: QueryContext): Unit = { if (value == minValue && num.dataType.isInstanceOf[IntegralType]) { if (numValue.asInstanceOf[Number].longValue() == -1) { throw QueryExecutionErrors.intervalArithmeticOverflowError( @@ -616,7 +616,7 @@ trait IntervalDivide { def divideByZeroCheck( dataType: DataType, num: Any, - context: SQLQueryContext): Unit = dataType match { + context: QueryContext): Unit = dataType match { case _: DecimalType => if (num.asInstanceOf[Decimal].isZero) { throw QueryExecutionErrors.intervalDividedByZeroError(context) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index c320d98d9fd1c..0c09e9be12e94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} import java.util.Locale +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.trees.SQLQueryContext import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf @@ -480,7 +480,7 @@ case class Conv( newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(numExpr = newFirst, fromBaseExpr = newSecond, toBaseExpr = newThird) - override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) { + override def initQueryContext(): Option[QueryContext] = if (ansiEnabled) { Some(origin.context) } else { None @@ -1523,7 +1523,7 @@ abstract class RoundBase(child: Expression, scale: Expression, private lazy val scaleV: Any = scale.eval(EmptyRow) protected lazy val _scale: Int = scaleV.asInstanceOf[Int] - override def initQueryContext(): Option[SQLQueryContext] = { + override def initQueryContext(): Option[QueryContext] = { if (ansiEnabled) { Some(origin.context) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 9eefcef8e17d2..761bd3f33586e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -1312,7 +1312,7 @@ object IsUnknown { def apply(child: Expression): Predicate = { new IsNull(child) with ExpectsInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType) - override def sql: String = s"(${child.sql} IS UNKNOWN)" + override def sql: String = s"(${this.child.sql} IS UNKNOWN)" } } } @@ -1321,7 +1321,7 @@ object IsNotUnknown { def apply(child: Expression): Predicate = { new IsNotNull(child) with ExpectsInputTypes { override def inputTypes: Seq[DataType] = Seq(BooleanType) - override def sql: String = s"(${child.sql} IS NOT UNKNOWN)" + override def sql: String = s"(${this.child.sql} IS NOT UNKNOWN)" } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index f5c34d67d4db1..cf6c7780cc82f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -23,6 +23,7 @@ import java.util.{HashMap, Locale, Map => JMap} import scala.collection.mutable.ArrayBuffer +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext} +import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -411,7 +412,7 @@ case class Elt( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Elt = copy(children = newChildren) - override def initQueryContext(): Option[SQLQueryContext] = if (failOnError) { + override def initQueryContext(): Option[QueryContext] = if (failOnError) { Some(origin.context) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 23bbc91c16d54..8fabb44876208 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit._ import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{Decimal, DoubleExactNumeric, TimestampNTZType, TimestampType} @@ -70,7 +70,7 @@ object DateTimeUtils extends SparkDateTimeUtils { // the "GMT" string. For example, it returns 2000-01-01T00:00+01:00 for 2000-01-01T00:00GMT+01:00. def cleanLegacyTimestampStr(s: UTF8String): UTF8String = s.replace(gmtUtf8, UTF8String.EMPTY_UTF8) - def doubleToTimestampAnsi(d: Double, context: SQLQueryContext): Long = { + def doubleToTimestampAnsi(d: Double, context: QueryContext): Long = { if (d.isNaN || d.isInfinite) { throw QueryExecutionErrors.invalidInputInCastToDatetimeError(d, TimestampType, context) } else { @@ -91,7 +91,7 @@ object DateTimeUtils extends SparkDateTimeUtils { def stringToTimestampWithoutTimeZoneAnsi( s: UTF8String, - context: SQLQueryContext): Long = { + context: QueryContext): Long = { stringToTimestampWithoutTimeZone(s, true).getOrElse { throw QueryExecutionErrors.invalidInputInCastToDatetimeError(s, TimestampNTZType, context) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala index 59765cde1f926..2730ab8f4b890 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.unsafe.types.UTF8String @@ -54,7 +54,7 @@ object NumberConverter { fromPos: Int, value: Array[Byte], ansiEnabled: Boolean, - context: SQLQueryContext): Long = { + context: QueryContext): Long = { var v: Long = 0L // bound will always be positive since radix >= 2 // Note that: -1 is equivalent to 11111111...1111 which is the largest unsigned long value @@ -134,7 +134,7 @@ object NumberConverter { fromBase: Int, toBase: Int, ansiEnabled: Boolean, - context: SQLQueryContext): UTF8String = { + context: QueryContext): UTF8String = { if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX || Math.abs(toBase) < Character.MIN_RADIX || Math.abs(toBase) > Character.MAX_RADIX) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala index f7800469c3528..1c3a5075dab2c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UTF8StringUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import org.apache.spark.sql.catalyst.trees.SQLQueryContext +import org.apache.spark.QueryContext import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, LongType, ShortType} import org.apache.spark.unsafe.types.UTF8String @@ -27,21 +27,21 @@ import org.apache.spark.unsafe.types.UTF8String */ object UTF8StringUtils { - def toLongExact(s: UTF8String, context: SQLQueryContext): Long = + def toLongExact(s: UTF8String, context: QueryContext): Long = withException(s.toLongExact, context, LongType, s) - def toIntExact(s: UTF8String, context: SQLQueryContext): Int = + def toIntExact(s: UTF8String, context: QueryContext): Int = withException(s.toIntExact, context, IntegerType, s) - def toShortExact(s: UTF8String, context: SQLQueryContext): Short = + def toShortExact(s: UTF8String, context: QueryContext): Short = withException(s.toShortExact, context, ShortType, s) - def toByteExact(s: UTF8String, context: SQLQueryContext): Byte = + def toByteExact(s: UTF8String, context: QueryContext): Byte = withException(s.toByteExact, context, ByteType, s) private def withException[A]( f: => A, - context: SQLQueryContext, + context: QueryContext, to: DataType, s: UTF8String): A = { try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index afc244509c41d..30dfe8eebe6cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ValueInterval -import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext, TreeNode} +import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} import org.apache.spark.sql.catalyst.util.{sideBySide, BadRecordException, DateTimeUtils, FailFastMode, MapData} import org.apache.spark.sql.connector.catalog.{CatalogNotFoundException, Table, TableProvider} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -104,7 +104,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE value: Decimal, decimalPrecision: Int, decimalScale: Int, - context: SQLQueryContext = null): ArithmeticException = { + context: QueryContext = null): ArithmeticException = { new SparkArithmeticException( errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", messageParameters = Map( @@ -118,7 +118,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidInputSyntaxForBooleanError( s: UTF8String, - context: SQLQueryContext): SparkRuntimeException = { + context: QueryContext): SparkRuntimeException = { new SparkRuntimeException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -133,7 +133,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidInputInCastToNumberError( to: DataType, s: UTF8String, - context: SQLQueryContext): SparkNumberFormatException = { + context: QueryContext): SparkNumberFormatException = { new SparkNumberFormatException( errorClass = "CAST_INVALID_INPUT", messageParameters = Map( @@ -194,15 +194,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = e) } - def divideByZeroError(context: SQLQueryContext): ArithmeticException = { + def divideByZeroError(context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "DIVIDE_BY_ZERO", messageParameters = Map("config" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = getQueryContext(context), + context = Array(context), summary = getSummary(context)) } - def intervalDividedByZeroError(context: SQLQueryContext): ArithmeticException = { + def intervalDividedByZeroError(context: QueryContext): ArithmeticException = { new SparkArithmeticException( errorClass = "INTERVAL_DIVIDED_BY_ZERO", messageParameters = Map.empty, @@ -213,7 +213,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidArrayIndexError( index: Int, numElements: Int, - context: SQLQueryContext): ArrayIndexOutOfBoundsException = { + context: QueryContext): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX", messageParameters = Map( @@ -227,7 +227,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def invalidElementAtIndexError( index: Int, numElements: Int, - context: SQLQueryContext): ArrayIndexOutOfBoundsException = { + context: QueryContext): ArrayIndexOutOfBoundsException = { new SparkArrayIndexOutOfBoundsException( errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", messageParameters = Map( @@ -292,15 +292,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE ansiIllegalArgumentError(e.getMessage) } - def overflowInSumOfDecimalError(context: SQLQueryContext): ArithmeticException = { + def overflowInSumOfDecimalError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in sum of decimals", context = context) } - def overflowInIntegralDivideError(context: SQLQueryContext): ArithmeticException = { + def overflowInIntegralDivideError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in integral divide", "try_divide", context) } - def overflowInConvError(context: SQLQueryContext): ArithmeticException = { + def overflowInConvError(context: QueryContext): ArithmeticException = { arithmeticOverflowError("Overflow in function conv()", context = context) } @@ -625,7 +625,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def intervalArithmeticOverflowError( message: String, hint: String = "", - context: SQLQueryContext): ArithmeticException = { + context: QueryContext): ArithmeticException = { val alternative = if (hint.nonEmpty) { s" Use '$hint' to tolerate overflow and return NULL instead." } else "" @@ -1391,7 +1391,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "functionName" -> toSQLId(prettyName))) } - def invalidIndexOfZeroError(context: SQLQueryContext): RuntimeException = { + def invalidIndexOfZeroError(context: QueryContext): RuntimeException = { new SparkRuntimeException( errorClass = "INVALID_INDEX_OF_ZERO", cause = null, @@ -2556,7 +2556,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = null) } - def multipleRowScalarSubqueryError(context: SQLQueryContext): Throwable = { + def multipleRowScalarSubqueryError(context: QueryContext): Throwable = { new SparkException( errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS", messageParameters = Map.empty, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 997308c6ef44f..ba4e7b279f512 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -35,8 +35,6 @@ import org.apache.spark.sql.types.StructType trait AnalysisTest extends PlanTest { - import org.apache.spark.QueryContext - protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Nil protected def createTempView( @@ -177,7 +175,7 @@ trait AnalysisTest extends PlanTest { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { val analyzer = getAnalyzer diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index d91a080d8fe89..3fd0c1ee5de4b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.QueryContext import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, CreateNamedStruct, GetStructField, If, IsNull, LessThanOrEqual, Literal} @@ -159,7 +158,7 @@ abstract class V2ANSIWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.ANSI.toString) { super.assertAnalysisErrorClass( @@ -196,7 +195,7 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { inputPlan: LogicalPlan, expectedErrorClass: String, expectedMessageParameters: Map[String, String], - queryContext: Array[QueryContext] = Array.empty, + queryContext: Array[ExpectedContext] = Array.empty, caseSensitive: Boolean = true): Unit = { withSQLConf(SQLConf.STORE_ASSIGNMENT_POLICY.key -> StoreAssignmentPolicy.STRICT.toString) { super.assertAnalysisErrorClass( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 7765bc26741be..cd7f7295d5cb9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -452,7 +452,7 @@ abstract class InMemoryBaseTable( val matchingKeys = values.map { value => if (value != null) value.toString else null }.toSet - data = data.filter(partition => { + this.data = this.data.filter(partition => { val rows = partition.asInstanceOf[BufferedRows] rows.key match { // null partitions are represented as Seq(null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 9bb35a8b0b3d1..0ca55ef67fd38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -70,8 +70,10 @@ private[sql] object Column { name: String, isDistinct: Boolean, ignoreNulls: Boolean, - inputs: Column*): Column = Column { - UnresolvedFunction(Seq(name), inputs.map(_.expr), isDistinct, ignoreNulls = ignoreNulls) + inputs: Column*): Column = withOrigin { + Column { + UnresolvedFunction(Seq(name), inputs.map(_.expr), isDistinct, ignoreNulls = ignoreNulls) + } } } @@ -148,12 +150,14 @@ class TypedColumn[-T, U]( @Stable class Column(val expr: Expression) extends Logging { - def this(name: String) = this(name match { - case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => - val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) - UnresolvedStar(Some(parts)) - case _ => UnresolvedAttribute.quotedString(name) + def this(name: String) = this(withOrigin { + name match { + case "*" => UnresolvedStar(None) + case _ if name.endsWith(".*") => + val parts = UnresolvedAttribute.parseAttributeName(name.substring(0, name.length - 2)) + UnresolvedStar(Some(parts)) + case _ => UnresolvedAttribute.quotedString(name) + } }) private def fn(name: String): Column = { @@ -180,7 +184,9 @@ class Column(val expr: Expression) extends Logging { } /** Creates a column based on the given expression. */ - private def withExpr(newExpr: Expression): Column = new Column(newExpr) + private def withExpr(newExpr: => Expression): Column = withOrigin { + new Column(newExpr) + } /** * Returns the expression for this column either with an existing or auto assigned name. @@ -1370,7 +1376,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def over(window: expressions.WindowSpec): Column = window.withAggregate(this) + def over(window: expressions.WindowSpec): Column = withOrigin { + window.withAggregate(this) + } /** * Defines an empty analytic clause. In this case the analytic function is applied diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index f3690773f6ddd..a8c4d4f8d2ba7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -72,7 +72,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def approxQuantile( col: String, probabilities: Array[Double], - relativeError: Double): Array[Double] = { + relativeError: Double): Array[Double] = withOrigin { approxQuantile(Array(col), probabilities, relativeError).head } @@ -97,7 +97,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { def approxQuantile( cols: Array[String], probabilities: Array[Double], - relativeError: Double): Array[Array[Double]] = { + relativeError: Double): Array[Array[Double]] = withOrigin { StatFunctions.multipleApproxQuantiles( df.select(cols.map(col): _*), cols, @@ -132,7 +132,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def cov(col1: String, col2: String): Double = { + def cov(col1: String, col2: String): Double = withOrigin { StatFunctions.calculateCov(df, Seq(col1, col2)) } @@ -154,7 +154,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def corr(col1: String, col2: String, method: String): Double = { + def corr(col1: String, col2: String, method: String): Double = withOrigin { require(method == "pearson", "Currently only the calculation of the Pearson Correlation " + "coefficient is supported.") StatFunctions.pearsonCorrelation(df, Seq(col1, col2)) @@ -210,7 +210,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def crosstab(col1: String, col2: String): DataFrame = { + def crosstab(col1: String, col2: String): DataFrame = withOrigin { StatFunctions.crossTabulate(df, col1, col2) } @@ -257,7 +257,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Array[String], support: Double): DataFrame = { + def freqItems(cols: Array[String], support: Double): DataFrame = withOrigin { FrequentItems.singlePassFreqItems(df, cols, support) } @@ -276,7 +276,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Array[String]): DataFrame = { + def freqItems(cols: Array[String]): DataFrame = withOrigin { FrequentItems.singlePassFreqItems(df, cols, 0.01) } @@ -320,7 +320,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Seq[String], support: Double): DataFrame = { + def freqItems(cols: Seq[String], support: Double): DataFrame = withOrigin { FrequentItems.singlePassFreqItems(df, cols, support) } @@ -339,7 +339,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 1.4.0 */ - def freqItems(cols: Seq[String]): DataFrame = { + def freqItems(cols: Seq[String]): DataFrame = withOrigin { FrequentItems.singlePassFreqItems(df, cols, 0.01) } @@ -415,7 +415,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * * @since 3.0.0 */ - def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = { + def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DataFrame = withOrigin { require(fractions.values.forall(p => p >= 0.0 && p <= 1.0), s"Fractions must be in [0, 1], but got $fractions.") import org.apache.spark.sql.functions.{rand, udf} @@ -497,7 +497,11 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @return a `CountMinSketch` over column `colName` * @since 2.0.0 */ - def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = { + def countMinSketch( + col: Column, + eps: Double, + confidence: Double, + seed: Int): CountMinSketch = withOrigin { val countMinSketchAgg = new CountMinSketchAgg( col.expr, Literal(eps, DoubleType), @@ -555,7 +559,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * @param numBits expected number of bits of the filter. * @since 2.0.0 */ - def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = { + def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = withOrigin { val bloomFilterAgg = new BloomFilterAggregate( col.expr, Literal(expectedNumItems, LongType), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4f07133bb7617..a567a915daf66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -508,9 +508,11 @@ class Dataset[T] private[sql]( * @group basic * @since 3.4.0 */ - def to(schema: StructType): DataFrame = withPlan { - val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] - Project.matchSchema(logicalPlan, replaced, sparkSession.sessionState.conf) + def to(schema: StructType): DataFrame = withOrigin { + withPlan { + val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] + Project.matchSchema(logicalPlan, replaced, sparkSession.sessionState.conf) + } } /** @@ -650,7 +652,7 @@ class Dataset[T] private[sql]( * @group basic * @since 2.4.0 */ - def isEmpty: Boolean = withAction("isEmpty", select().queryExecution) { plan => + def isEmpty: Boolean = withAction("isEmpty", select().limit(1).queryExecution) { plan => plan.executeTake(1).isEmpty } @@ -770,12 +772,14 @@ class Dataset[T] private[sql]( */ // We only accept an existing column name, not a derived column here as a watermark that is // defined on a derived column cannot referenced elsewhere in the plan. - def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withTypedPlan { - val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold) - require(!IntervalUtils.isNegative(parsedDelay), - s"delay threshold ($delayThreshold) should not be negative.") - EliminateEventTimeWatermark( - EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] = withOrigin { + withTypedPlan { + val parsedDelay = IntervalUtils.fromIntervalString(delayThreshold) + require(!IntervalUtils.isNegative(parsedDelay), + s"delay threshold ($delayThreshold) should not be negative.") + EliminateEventTimeWatermark( + EventTimeWatermark(UnresolvedAttribute(eventTime), parsedDelay, logicalPlan)) + } } /** @@ -947,8 +951,10 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE) + def join(right: Dataset[_]): DataFrame = withOrigin { + withPlan { + Join(logicalPlan, right.logicalPlan, joinType = Inner, None, JoinHint.NONE) + } } /** @@ -1081,22 +1087,23 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = { - // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right - // by creating a new instance for one of the branch. - val joined = sparkSession.sessionState.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE)) - .analyzed.asInstanceOf[Join] - - withPlan { - Join( - joined.left, - joined.right, - UsingJoin(JoinType(joinType), usingColumns.toIndexedSeq), - None, - JoinHint.NONE) + def join(right: Dataset[_], usingColumns: Seq[String], joinType: String): DataFrame = + withOrigin { + // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right + // by creating a new instance for one of the branch. + val joined = sparkSession.sessionState.executePlan( + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None, JoinHint.NONE)) + .analyzed.asInstanceOf[Join] + + withPlan { + Join( + joined.left, + joined.right, + UsingJoin(JoinType(joinType), usingColumns.toIndexedSeq), + None, + JoinHint.NONE) + } } - } /** * Inner join with another `DataFrame`, using the given join expression. @@ -1177,7 +1184,7 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = { + def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = withOrigin { withPlan { resolveSelfJoinCondition(right, Some(joinExprs), joinType) } @@ -1193,8 +1200,10 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.1.0 */ - def crossJoin(right: Dataset[_]): DataFrame = withPlan { - Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE) + def crossJoin(right: Dataset[_]): DataFrame = withOrigin { + withPlan { + Join(logicalPlan, right.logicalPlan, joinType = Cross, None, JoinHint.NONE) + } } /** @@ -1218,27 +1227,28 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = { - // Creates a Join node and resolve it first, to get join condition resolved, self-join resolved, - // etc. - val joined = sparkSession.sessionState.executePlan( - Join( - this.logicalPlan, - other.logicalPlan, - JoinType(joinType), - Some(condition.expr), - JoinHint.NONE)).analyzed.asInstanceOf[Join] - - implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) - - withTypedPlan(JoinWith.typedJoinWith( - joined, - sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity, - sparkSession.sessionState.analyzer.resolver, - this.exprEnc.isSerializedAsStructForTopLevel, - other.exprEnc.isSerializedAsStructForTopLevel)) - } + def joinWith[U](other: Dataset[U], condition: Column, joinType: String): Dataset[(T, U)] = + withOrigin { + // Creates a Join node and resolve it first, to get join condition resolved, self-join + // resolved, etc. + val joined = sparkSession.sessionState.executePlan( + Join( + this.logicalPlan, + other.logicalPlan, + JoinType(joinType), + Some(condition.expr), + JoinHint.NONE)).analyzed.asInstanceOf[Join] + + implicit val tuple2Encoder: Encoder[(T, U)] = + ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) + + withTypedPlan(JoinWith.typedJoinWith( + joined, + sqlContext.conf.dataFrameSelfJoinAutoResolveAmbiguity, + sparkSession.sessionState.analyzer.resolver, + this.exprEnc.isSerializedAsStructForTopLevel, + other.exprEnc.isSerializedAsStructForTopLevel)) + } /** * Using inner equi-join to join this Dataset returning a `Tuple2` for each pair @@ -1416,14 +1426,16 @@ class Dataset[T] private[sql]( * @since 2.2.0 */ @scala.annotation.varargs - def hint(name: String, parameters: Any*): Dataset[T] = withTypedPlan { - val exprs = parameters.map { - case c: Column => c.expr - case s: Symbol => Column(s.name).expr - case e: Expression => e - case literal => Literal(literal) - }.toSeq - UnresolvedHint(name, exprs, logicalPlan) + def hint(name: String, parameters: Any*): Dataset[T] = withOrigin { + withTypedPlan { + val exprs = parameters.map { + case c: Column => c.expr + case s: Symbol => Column(s.name).expr + case e: Expression => e + case literal => Literal(literal) + }.toSeq + UnresolvedHint(name, exprs, logicalPlan) + } } /** @@ -1499,8 +1511,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan) + def as(alias: String): Dataset[T] = withOrigin { + withTypedPlan { + SubqueryAlias(alias, logicalPlan) + } } /** @@ -1537,25 +1551,28 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def select(cols: Column*): DataFrame = withPlan { - val untypedCols = cols.map { - case typedCol: TypedColumn[_, _] => - // Checks if a `TypedColumn` has been inserted with - // specific input type and schema by `withInputType`. - val needInputType = typedCol.expr.exists { - case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => true - case _ => false - } + def select(cols: Column*): DataFrame = withOrigin { + withPlan { + val untypedCols = cols.map { + case typedCol: TypedColumn[_, _] => + // Checks if a `TypedColumn` has been inserted with + // specific input type and schema by `withInputType`. + val needInputType = typedCol.expr.exists { + case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => true + case _ => false + } - if (!needInputType) { - typedCol - } else { - throw QueryCompilationErrors.cannotPassTypedColumnInUntypedSelectError(typedCol.toString) - } + if (!needInputType) { + typedCol + } else { + throw + QueryCompilationErrors.cannotPassTypedColumnInUntypedSelectError(typedCol.toString) + } - case other => other + case other => other + } + Project(untypedCols.map(_.named), logicalPlan) } - Project(untypedCols.map(_.named), logicalPlan) } /** @@ -1572,7 +1589,9 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) : _*) + def select(col: String, cols: String*): DataFrame = withOrigin { + select((col +: cols).map(Column(_)) : _*) + } /** * Selects a set of SQL expressions. This is a variant of `select` that accepts @@ -1588,10 +1607,12 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @scala.annotation.varargs - def selectExpr(exprs: String*): DataFrame = sparkSession.withActive { - select(exprs.map { expr => - Column(sparkSession.sessionState.sqlParser.parseExpression(expr)) - }: _*) + def selectExpr(exprs: String*): DataFrame = withOrigin { + sparkSession.withActive { + select(exprs.map { expr => + Column(sparkSession.sessionState.sqlParser.parseExpression(expr)) + }: _*) + } } /** @@ -1605,7 +1626,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { + def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = withOrigin { implicit val encoder = c1.encoder val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) @@ -1689,8 +1710,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def filter(condition: Column): Dataset[T] = withTypedPlan { - Filter(condition.expr, logicalPlan) + def filter(condition: Column): Dataset[T] = withOrigin { + withTypedPlan { + Filter(condition.expr, logicalPlan) + } } /** @@ -2049,15 +2072,17 @@ class Dataset[T] private[sql]( ids: Array[Column], values: Array[Column], variableColumnName: String, - valueColumnName: String): DataFrame = withPlan { - Unpivot( - Some(ids.map(_.named)), - Some(values.map(v => Seq(v.named))), - None, - variableColumnName, - Seq(valueColumnName), - logicalPlan - ) + valueColumnName: String): DataFrame = withOrigin { + withPlan { + Unpivot( + Some(ids.map(_.named)), + Some(values.map(v => Seq(v.named))), + None, + variableColumnName, + Seq(valueColumnName), + logicalPlan + ) + } } /** @@ -2080,15 +2105,17 @@ class Dataset[T] private[sql]( def unpivot( ids: Array[Column], variableColumnName: String, - valueColumnName: String): DataFrame = withPlan { - Unpivot( - Some(ids.map(_.named)), - None, - None, - variableColumnName, - Seq(valueColumnName), - logicalPlan - ) + valueColumnName: String): DataFrame = withOrigin { + withPlan { + Unpivot( + Some(ids.map(_.named)), + None, + None, + variableColumnName, + Seq(valueColumnName), + logicalPlan + ) + } } /** @@ -2205,8 +2232,10 @@ class Dataset[T] private[sql]( * @since 3.0.0 */ @varargs - def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withTypedPlan { - CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id) + def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = withOrigin { + withTypedPlan { + CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id) + } } /** @@ -2243,8 +2272,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def limit(n: Int): Dataset[T] = withTypedPlan { - Limit(Literal(n), logicalPlan) + def limit(n: Int): Dataset[T] = withOrigin { + withTypedPlan { + Limit(Literal(n), logicalPlan) + } } /** @@ -2253,8 +2284,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 3.4.0 */ - def offset(n: Int): Dataset[T] = withTypedPlan { - Offset(Literal(n), logicalPlan) + def offset(n: Int): Dataset[T] = withOrigin { + withTypedPlan { + Offset(Literal(n), logicalPlan) + } } // This breaks caching, but it's usually ok because it addresses a very specific use case: @@ -2664,20 +2697,20 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") - def explode[A <: Product : TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame = { - val elementSchema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - - val convert = CatalystTypeConverters.createToCatalystConverter(elementSchema) - - val rowFunction = - f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) - val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr)) - - withPlan { - Generate(generator, unrequiredChildIndex = Nil, outer = false, - qualifier = None, generatorOutput = Nil, logicalPlan) + def explode[A <: Product : TypeTag](input: Column*)(f: Row => IterableOnce[A]): DataFrame = + withOrigin { + val elementSchema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val convert = CatalystTypeConverters.createToCatalystConverter(elementSchema) + + val rowFunction = + f.andThen(_.map(convert(_).asInstanceOf[InternalRow])) + val generator = UserDefinedGenerator(elementSchema, rowFunction, input.map(_.expr)) + + withPlan { + Generate(generator, unrequiredChildIndex = Nil, outer = false, + qualifier = None, generatorOutput = Nil, logicalPlan) + } } - } /** * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero @@ -2702,7 +2735,7 @@ class Dataset[T] private[sql]( */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => IterableOnce[B]) - : DataFrame = { + : DataFrame = withOrigin { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil // TODO handle the metadata? @@ -2859,7 +2892,7 @@ class Dataset[T] private[sql]( * @since 3.4.0 */ @throws[AnalysisException] - def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = { + def withColumnsRenamed(colsMap: Map[String, String]): DataFrame = withOrigin { val resolver = sparkSession.sessionState.analyzer.resolver val output: Seq[NamedExpression] = queryExecution.analyzed.output @@ -3073,9 +3106,11 @@ class Dataset[T] private[sql]( * @group typedrel * @since 2.0.0 */ - def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan { - val groupCols = groupColsFromDropDuplicates(colNames) - Deduplicate(groupCols, logicalPlan) + def dropDuplicates(colNames: Seq[String]): Dataset[T] = withOrigin { + withTypedPlan { + val groupCols = groupColsFromDropDuplicates(colNames) + Deduplicate(groupCols, logicalPlan) + } } /** @@ -3151,10 +3186,12 @@ class Dataset[T] private[sql]( * @group typedrel * @since 3.5.0 */ - def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = withTypedPlan { - val groupCols = groupColsFromDropDuplicates(colNames) - // UnsupportedOperationChecker will fail the query if this is called with batch Dataset. - DeduplicateWithinWatermark(groupCols, logicalPlan) + def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = withOrigin { + withTypedPlan { + val groupCols = groupColsFromDropDuplicates(colNames) + // UnsupportedOperationChecker will fail the query if this is called with batch Dataset. + DeduplicateWithinWatermark(groupCols, logicalPlan) + } } /** @@ -3378,7 +3415,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def filter(func: T => Boolean): Dataset[T] = { + def filter(func: T => Boolean): Dataset[T] = withOrigin { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -3389,7 +3426,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def filter(func: FilterFunction[T]): Dataset[T] = { + def filter(func: FilterFunction[T]): Dataset[T] = withOrigin { withTypedPlan(TypedFilter(func, logicalPlan)) } @@ -3400,8 +3437,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { - MapElements[T, U](func, logicalPlan) + def map[U : Encoder](func: T => U): Dataset[U] = withOrigin { + withTypedPlan { + MapElements[T, U](func, logicalPlan) + } } /** @@ -3411,7 +3450,7 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = withOrigin { implicit val uEnc = encoder withTypedPlan(MapElements[T, U](func, logicalPlan)) } @@ -3574,8 +3613,9 @@ class Dataset[T] private[sql]( * @group action * @since 3.0.0 */ - def tail(n: Int): Array[T] = withAction( - "tail", withTypedPlan(Tail(Literal(n), logicalPlan)).queryExecution)(collectFromPlan) + def tail(n: Int): Array[T] = withOrigin { + withAction("tail", withTypedPlan(Tail(Literal(n), logicalPlan)).queryExecution)(collectFromPlan) + } /** * Returns the first `n` rows in the Dataset as a list. @@ -3639,8 +3679,10 @@ class Dataset[T] private[sql]( * @group action * @since 1.6.0 */ - def count(): Long = withAction("count", groupBy().count().queryExecution) { plan => - plan.executeCollect().head.getLong(0) + def count(): Long = withOrigin { + withAction("count", groupBy().count().queryExecution) { plan => + plan.executeCollect().head.getLong(0) + } } /** @@ -3649,13 +3691,15 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def repartition(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = true, logicalPlan) + def repartition(numPartitions: Int): Dataset[T] = withOrigin { + withTypedPlan { + Repartition(numPartitions, shuffle = true, logicalPlan) + } } private def repartitionByExpression( numPartitions: Option[Int], - partitionExprs: Seq[Column]): Dataset[T] = { + partitionExprs: Seq[Column]): Dataset[T] = withOrigin { // The underlying `LogicalPlan` operator special-cases all-`SortOrder` arguments. // However, we don't want to complicate the semantics of this API method. // Instead, let's give users a friendly error message, pointing them to the new method. @@ -3700,7 +3744,7 @@ class Dataset[T] private[sql]( private def repartitionByRange( numPartitions: Option[Int], - partitionExprs: Seq[Column]): Dataset[T] = { + partitionExprs: Seq[Column]): Dataset[T] = withOrigin { require(partitionExprs.nonEmpty, "At least one partition-by expression must be specified.") val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match { case expr: SortOrder => expr @@ -3772,8 +3816,10 @@ class Dataset[T] private[sql]( * @group typedrel * @since 1.6.0 */ - def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan { - Repartition(numPartitions, shuffle = false, logicalPlan) + def coalesce(numPartitions: Int): Dataset[T] = withOrigin { + withTypedPlan { + Repartition(numPartitions, shuffle = false, logicalPlan) + } } /** @@ -3917,8 +3963,10 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @throws[AnalysisException] - def createTempView(viewName: String): Unit = withPlan { - createTempViewCommand(viewName, replace = false, global = false) + def createTempView(viewName: String): Unit = withOrigin { + withPlan { + createTempViewCommand(viewName, replace = false, global = false) + } } @@ -3930,8 +3978,10 @@ class Dataset[T] private[sql]( * @group basic * @since 2.0.0 */ - def createOrReplaceTempView(viewName: String): Unit = withPlan { - createTempViewCommand(viewName, replace = true, global = false) + def createOrReplaceTempView(viewName: String): Unit = withOrigin { + withPlan { + createTempViewCommand(viewName, replace = true, global = false) + } } /** @@ -3949,8 +3999,10 @@ class Dataset[T] private[sql]( * @since 2.1.0 */ @throws[AnalysisException] - def createGlobalTempView(viewName: String): Unit = withPlan { - createTempViewCommand(viewName, replace = false, global = true) + def createGlobalTempView(viewName: String): Unit = withOrigin { + withPlan { + createTempViewCommand(viewName, replace = false, global = true) + } } /** @@ -4358,7 +4410,7 @@ class Dataset[T] private[sql]( plan.executeCollect().map(fromRow) } - private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = { + private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = withOrigin { val sortOrder: Seq[SortOrder] = sortExprs.map { col => col.expr match { case expr: SortOrder => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 936339e091d8f..89c7cae175aed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -311,7 +311,7 @@ private[parquet] class ParquetRowConverter( case LongType if isUnsignedIntTypeMatched(32) => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = - updater.setLong(Integer.toUnsignedLong(value)) + this.updater.setLong(Integer.toUnsignedLong(value)) } case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType | _: AnsiIntervalType => @@ -320,13 +320,13 @@ private[parquet] class ParquetRowConverter( case ByteType => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = - updater.setByte(value.asInstanceOf[PhysicalByteType#InternalType]) + this.updater.setByte(value.asInstanceOf[PhysicalByteType#InternalType]) } case ShortType => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = - updater.setShort(value.asInstanceOf[PhysicalShortType#InternalType]) + this.updater.setShort(value.asInstanceOf[PhysicalShortType#InternalType]) } // For INT32 backed decimals @@ -346,7 +346,7 @@ private[parquet] class ParquetRowConverter( case _: DecimalType if isUnsignedIntTypeMatched(64) => new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { - updater.set(Decimal(java.lang.Long.toUnsignedString(value))) + this.updater.set(Decimal(java.lang.Long.toUnsignedString(value))) } } @@ -391,7 +391,7 @@ private[parquet] class ParquetRowConverter( .asInstanceOf[TimestampLogicalTypeAnnotation].getUnit == TimeUnit.MICROS => new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { - updater.setLong(timestampRebaseFunc(value)) + this.updater.setLong(timestampRebaseFunc(value)) } } @@ -404,7 +404,7 @@ private[parquet] class ParquetRowConverter( new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { val micros = DateTimeUtils.millisToMicros(value) - updater.setLong(timestampRebaseFunc(micros)) + this.updater.setLong(timestampRebaseFunc(micros)) } } @@ -417,7 +417,7 @@ private[parquet] class ParquetRowConverter( val gregorianMicros = int96RebaseFunc(julianMicros) val adjTime = convertTz.map(DateTimeUtils.convertTz(gregorianMicros, _, ZoneOffset.UTC)) .getOrElse(gregorianMicros) - updater.setLong(adjTime) + this.updater.setLong(adjTime) } } @@ -434,14 +434,14 @@ private[parquet] class ParquetRowConverter( new ParquetPrimitiveConverter(updater) { override def addLong(value: Long): Unit = { val micros = DateTimeUtils.millisToMicros(value) - updater.setLong(micros) + this.updater.setLong(micros) } } case DateType => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { - updater.set(dateRebaseFunc(value)) + this.updater.set(dateRebaseFunc(value)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala index 0e3243eac6230..5e0c5ff92fdab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala @@ -233,7 +233,7 @@ private[python] case class HybridRowQueue( val buffer = if (page != null) { new InMemoryRowQueue(page, numFields) { override def close(): Unit = { - freePage(page) + freePage(this.page) } } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 58f720154df52..771c743f70629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution +import org.apache.spark.QueryContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, ExprId, InSet, ListQuery, Literal, PlanExpression, Predicate, SupportQueryContext} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.{LeafLike, SQLQueryContext, UnaryLike} +import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -68,7 +69,7 @@ case class ScalarSubquery( override def nullable: Boolean = true override def toString: String = plan.simpleString(SQLConf.get.maxToStringFields) override def withNewPlan(query: BaseSubqueryExec): ScalarSubquery = copy(plan = query) - def initQueryContext(): Option[SQLQueryContext] = Some(origin.context) + def initQueryContext(): Option[QueryContext] = Some(origin.context) override lazy val canonicalized: Expression = { ScalarSubquery(plan.canonicalized.asInstanceOf[BaseSubqueryExec], ExprId(0)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b5e40fe35cfe1..a42df5bbcc292 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -93,11 +93,13 @@ import org.apache.spark.util.Utils object functions { // scalastyle:on - private def withExpr(expr: Expression): Column = Column(expr) + private def withExpr(expr: => Expression): Column = withOrigin { + Column(expr) + } private def withAggregateFunction( - func: AggregateFunction, - isDistinct: Boolean = false): Column = { + func: => AggregateFunction, + isDistinct: Boolean = false): Column = withOrigin { Column(func.toAggregateExpression(isDistinct)) } @@ -127,16 +129,18 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def lit(literal: Any): Column = literal match { - case c: Column => c - case s: Symbol => new ColumnName(s.name) - case _ => - // This is different from `typedlit`. `typedlit` calls `Literal.create` to use - // `ScalaReflection` to get the type of `literal`. However, since we use `Any` in this method, - // `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, we can - // just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. This is - // significantly better when there are many threads calling `lit` concurrently. + def lit(literal: Any): Column = withOrigin { + literal match { + case c: Column => c + case s: Symbol => new ColumnName(s.name) + case _ => + // This is different from `typedlit`. `typedlit` calls `Literal.create` to use + // `ScalaReflection` to get the type of `literal`. However, since we use `Any` in this + // method, `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, + // we can just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. + // This is significantly better when there are many threads calling `lit` concurrently. Column(Literal(literal)) + } } /** @@ -147,7 +151,9 @@ object functions { * @group normal_funcs * @since 2.2.0 */ - def typedLit[T : TypeTag](literal: T): Column = typedlit(literal) + def typedLit[T : TypeTag](literal: T): Column = withOrigin { + typedlit(literal) + } /** * Creates a [[Column]] of literal value. @@ -164,10 +170,12 @@ object functions { * @group normal_funcs * @since 3.2.0 */ - def typedlit[T : TypeTag](literal: T): Column = literal match { - case c: Column => c - case s: Symbol => new ColumnName(s.name) - case _ => Column(Literal.create(literal)) + def typedlit[T : TypeTag](literal: T): Column = withOrigin { + literal match { + case c: Column => c + case s: Symbol => new ColumnName(s.name) + case _ => Column(Literal.create(literal)) + } } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -5965,25 +5973,31 @@ object functions { def array_except(col1: Column, col2: Column): Column = Column.fn("array_except", col1, col2) - private def createLambda(f: Column => Column) = Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val function = f(Column(x)).expr - LambdaFunction(function, Seq(x)) + private def createLambda(f: Column => Column) = withOrigin { + Column { + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val function = f(Column(x)).expr + LambdaFunction(function, Seq(x)) + } } - private def createLambda(f: (Column, Column) => Column) = Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) - val function = f(Column(x), Column(y)).expr - LambdaFunction(function, Seq(x, y)) + private def createLambda(f: (Column, Column) => Column) = withOrigin { + Column { + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) + val function = f(Column(x), Column(y)).expr + LambdaFunction(function, Seq(x, y)) + } } - private def createLambda(f: (Column, Column, Column) => Column) = Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) - val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z"))) - val function = f(Column(x), Column(y), Column(z)).expr - LambdaFunction(function, Seq(x, y, z)) + private def createLambda(f: (Column, Column, Column) => Column) = withOrigin { + Column { + val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) + val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) + val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z"))) + val function = f(Column(x), Column(y), Column(z)).expr + LambdaFunction(function, Seq(x, y, z)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 5543b409d1702..1d496b027ef5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -188,7 +188,7 @@ abstract class BaseSessionStateBuilder( new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: - new ResolveSessionCatalog(catalogManager) +: + new ResolveSessionCatalog(this.catalogManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: customResolutionRules diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 1794ac513749f..7f00f6d6317c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -17,7 +17,10 @@ package org.apache.spark +import java.util.regex.Pattern + import org.apache.spark.annotation.{DeveloperApi, Unstable} +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} import org.apache.spark.sql.execution.SparkStrategy /** @@ -73,4 +76,41 @@ package object sql { * with rebasing. */ private[sql] val SPARK_LEGACY_INT96_METADATA_KEY = "org.apache.spark.legacyINT96" + + /** + * This helper function captures the Spark API and its call site in the user code from the current + * stacktrace. + * + * As adding `withOrigin` explicitly to all Spark API definition would be a huge change, + * `withOrigin` is used only at certain places where all API implementation surely pass through + * and the current stacktrace is filtered to the point where first Spark API code is invoked from + * the user code. + * + * As there might be multiple nested `withOrigin` calls (e.g. any Spark API implementations can + * invoke other APIs) only the first `withOrigin` is captured because that is closer to the user + * code. + * + * @param f The function that can use the origin. + * @return The result of `f`. + */ + private[sql] def withOrigin[T](f: => T): T = { + if (CurrentOrigin.get.stackTrace.isDefined) { + f + } else { + val st = Thread.currentThread().getStackTrace + var i = 3 + while (i < st.length && sparkCode(st(i))) i += 1 + val origin = + Origin(stackTrace = Some(Thread.currentThread().getStackTrace.slice(i - 1, i + 1))) + CurrentOrigin.withOrigin(origin)(f) + } + } + + private val sparkCodePattern = Pattern.compile("org\\.apache\\.spark\\.sql\\." + + "(?:functions|Column|ColumnName|SQLImplicits|Dataset|DataFrameStatFunctions)" + + "(?:|\\..*|\\$.*)") + + private def sparkCode(ste: StackTraceElement): Boolean = { + sparkCodePattern.matcher(ste.getClassName).matches() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 8c9ad2180faa3..140daced32234 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -458,7 +458,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { parameters = Map( "functionName" -> "`in`", "dataType" -> "[\"INT\", \"ARRAY\"]", - "sqlExpr" -> "\"(a IN (b))\"") + "sqlExpr" -> "\"(a IN (b))\""), + context = ExpectedContext(fragment = "isin", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -525,7 +526,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { parameters = Map( "functionName" -> "`in`", "dataType" -> "[\"INT\", \"ARRAY\"]", - "sqlExpr" -> "\"(a IN (b))\"") + "sqlExpr" -> "\"(a IN (b))\""), + context = ExpectedContext( + fragment = "isInCollection", + callSitePattern = getCurrentClassCallSitePattern) ) } } @@ -1056,7 +1060,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"key\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = ExpectedContext( + fragment = "withField", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1101,7 +1108,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"a.b\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = ExpectedContext( + fragment = "withField", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1849,7 +1859,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"key\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = ExpectedContext( + fragment = "dropFields", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1886,7 +1899,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"a.b\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"STRUCT\"") + "requiredType" -> "\"STRUCT\""), + context = ExpectedContext( + fragment = "dropFields", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1952,7 +1968,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel1.withColumn("a", $"a".dropFields("a", "b", "c")) }, errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", - parameters = Map("sqlExpr" -> "\"update_fields(a, dropfield(), dropfield(), dropfield())\"") + parameters = Map("sqlExpr" -> "\"update_fields(a, dropfield(), dropfield(), dropfield())\""), + context = ExpectedContext( + fragment = "dropFields", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -2224,7 +2243,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"struct_col".dropFields("a", "b")) }, errorClass = "DATATYPE_MISMATCH.CANNOT_DROP_ALL_FIELDS", - parameters = Map("sqlExpr" -> "\"update_fields(struct_col, dropfield(), dropfield())\"") + parameters = Map("sqlExpr" -> "\"update_fields(struct_col, dropfield(), dropfield())\""), + context = ExpectedContext( + fragment = "dropFields", + callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer( @@ -2398,7 +2420,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel1.select($"a".withField("d", lit(4)).withField("e", $"a.d" + 1).as("a")) }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`d`", "fields" -> "`a`, `b`, `c`")) + parameters = Map("fieldName" -> "`d`", "fields" -> "`a`, `b`, `c`"), + context = ExpectedContext( + fragment = "$", + callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( structLevel1 @@ -2451,7 +2476,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"a".withField("z", $"a.c")).as("a") }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`c`", "fields" -> "`a`, `b`")) + parameters = Map("fieldName" -> "`c`", "fields" -> "`a`, `b`"), + context = ExpectedContext( + fragment = "$", + callSitePattern = getCurrentClassCallSitePattern)) } test("nestedDf should generate nested DataFrames") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index c40ecb88257f2..e7c1f0414b619 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -52,7 +52,8 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "inputSchema" -> "\"ARRAY\"", "dataType" -> "\"ARRAY\"" - ) + ), + context = ExpectedContext(fragment = "from_csv", getCurrentClassCallSitePattern) ) checkError( @@ -395,7 +396,8 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { .select(from_csv($"csv", $"schema", options)).collect() }, errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", - parameters = Map("inputSchema" -> "\"schema\"") + parameters = Map("inputSchema" -> "\"schema\""), + context = ExpectedContext(fragment = "from_csv", getCurrentClassCallSitePattern) ) checkError( @@ -403,7 +405,8 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { Seq("1").toDF("csv").select(from_csv($"csv", lit(1), options)).collect() }, errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", - parameters = Map("inputSchema" -> "\"1\"") + parameters = Map("inputSchema" -> "\"1\""), + context = ExpectedContext(fragment = "from_csv", getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f34d7cf368072..c8eea985c1065 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -633,7 +633,10 @@ class DataFrameAggregateSuite extends QueryTest "functionName" -> "`collect_set`", "dataType" -> "\"MAP\"", "sqlExpr" -> "\"collect_set(b)\"" - ) + ), + context = ExpectedContext( + fragment = "collect_set", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -706,7 +709,8 @@ class DataFrameAggregateSuite extends QueryTest testData.groupBy(sum($"key")).count() }, errorClass = "GROUP_BY_AGGREGATE", - parameters = Map("sqlExpr" -> "sum(key)") + parameters = Map("sqlExpr" -> "sum(key)"), + context = ExpectedContext(fragment = "sum", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1302,7 +1306,8 @@ class DataFrameAggregateSuite extends QueryTest "paramIndex" -> "2", "inputSql" -> "\"a\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"INTEGRAL\"")) + "requiredType" -> "\"INTEGRAL\""), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) } test("SPARK-34716: Support ANSI SQL intervals by the aggregate function `sum`") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 789196583c600..135ce834bfe5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -171,7 +171,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"k\"", "inputType" -> "\"INT\"" - ) + ), + queryContext = Array( + ExpectedContext( + fragment = "map_from_arrays", + callSitePattern = getCurrentClassCallSitePattern)) ) val df5 = Seq((Seq("a", null), Seq(1, 2))).toDF("k", "v") @@ -758,7 +762,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { test("The given function only supports array input") { val df = Seq(1, 2, 3).toDF("a") - checkErrorMatchPVals( + checkError( exception = intercept[AnalysisException] { df.select(array_sort(col("a"), (x, y) => x - y)) }, @@ -769,7 +773,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"a\"", "inputType" -> "\"INT\"" - )) + ), + matchPVals = true, + queryContext = Array( + ExpectedContext( + fragment = "array_sort", + callSitePattern = getCurrentClassCallSitePattern)) + ) } test("sort_array/array_sort functions") { @@ -1305,7 +1315,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, map2)\"", "dataType" -> "(\"MAP, INT>\" or \"MAP\")", - "functionName" -> "`map_concat`") + "functionName" -> "`map_concat`"), + context = + ExpectedContext( + fragment = "map_concat", + callSitePattern = getCurrentClassCallSitePattern) ) checkError( @@ -1333,7 +1347,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map_concat(map1, 12)\"", "dataType" -> "[\"MAP, INT>\", \"INT\"]", - "functionName" -> "`map_concat`") + "functionName" -> "`map_concat`"), + context = + ExpectedContext( + fragment = "map_concat", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1402,7 +1420,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"a\"", "inputType" -> "\"INT\"", "requiredType" -> "\"ARRAY\" of pair \"STRUCT\"" - ) + ), + context = + ExpectedContext( + fragment = "map_from_entries", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -1439,7 +1461,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array_contains(a, NULL)\"", "functionName" -> "`array_contains`" - ) + ), + context = + ExpectedContext( + fragment = "array_contains", + callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2348,7 +2374,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", - "rightType" -> "\"ARRAY\"")) + "rightType" -> "\"ARRAY\""), + context = + ExpectedContext( + fragment = "array_union", + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { @@ -2379,7 +2409,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", - "rightType" -> "\"VOID\"") + "rightType" -> "\"VOID\""), + context = ExpectedContext( + fragment = "array_union", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2410,7 +2442,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "functionName" -> "`array_union`", "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY>\"", - "rightType" -> "\"ARRAY\"") + "rightType" -> "\"ARRAY\""), + queryContext = Array(ExpectedContext( + fragment = "array_union", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2647,7 +2681,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"arr\"", "inputType" -> "\"ARRAY\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + context = ExpectedContext( + fragment = "flatten", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2660,7 +2696,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "flatten", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2673,7 +2711,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"s\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"ARRAY\" of \"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "flatten", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -2782,7 +2822,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"b\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - ) + ), + context = ExpectedContext( + fragment = "array_repeat", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -2795,7 +2837,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"1\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_repeat", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3123,7 +3167,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"VOID\"" - ) + ), + context = ExpectedContext( + fragment = "array_except", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -3151,7 +3197,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3179,7 +3227,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"VOID\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3207,7 +3257,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_except", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3276,7 +3328,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"VOID\"" - ) + ), + context = ExpectedContext( + fragment = "array_intersect", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -3305,7 +3359,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array(ExpectedContext( + fragment = "array_intersect", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3334,7 +3390,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "arrayType" -> "\"ARRAY\"", "leftType" -> "\"VOID\"", "rightType" -> "\"ARRAY\"" - ) + ), + queryContext = Array( + ExpectedContext( + fragment = "array_intersect", + callSitePattern = getCurrentClassCallSitePattern)) ) checkError( exception = intercept[AnalysisException] { @@ -3750,7 +3810,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"MAP\"")) + "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(fragment = "map_filter", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = @@ -3933,7 +3995,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "filter", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -3945,7 +4009,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + ExpectedContext( + fragment = "filter(s, x -> x)", + start = 0, + stop = 16)) checkError( exception = intercept[AnalysisException] { @@ -3957,7 +4025,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = ExpectedContext( + fragment = "filter", + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = @@ -4112,7 +4183,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "exists", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -4124,7 +4197,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = ExpectedContext( + fragment = "exists(s, x -> x)", + start = 0, + stop = 16) + ) checkError( exception = intercept[AnalysisException] { @@ -4136,7 +4214,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = + ExpectedContext(fragment = "exists", callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException](df.selectExpr("exists(a, x -> x)")), @@ -4304,7 +4384,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "forall", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = intercept[AnalysisException] { @@ -4316,7 +4398,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = ExpectedContext( + fragment = "forall(s, x -> x)", + start = 0, + stop = 16)) checkError( exception = intercept[AnalysisException] { @@ -4328,7 +4414,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "2", "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", - "requiredType" -> "\"BOOLEAN\"")) + "requiredType" -> "\"BOOLEAN\""), + context = + ExpectedContext(fragment = "forall", callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException](df.selectExpr("forall(a, x -> x)")), @@ -4343,7 +4431,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkError( exception = intercept[AnalysisException](df.select(forall(col("a"), x => x))), errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`")) + parameters = Map("objectName" -> "`a`", "proposal" -> "`i`, `s`"), + queryContext = Array( + ExpectedContext(fragment = "col", callSitePattern = getCurrentClassCallSitePattern))) } test("aggregate function - array for primitive type not containing null") { @@ -4581,7 +4671,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "aggregate", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit // scalastyle:off line.size.limit @@ -4597,7 +4689,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - )) + ), + context = ExpectedContext( + fragment = s"$agg(s, 0, (acc, x) -> x)", + start = 0, + stop = agg.length + 20)) } // scalastyle:on line.size.limit @@ -4613,7 +4709,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "inputSql" -> "\"lambdafunction(namedlambdavariable(), namedlambdavariable(), namedlambdavariable())\"", "inputType" -> "\"STRING\"", "requiredType" -> "\"INT\"" - )) + ), + context = + ExpectedContext(fragment = "aggregate", callSitePattern = getCurrentClassCallSitePattern)) // scalastyle:on line.size.limit Seq("aggregate", "reduce").foreach { agg => @@ -4719,7 +4817,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(mis, mmi, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "functionName" -> "`map_zip_with`", "leftType" -> "\"INT\"", - "rightType" -> "\"MAP\"")) + "rightType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(fragment = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -4749,7 +4849,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(i, mis, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "paramIndex" -> "1", "inputSql" -> "\"i\"", - "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\"")) + "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(fragment = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -4779,7 +4881,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> """"map_zip_with\(mis, i, lambdafunction\(concat\(x_\d+, y_\d+, z_\d+\), x_\d+, y_\d+, z_\d+\)\)"""", "paramIndex" -> "2", "inputSql" -> "\"i\"", - "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\"")) + "inputType" -> "\"INT\"", "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext(fragment = "map_zip_with", callSitePattern = getCurrentClassCallSitePattern))) // scalastyle:on line.size.limit checkError( @@ -5235,7 +5339,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"x\"", "inputType" -> "\"ARRAY\"", - "requiredType" -> "\"MAP\"")) + "requiredType" -> "\"MAP\""), + queryContext = Array( + ExpectedContext( + fragment = "transform_values", + callSitePattern = getCurrentClassCallSitePattern))) } testInvalidLambdaFunctions() @@ -5375,7 +5483,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"i\"", "inputType" -> "\"INT\"", - "requiredType" -> "\"ARRAY\"")) + "requiredType" -> "\"ARRAY\""), + queryContext = Array( + ExpectedContext(fragment = "zip_with", callSitePattern = getCurrentClassCallSitePattern))) checkError( exception = @@ -5631,7 +5741,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"map(m, 1)\"", "keyType" -> "\"MAP\"" - ) + ), + context = + ExpectedContext(fragment = "map", callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer( df.select(map(map_entries($"m"), lit(1))), @@ -5753,7 +5865,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "requiredType" -> "\"ARRAY\"", "inputSql" -> "\"a\"", "inputType" -> "\"INT\"" - )) + ), + context = ExpectedContext( + fragment = "array_compact", + callSitePattern = getCurrentClassCallSitePattern)) } test("array_append -> Unit Test cases for the function ") { @@ -5772,7 +5887,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "dataType" -> "\"ARRAY\"", "leftType" -> "\"ARRAY\"", "rightType" -> "\"INT\"", - "sqlExpr" -> "\"array_append(a, b)\"") + "sqlExpr" -> "\"array_append(a, b)\""), + context = + ExpectedContext(fragment = "array_append", callSitePattern = getCurrentClassCallSitePattern) ) checkAnswer(df1.selectExpr("array_append(a, 3)"), Seq(Row(Seq(3, 2, 5, 1, 2, 3)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 237915fb63fa8..b3bf9405a99f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -310,7 +310,8 @@ class DataFramePivotSuite extends QueryTest with SharedSparkSession { .agg(sum($"sales.earnings")) }, errorClass = "GROUP_BY_AGGREGATE", - parameters = Map("sqlExpr" -> "min(training)") + parameters = Map("sqlExpr" -> "min(training)"), + context = ExpectedContext(fragment = "min", callSitePattern = getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 88ef5936264de..c777d2207584d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -484,7 +484,8 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { checkError(ex, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map("objectName" -> "`df1`.`timeStr`", - "proposal" -> "`df3`.`timeStr`, `df1`.`tsStr`")) + "proposal" -> "`df3`.`timeStr`, `df1`.`tsStr`"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index ab8aab0713a44..e7c1d2c772c08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -373,7 +373,10 @@ class DataFrameSetOperationsSuite extends QueryTest errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", parameters = Map( "colName" -> "`m`", - "dataType" -> "\"MAP\"") + "dataType" -> "\"MAP\""), + context = ExpectedContext( + fragment = "distinct", + callSitePattern = getCurrentClassCallSitePattern) ) withTempView("v") { df.createOrReplaceTempView("v") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 430e36221025a..20ac2a9e9461d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -146,7 +146,9 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { parameters = Map( "name" -> "`num`", "referenceNames" -> "[`table1`.`num`, `table2`.`num`]" - ) + ), + context = + ExpectedContext(fragment = "freqItems", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { @@ -156,7 +158,9 @@ class DataFrameStatSuite extends QueryTest with SharedSparkSession { parameters = Map( "name" -> "`num`", "referenceNames" -> "[`table1`.`num`, `table2`.`num`]" - ) + ), + context = ExpectedContext( + fragment = "approxQuantile", callSitePattern = getCurrentClassCallSitePattern) ) checkError( exception = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index b7450e5648727..b0a0b189cb7f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -354,7 +354,8 @@ class DataFrameSuite extends QueryTest "paramIndex" -> "1", "inputSql"-> "\"csv\"", "inputType" -> "\"STRING\"", - "requiredType" -> "(\"ARRAY\" or \"MAP\")") + "requiredType" -> "(\"ARRAY\" or \"MAP\")"), + context = ExpectedContext(fragment = "explode", getCurrentClassCallSitePattern) ) val df2 = Seq(Array("1", "2"), Array("4"), Array("7", "8", "9")).toDF("csv") @@ -2947,7 +2948,8 @@ class DataFrameSuite extends QueryTest df.groupBy($"d", $"b").as[GroupByKey, Row] }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`")) + parameters = Map("objectName" -> "`d`", "proposal" -> "`a`, `b`, `c`"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) } test("SPARK-40601: flatMapCoGroupsInPandas should fail with different number of keys") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 2a81f7e7c2f34..bb744cfd8ab4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -184,7 +184,9 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + """BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING\)"""") ), - matchPVals = true + matchPVals = true, + queryContext = + Array(ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( @@ -198,7 +200,9 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + """BETWEEN -1 FOLLOWING AND UNBOUNDED FOLLOWING\)"""") ), - matchPVals = true + matchPVals = true, + queryContext = + Array(ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern)) ) checkError( @@ -212,7 +216,9 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "sqlExpr" -> (""""\(ORDER BY key ASC NULLS FIRST, value ASC NULLS FIRST RANGE """ + """BETWEEN -1 FOLLOWING AND 1 FOLLOWING\)"""") ), - matchPVals = true + matchPVals = true, + queryContext = + Array(ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern)) ) } @@ -240,7 +246,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "expectedType" -> ("(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR " + "TO MONTH\" or \"INTERVAL\")"), "sqlExpr" -> "\"RANGE BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING\"" - ) + ), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) checkError( @@ -255,7 +262,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "expectedType" -> ("(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR " + "TO MONTH\" or \"INTERVAL\")"), "sqlExpr" -> "\"RANGE BETWEEN -1 FOLLOWING AND UNBOUNDED FOLLOWING\"" - ) + ), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) checkError( @@ -270,7 +278,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "expectedType" -> ("(\"NUMERIC\" or \"INTERVAL DAY TO SECOND\" or \"INTERVAL YEAR " + "TO MONTH\" or \"INTERVAL\")"), "sqlExpr" -> "\"RANGE BETWEEN -1 FOLLOWING AND 1 FOLLOWING\"" - ) + ), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -462,7 +471,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { "upper" -> "\"2\"", "lowerType" -> "\"INTERVAL\"", "upperType" -> "\"BIGINT\"" - ) + ), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -481,7 +491,8 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"RANGE BETWEEN nonfoldableliteral() FOLLOWING AND 2 FOLLOWING\"", "location" -> "lower", - "expression" -> "\"nonfoldableliteral()\"") + "expression" -> "\"nonfoldableliteral()\""), + context = ExpectedContext(fragment = "over", callSitePattern = getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 84133eb485f0f..6969c4303e01e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -412,7 +412,10 @@ class DataFrameWindowFunctionsSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`invalid`", - "proposal" -> "`value`, `key`")) + "proposal" -> "`value`, `key`"), + context = ExpectedContext( + fragment = "count", + callSitePattern = getCurrentClassCallSitePattern)) } test("numerical aggregate functions on string column") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index bf78e6e11fe99..66105d2ac429f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -606,7 +606,8 @@ class DatasetSuite extends QueryTest } }, errorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", - parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups")) + parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } test("groupBy function, flatMapSorted") { @@ -634,7 +635,8 @@ class DatasetSuite extends QueryTest } }, errorClass = "INVALID_USAGE_OF_STAR_OR_REGEX", - parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups")) + parameters = Map("elem" -> "'*'", "prettyName" -> "MapGroups"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } test("groupBy, flatMapSorted desc") { @@ -2290,7 +2292,8 @@ class DatasetSuite extends QueryTest sqlState = None, parameters = Map( "objectName" -> s"`${colName.replace(".", "`.`")}`", - "proposal" -> "`field.1`, `field 2`")) + "proposal" -> "`field.1`, `field 2`"), + context = ExpectedContext(fragment = "select", getCurrentClassCallSitePattern)) } } } @@ -2304,7 +2307,8 @@ class DatasetSuite extends QueryTest sqlState = None, parameters = Map( "objectName" -> "`the`.`id`", - "proposal" -> "`the.id`")) + "proposal" -> "`the.id`"), + context = ExpectedContext(fragment = "select", getCurrentClassCallSitePattern)) } test("SPARK-39783: backticks in error message for map candidate key with dots") { @@ -2318,7 +2322,8 @@ class DatasetSuite extends QueryTest sqlState = None, parameters = Map( "objectName" -> "`nonexisting`", - "proposal" -> "`map`, `other.column`")) + "proposal" -> "`map`, `other.column`"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) } test("groupBy.as") { @@ -2659,6 +2664,22 @@ class DatasetSuite extends QueryTest assert(join.count() == 1000000) } + test("SPARK-45022: exact DatasetQueryContext call site") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + val df = Seq(1).toDS() + var callSitePattern: String = null + checkError( + exception = intercept[AnalysisException] { + callSitePattern = getNextLineCallSitePattern() + val c = col("a") + df.select(c) + }, + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + sqlState = "42703", + parameters = Map("objectName" -> "`a`", "proposal" -> "`value`"), + context = ExpectedContext(fragment = "col", callSitePattern = callSitePattern)) + } + } } class DatasetLargeResultCollectingSuite extends QueryTest diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala index 4117ea63bdd8c..49811d8ac61bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala @@ -373,7 +373,8 @@ class DatasetUnpivotSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`1`", - "proposal" -> "`id`, `int1`, `str1`, `long1`, `str2`")) + "proposal" -> "`id`, `int1`, `str1`, `long1`, `str2`"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) // unpivoting where value column does not exist val e2 = intercept[AnalysisException] { @@ -389,7 +390,8 @@ class DatasetUnpivotSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`does`", - "proposal" -> "`id`, `int1`, `long1`, `str1`, `str2`")) + "proposal" -> "`id`, `int1`, `long1`, `str1`, `str2`"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) // unpivoting without values where potential value columns are of incompatible types val e3 = intercept[AnalysisException] { @@ -506,7 +508,8 @@ class DatasetUnpivotSuite extends QueryTest errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`an`.`id`", - "proposal" -> "`an.id`, `int1`, `long1`, `str.one`, `str.two`")) + "proposal" -> "`an.id`, `int1`, `long1`, `str.one`, `str.two`"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) } test("unpivot with struct fields") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 9b4ad76881864..2ab651237206a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -293,7 +293,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { "paramIndex" -> "1", "inputSql" -> "\"array()\"", "inputType" -> "\"ARRAY\"", - "requiredType" -> "\"ARRAY\"") + "requiredType" -> "\"ARRAY\""), + context = ExpectedContext( + fragment = "inline", + callSitePattern = getCurrentClassCallSitePattern) ) } @@ -331,7 +334,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array(struct(a), struct(b))\"", "functionName" -> "`array`", - "dataType" -> "(\"STRUCT\" or \"STRUCT\")")) + "dataType" -> "(\"STRUCT\" or \"STRUCT\")"), + context = ExpectedContext( + fragment = "array", + callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( df.select(inline(array(struct(Symbol("a")), struct(Symbol("b").alias("a"))))), @@ -346,7 +352,10 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"array(struct(a), struct(2))\"", "functionName" -> "`array`", - "dataType" -> "(\"STRUCT\" or \"STRUCT\")")) + "dataType" -> "(\"STRUCT\" or \"STRUCT\")"), + context = ExpectedContext( + fragment = "array", + callSitePattern = getCurrentClassCallSitePattern)) checkAnswer( df.select(inline(array(struct(Symbol("a")), struct(lit(2).alias("a"))))), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 5effa2edf585c..933f362db663f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -195,7 +195,9 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "sqlExpr" -> "\"json_tuple(a, 1)\"", "funcName" -> "`json_tuple`" - ) + ), + context = + ExpectedContext(fragment = "json_tuple", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -648,7 +650,9 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { errorClass = "DATATYPE_MISMATCH.INVALID_JSON_MAP_KEY_TYPE", parameters = Map( "schema" -> "\"MAP, STRING>\"", - "sqlExpr" -> "\"entries\"")) + "sqlExpr" -> "\"entries\""), + context = + ExpectedContext(fragment = "from_json", callSitePattern = getCurrentClassCallSitePattern)) } test("SPARK-24709: infers schemas of json strings and pass them to from_json") { @@ -958,7 +962,8 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { .select(from_json($"json", $"schema", options)).collect() }, errorClass = "INVALID_SCHEMA.NON_STRING_LITERAL", - parameters = Map("inputSchema" -> "\"schema\"") + parameters = Map("inputSchema" -> "\"schema\""), + context = ExpectedContext(fragment = "from_json", getCurrentClassCallSitePattern) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala index 2a24f0cc39965..afbe9cdac6366 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ParametersSuite.scala @@ -600,6 +600,10 @@ class ParametersSuite extends QueryTest with SharedSparkSession { array(str_to_map(Column(Literal("a:1,b:2,c:3"))))))) }, errorClass = "INVALID_SQL_ARG", - parameters = Map("name" -> "m")) + parameters = Map("name" -> "m"), + context = ExpectedContext( + fragment = "map_from_arrays", + callSitePattern = getCurrentClassCallSitePattern) + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index b5ae9c7f35200..8668d61317409 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.util.TimeZone +import java.util.regex.Pattern import scala.jdk.CollectionConverters._ @@ -229,6 +230,17 @@ abstract class QueryTest extends PlanTest { assert(query.queryExecution.executedPlan.missingInput.isEmpty, s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}") } + + protected def getCurrentClassCallSitePattern: String = { + val cs = Thread.currentThread().getStackTrace()(2) + s"${cs.getClassName}\\..*\\(${cs.getFileName}:\\d+\\)" + } + + protected def getNextLineCallSitePattern(lines: Int = 1): String = { + val cs = Thread.currentThread().getStackTrace()(2) + Pattern.quote( + s"${cs.getClassName}.${cs.getMethodName}(${cs.getFileName}:${cs.getLineNumber + lines})") + } } object QueryTest extends Assertions { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3612f4a7eda8d..b7201c2d96d77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1920,7 +1920,10 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark dfNoCols.select($"b.*") }, errorClass = "CANNOT_RESOLVE_STAR_EXPAND", - parameters = Map("targetString" -> "`b`", "columns" -> "")) + parameters = Map("targetString" -> "`b`", "columns" -> ""), + context = ExpectedContext( + fragment = "$", + callSitePattern = getCurrentClassCallSitePattern)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 179f40742c28f..38a6b9a50272b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -879,7 +879,10 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "funcName" -> s"`$funcName`", "paramName" -> "`format`", - "paramType" -> "\"STRING\"")) + "paramType" -> "\"STRING\""), + context = ExpectedContext( + fragment = funcName, + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { df2.select(func(col("input"), lit("invalid_format"))).collect() @@ -888,7 +891,10 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { parameters = Map( "parameter" -> "`format`", "functionName" -> s"`$funcName`", - "invalidFormat" -> "'invalid_format'")) + "invalidFormat" -> "'invalid_format'"), + context = ExpectedContext( + fragment = funcName, + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { sql(s"select $funcName('a', 'b', 'c')") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 5c12ba3078069..30a5bf709066d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -696,7 +696,9 @@ class QueryCompilationErrorsSuite Seq("""{"a":1}""").toDF("a").select(from_json($"a", IntegerType)).collect() }, errorClass = "DATATYPE_MISMATCH.INVALID_JSON_SCHEMA", - parameters = Map("schema" -> "\"INT\"", "sqlExpr" -> "\"from_json(a)\"")) + parameters = Map("schema" -> "\"INT\"", "sqlExpr" -> "\"from_json(a)\""), + context = + ExpectedContext(fragment = "from_json", callSitePattern = getCurrentClassCallSitePattern)) } test("WRONG_NUM_ARGS.WITHOUT_SUGGESTION: wrong args of CAST(parameter types contains DataType)") { @@ -767,7 +769,8 @@ class QueryCompilationErrorsSuite }, errorClass = "AMBIGUOUS_REFERENCE_TO_FIELDS", sqlState = "42000", - parameters = Map("field" -> "`firstname`", "count" -> "2") + parameters = Map("field" -> "`firstname`", "count" -> "2"), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern) ) } @@ -780,7 +783,9 @@ class QueryCompilationErrorsSuite }, errorClass = "INVALID_EXTRACT_BASE_FIELD_TYPE", sqlState = "42000", - parameters = Map("base" -> "\"firstname\"", "other" -> "\"STRING\"")) + parameters = Map("base" -> "\"firstname\"", "other" -> "\"STRING\""), + context = ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern) + ) } test("INVALID_EXTRACT_FIELD_TYPE: extract not string literal field") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala index ee28a90aed9af..eafa89e8e007e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.errors import org.apache.spark._ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Cast, CheckOverflowInTableInsert, ExpressionProxy, Literal, SubExprEvaluationRuntime} +import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation +import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.ByteType @@ -53,6 +55,26 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest sqlState = "22012", parameters = Map("config" -> ansiConf), context = ExpectedContext(fragment = "6/0", start = 7, stop = 9)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit(5) / lit(0)).collect() + }, + errorClass = "DIVIDE_BY_ZERO", + sqlState = "22012", + parameters = Map("config" -> ansiConf), + context = ExpectedContext(fragment = "div", callSitePattern = getCurrentClassCallSitePattern)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit(5).divide(lit(0))).collect() + }, + errorClass = "DIVIDE_BY_ZERO", + sqlState = "22012", + parameters = Map("config" -> ansiConf), + context = ExpectedContext( + fragment = "divide", + callSitePattern = getCurrentClassCallSitePattern)) } test("INTERVAL_DIVIDED_BY_ZERO: interval divided by zero") { @@ -92,6 +114,21 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "CAST('66666666666666.666' AS DECIMAL(8, 1))", start = 7, stop = 49)) + + checkError( + exception = intercept[SparkArithmeticException] { + OneRowRelation().select(lit("66666666666666.666").cast("DECIMAL(8, 1)")).collect() + }, + errorClass = "NUMERIC_VALUE_OUT_OF_RANGE", + sqlState = "22003", + parameters = Map( + "value" -> "66666666666666.666", + "precision" -> "8", + "scale" -> "1", + "config" -> ansiConf), + context = ExpectedContext( + fragment = "cast", + callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_ARRAY_INDEX: get element from array") { @@ -102,6 +139,16 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest errorClass = "INVALID_ARRAY_INDEX", parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), context = ExpectedContext(fragment = "array(1, 2, 3, 4, 5)[8]", start = 7, stop = 29)) + + checkError( + exception = intercept[SparkArrayIndexOutOfBoundsException] { + OneRowRelation().select(lit(Array(1, 2, 3, 4, 5))(8)).collect() + }, + errorClass = "INVALID_ARRAY_INDEX", + parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), + context = ExpectedContext( + fragment = "apply", + callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_ARRAY_INDEX_IN_ELEMENT_AT: element_at from array") { @@ -115,6 +162,15 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "element_at(array(1, 2, 3, 4, 5), 8)", start = 7, stop = 41)) + + checkError( + exception = intercept[SparkArrayIndexOutOfBoundsException] { + OneRowRelation().select(element_at(lit(Array(1, 2, 3, 4, 5)), 8)).collect() + }, + errorClass = "INVALID_ARRAY_INDEX_IN_ELEMENT_AT", + parameters = Map("indexValue" -> "8", "arraySize" -> "5", "ansiConfig" -> ansiConf), + context = + ExpectedContext(fragment = "element_at", callSitePattern = getCurrentClassCallSitePattern)) } test("INVALID_INDEX_OF_ZERO: element_at from array by index zero") { @@ -129,6 +185,15 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest start = 7, stop = 41) ) + + checkError( + exception = intercept[SparkRuntimeException]( + OneRowRelation().select(element_at(lit(Array(1, 2, 3, 4, 5)), 0)).collect() + ), + errorClass = "INVALID_INDEX_OF_ZERO", + parameters = Map.empty, + context = + ExpectedContext(fragment = "element_at", callSitePattern = getCurrentClassCallSitePattern)) } test("CAST_INVALID_INPUT: cast string to double") { @@ -146,6 +211,20 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest fragment = "CAST('111111111111xe23' AS DOUBLE)", start = 7, stop = 40)) + + checkError( + exception = intercept[SparkNumberFormatException] { + OneRowRelation().select(lit("111111111111xe23").cast("DOUBLE")).collect() + }, + errorClass = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> "'111111111111xe23'", + "sourceType" -> "\"STRING\"", + "targetType" -> "\"DOUBLE\"", + "ansiConfig" -> ansiConf), + context = ExpectedContext( + fragment = "cast", + callSitePattern = getCurrentClassCallSitePattern)) } test("CANNOT_PARSE_TIMESTAMP: parse string to timestamp") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 75e5d4d452e15..a7cab381c7f6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -1206,7 +1206,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }, errorClass = "DIVIDE_BY_ZERO", parameters = Map("config" -> "\"spark.sql.ansi.enabled\""), - context = new ExpectedContext( + context = ExpectedContext( objectType = "VIEW", objectName = s"$SESSION_CATALOG_NAME.default.v5", fragment = "1/0", @@ -1225,7 +1225,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { }, errorClass = "DIVIDE_BY_ZERO", parameters = Map("config" -> "\"spark.sql.ansi.enabled\""), - context = new ExpectedContext( + context = ExpectedContext( objectType = "VIEW", objectName = s"$SESSION_CATALOG_NAME.default.v1", fragment = "1/0", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala index 6f9cc66f24769..2bc747c0abee4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignAssignmentsSuiteBase.scala @@ -191,7 +191,7 @@ abstract class AlignAssignmentsSuiteBase extends AnalysisTest { protected def parseAndResolve(query: String): LogicalPlan = { val analyzer = new CustomAnalyzer(catalogManager) { override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq( - new ResolveSessionCatalog(catalogManager)) + new ResolveSessionCatalog(this.catalogManager)) } val analyzed = analyzer.execute(CatalystSqlParser.parsePlan(query)) analyzer.checkAnalysis(analyzed) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 4eb65305de838..e39cc91d5f048 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -233,7 +233,7 @@ class PlanResolutionSuite extends AnalysisTest { } val analyzer = new Analyzer(catalogManager) { override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Seq( - new ResolveSessionCatalog(catalogManager)) + new ResolveSessionCatalog(this.catalogManager)) } // We don't check analysis here by default, as we expect the plan to be unresolved // such as `CreateTable`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala index 9f2d202299557..0e4985bac9941 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileMetadataStructSuite.scala @@ -244,7 +244,9 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { df.select("name", METADATA_FILE_NAME).collect() }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`")) + parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`"), + context = + ExpectedContext(fragment = "select", callSitePattern = getCurrentClassCallSitePattern)) } metadataColumnsTest("SPARK-42683: df metadataColumn - schema conflict", @@ -522,14 +524,20 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession { df.select("name", "_metadata.file_name").collect() }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`")) + parameters = Map("fieldName" -> "`file_name`", "fields" -> "`id`, `university`"), + context = ExpectedContext( + fragment = "select", + callSitePattern = getCurrentClassCallSitePattern)) checkError( exception = intercept[AnalysisException] { df.select("name", "_METADATA.file_NAME").collect() }, errorClass = "FIELD_NOT_FOUND", - parameters = Map("fieldName" -> "`file_NAME`", "fields" -> "`id`, `university`")) + parameters = Map("fieldName" -> "`file_NAME`", "fields" -> "`id`, `university`"), + context = ExpectedContext( + fragment = "select", + callSitePattern = getCurrentClassCallSitePattern)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 111e88d57c784..a84aea2786823 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -2697,7 +2697,9 @@ abstract class CSVSuite readback.filter($"AAA" === 2 && $"bbb" === 3).collect() }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`")) + parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`"), + context = + ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index f0561a30727b3..2f8b0a323dc8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -3105,7 +3105,9 @@ abstract class JsonSuite readback.filter($"AAA" === 0 && $"bbb" === 1).collect() }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`")) + parameters = Map("objectName" -> "`AAA`", "proposal" -> "`BBB`, `aaa`"), + context = + ExpectedContext(fragment = "$", callSitePattern = getCurrentClassCallSitePattern)) // Schema inferring val readback2 = spark.read.json(path.getCanonicalPath) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala index 2465dee230de9..d3e9819b9a054 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileMetadataStructRowIndexSuite.scala @@ -133,7 +133,8 @@ class ParquetFileMetadataStructRowIndexSuite extends QueryTest with SharedSparkS parameters = Map( "fieldName" -> "`row_index`", "fields" -> ("`file_path`, `file_name`, `file_size`, " + - "`file_block_start`, `file_block_length`, `file_modification_time`"))) + "`file_block_start`, `file_block_length`, `file_modification_time`")), + context = ExpectedContext(fragment = "select", getCurrentClassCallSitePattern)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index baffc5088eb24..94535bc84a4c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -1918,7 +1918,9 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { Seq("xyz").toDF().select("value", "default").write.insertInto("t") }, errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", - parameters = Map("objectName" -> "`default`", "proposal" -> "`value`")) + parameters = Map("objectName" -> "`default`", "proposal" -> "`value`"), + context = + ExpectedContext(fragment = "select", callSitePattern = getCurrentClassCallSitePattern)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 2174e91cb4435..66d37e996a6cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -713,7 +713,9 @@ class StreamSuite extends StreamTest { "columnName" -> "`rn_col`", "windowSpec" -> ("(PARTITION BY COL1 ORDER BY COL2 ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING " + - "AND CURRENT ROW)"))) + "AND CURRENT ROW)")), + queryContext = Array( + ExpectedContext(fragment = "withColumn", callSitePattern = getCurrentClassCallSitePattern))) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 2d0bcdff07151..0b5e98d0a3e40 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -123,7 +123,7 @@ class HiveSessionStateBuilder( */ override protected def planner: SparkPlanner = { new SparkPlanner(session, experimentalMethods) with HiveStrategies { - override val sparkSession: SparkSession = session + override val sparkSession: SparkSession = this.session override def extraPlanningStrategies: Seq[Strategy] = super.extraPlanningStrategies ++ customPlanningStrategies ++