From e1a525e77cfc9e8d751bd279f008c1c4663e6b77 Mon Sep 17 00:00:00 2001 From: Yihong He Date: Wed, 5 Jul 2023 17:44:06 +0900 Subject: [PATCH] [SPARK-42554][CONNECT] Implement GRPC exceptions interception for conversion ### What changes were proposed in this pull request? This PR adds GrpcStub and GrpcExceptionConverter utilities to the Spark Connect Scala Client so that the client can intercept the GRPC Exceptions and convert them to Spark related exceptions. The change converts all GRPC RuntimeStatusException to SparkException as a starting point. ### Why are the changes needed? Intercept GRPC Exceptions in Spark Connect Scala Client and convert them to Spark related exceptions for making exceptions more compatible with the existing behaviors ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests Closes #41743 from heyihong/SPARK-42554. Authored-by: Yihong He Signed-off-by: Hyukjin Kwon --- .../CustomSparkConnectBlockingStub.scala | 51 +++++++++++++++++ .../client/GrpcExceptionConverter.scala | 57 +++++++++++++++++++ .../connect/client/SparkConnectClient.scala | 2 +- .../org/apache/spark/sql/CatalogSuite.scala | 10 ++-- .../apache/spark/sql/ClientE2ETestSuite.scala | 9 ++- .../spark/sql/DataFrameNaFunctionSuite.scala | 3 +- .../apache/spark/sql/DataFrameStatSuite.scala | 4 +- .../KeyValueGroupedDatasetE2ETestSuite.scala | 5 +- .../spark/sql/SparkSessionE2ESuite.scala | 5 +- .../client/SparkConnectClientSuite.scala | 5 +- 10 files changed, 130 insertions(+), 21 deletions(-) create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala new file mode 100644 index 0000000000000..ab2f63059e9a9 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import io.grpc.ManagedChannel + +import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, ConfigRequest, ConfigResponse, ExecutePlanRequest, ExecutePlanResponse, InterruptRequest, InterruptResponse} +import org.apache.spark.connect.proto + +private[client] class CustomSparkConnectBlockingStub(channel: ManagedChannel) { + + private val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel) + + def executePlan(request: ExecutePlanRequest): java.util.Iterator[ExecutePlanResponse] = { + GrpcExceptionConverter.convert { + GrpcExceptionConverter.convertIterator[ExecutePlanResponse](stub.executePlan(request)) + } + } + + def analyzePlan(request: AnalyzePlanRequest): AnalyzePlanResponse = { + GrpcExceptionConverter.convert { + stub.analyzePlan(request) + } + } + + def config(request: ConfigRequest): ConfigResponse = { + GrpcExceptionConverter.convert { + stub.config(request) + } + } + + def interrupt(request: InterruptRequest): InterruptResponse = { + GrpcExceptionConverter.convert { + stub.interrupt(request) + } + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala new file mode 100644 index 0000000000000..1a42ec821d84f --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import io.grpc.StatusRuntimeException +import io.grpc.protobuf.StatusProto + +import org.apache.spark.{SparkException, SparkThrowable} + +private[client] object GrpcExceptionConverter { + def convert[T](f: => T): T = { + try { + f + } catch { + case e: StatusRuntimeException => + throw toSparkThrowable(e) + } + } + + def convertIterator[T](iter: java.util.Iterator[T]): java.util.Iterator[T] = { + new java.util.Iterator[T] { + override def hasNext: Boolean = { + convert { + iter.hasNext + } + } + + override def next(): T = { + convert { + iter.next() + } + } + } + } + + private def toSparkThrowable(ex: StatusRuntimeException): SparkThrowable with Throwable = { + val status = StatusProto.fromThrowable(ex) + // TODO: Add finer grained error conversion + new SparkException(status.getMessage, ex.getCause) + } +} + + diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 0c3b1cae091e3..65d0aea19bc61 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -40,7 +40,7 @@ private[sql] class SparkConnectClient( private val userContext: UserContext = configuration.userContext - private[this] val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel) + private[this] val stub = new CustomSparkConnectBlockingStub(channel) private[client] def userAgent: String = configuration.userAgent diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala index 04b3f4e639a06..00a6bcc9b5c45 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import java.io.{File, FilenameFilter} -import io.grpc.StatusRuntimeException import org.apache.commons.io.FileUtils +import org.apache.spark.SparkException import org.apache.spark.sql.connect.client.util.RemoteSparkSession import org.apache.spark.sql.types.{DoubleType, LongType, StructType} import org.apache.spark.storage.StorageLevel @@ -46,7 +46,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(databasesWithPattern.length == 0) val database = spark.catalog.getDatabase(db) assert(database.name == db) - val message = intercept[StatusRuntimeException] { + val message = intercept[SparkException] { spark.catalog.getDatabase("notExists") }.getMessage assert(message.contains("SCHEMA_NOT_FOUND")) @@ -66,7 +66,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { val catalogs = spark.catalog.listCatalogs().collect() assert(catalogs.length == 1) assert(catalogs.map(_.name) sameElements Array("spark_catalog")) - val message = intercept[StatusRuntimeException] { + val message = intercept[SparkException] { spark.catalog.setCurrentCatalog("notExists") }.getMessage assert(message.contains("plugin class not found")) @@ -141,7 +141,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(spark.catalog.listTables().collect().map(_.name).toSet == Set(parquetTableName)) } } - val message = intercept[StatusRuntimeException] { + val message = intercept[SparkException] { spark.catalog.getTable(parquetTableName) }.getMessage assert(message.contains("TABLE_OR_VIEW_NOT_FOUND")) @@ -207,7 +207,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(spark.catalog.getFunction(absFunctionName).name == absFunctionName) val notExistsFunction = "notExists" assert(!spark.catalog.functionExists(notExistsFunction)) - val message = intercept[StatusRuntimeException] { + val message = intercept[SparkException] { spark.catalog.getFunction(notExistsFunction) }.getMessage assert(message.contains("UNRESOLVED_ROUTINE")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 0ababaa0af1a6..fd3dba4946522 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -23,7 +23,6 @@ import java.util.Properties import scala.collection.JavaConverters._ import scala.collection.mutable -import io.grpc.StatusRuntimeException import org.apache.commons.io.FileUtils import org.apache.commons.io.output.TeeOutputStream import org.apache.commons.lang3.{JavaVersion, SystemUtils} @@ -65,7 +64,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM assume(IntegrationTestUtils.isSparkHiveJarAvailable) withTable("test_martin") { // Fails, because table does not exist. - assertThrows[StatusRuntimeException] { + assertThrows[SparkException] { spark.sql("select * from test_martin").collect() } // Execute eager, DML @@ -153,7 +152,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM StructField("job", StringType) :: Nil)) .csv(testDataPath.toString) // Failed because the path cannot be provided both via option and load method (csv). - assertThrows[StatusRuntimeException] { + assertThrows[SparkException] { df.collect() } } @@ -237,7 +236,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM assert(result.length == 10) } finally { // clean up - assertThrows[StatusRuntimeException] { + assertThrows[SparkException] { spark.read.jdbc(url = s"$url;drop=true", table, new Properties()).collect() } } @@ -354,7 +353,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM val df = spark.range(10) val outputFolderPath = Files.createTempDirectory("output").toAbsolutePath // Failed because the path cannot be provided both via option and save method. - assertThrows[StatusRuntimeException] { + assertThrows[SparkException] { df.write.option("path", outputFolderPath.toString).save(outputFolderPath.toString) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala index 1f6ea879248dc..c44b515bdedf4 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ +import org.apache.spark.SparkException import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StringType, StructType} @@ -278,7 +279,7 @@ class DataFrameNaFunctionSuite extends QueryTest with SQLHelper { test("drop with col(*)") { val df = createDF() - val ex = intercept[RuntimeException] { + val ex = intercept[SparkException] { df.na.drop("any", Seq("*")).collect() } assert(ex.getMessage.contains("UNRESOLVED_COLUMN.WITH_SUGGESTION")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index aea31005f3bd6..12683e54d989f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import java.util.Random -import io.grpc.StatusRuntimeException import org.scalatest.matchers.must.Matchers._ +import org.apache.spark.SparkException import org.apache.spark.sql.connect.client.util.RemoteSparkSession class DataFrameStatSuite extends RemoteSparkSession { @@ -87,7 +87,7 @@ class DataFrameStatSuite extends RemoteSparkSession { val results = df.stat.cov("singles", "doubles") assert(math.abs(results - 55.0 / 3) < 1e-12) - intercept[StatusRuntimeException] { + intercept[SparkException] { df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes } val decimalData = Seq.tabulate(6)(i => (BigDecimal(i % 3), BigDecimal(i % 2))).toDF("a", "b") diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala index 79bedabd559b6..7ced0a9cb6755 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql import java.sql.Timestamp import java.util.Arrays -import io.grpc.StatusRuntimeException - +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.functions._ @@ -179,7 +178,7 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { assert(values == Arrays.asList[String]("0", "8,6,4,2,0", "1", "9,7,5,3,1")) // Star is not allowed as group sort column - val message = intercept[StatusRuntimeException] { + val message = intercept[SparkException] { grouped .flatMapSortedGroups(col("*")) { (g, iter) => Iterator(String.valueOf(g), iter.mkString(",")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala index fb295c00edc02..044f6a48cc8f8 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala @@ -22,6 +22,7 @@ import scala.util.{Failure, Success} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.SparkException import org.apache.spark.sql.connect.client.util.RemoteSparkSession import org.apache.spark.util.ThreadUtils @@ -85,11 +86,11 @@ class SparkSessionE2ESuite extends RemoteSparkSession { } finished } - val e1 = intercept[io.grpc.StatusRuntimeException] { + val e1 = intercept[SparkException] { spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect() } assert(e1.getMessage.contains("cancelled"), s"Unexpected exception: $e1") - val e2 = intercept[io.grpc.StatusRuntimeException] { + val e2 = intercept[SparkException] { spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect() } assert(e2.getMessage.contains("cancelled"), s"Unexpected exception: $e2") diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index e9a10e27273ad..2267d803e1d7a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -21,11 +21,12 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable -import io.grpc.{Server, StatusRuntimeException} +import io.grpc.Server import io.grpc.netty.NettyServerBuilder import io.grpc.stub.StreamObserver import org.scalatest.BeforeAndAfterEach +import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ArtifactStatusesRequest, ArtifactStatusesResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc} import org.apache.spark.sql.SparkSession @@ -104,7 +105,7 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { val request = AnalyzePlanRequest.newBuilder().setSessionId("abc123").build() // Failed the ssl handshake as the dummy server does not have any server credentials installed. - assertThrows[StatusRuntimeException] { + assertThrows[SparkException] { client.analyze(request) } }