diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala new file mode 100644 index 0000000000000..ab2f63059e9a9 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import io.grpc.ManagedChannel + +import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, ConfigRequest, ConfigResponse, ExecutePlanRequest, ExecutePlanResponse, InterruptRequest, InterruptResponse} +import org.apache.spark.connect.proto + +private[client] class CustomSparkConnectBlockingStub(channel: ManagedChannel) { + + private val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel) + + def executePlan(request: ExecutePlanRequest): java.util.Iterator[ExecutePlanResponse] = { + GrpcExceptionConverter.convert { + GrpcExceptionConverter.convertIterator[ExecutePlanResponse](stub.executePlan(request)) + } + } + + def analyzePlan(request: AnalyzePlanRequest): AnalyzePlanResponse = { + GrpcExceptionConverter.convert { + stub.analyzePlan(request) + } + } + + def config(request: ConfigRequest): ConfigResponse = { + GrpcExceptionConverter.convert { + stub.config(request) + } + } + + def interrupt(request: InterruptRequest): InterruptResponse = { + GrpcExceptionConverter.convert { + stub.interrupt(request) + } + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala new file mode 100644 index 0000000000000..1a42ec821d84f --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import io.grpc.StatusRuntimeException +import io.grpc.protobuf.StatusProto + +import org.apache.spark.{SparkException, SparkThrowable} + +private[client] object GrpcExceptionConverter { + def convert[T](f: => T): T = { + try { + f + } catch { + case e: StatusRuntimeException => + throw toSparkThrowable(e) + } + } + + def convertIterator[T](iter: java.util.Iterator[T]): java.util.Iterator[T] = { + new java.util.Iterator[T] { + override def hasNext: Boolean = { + convert { + iter.hasNext + } + } + + override def next(): T = { + convert { + iter.next() + } + } + } + } + + private def toSparkThrowable(ex: StatusRuntimeException): SparkThrowable with Throwable = { + val status = StatusProto.fromThrowable(ex) + // TODO: Add finer grained error conversion + new SparkException(status.getMessage, ex.getCause) + } +} + + diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index 0c3b1cae091e3..65d0aea19bc61 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -40,7 +40,7 @@ private[sql] class SparkConnectClient( private val userContext: UserContext = configuration.userContext - private[this] val stub = proto.SparkConnectServiceGrpc.newBlockingStub(channel) + private[this] val stub = new CustomSparkConnectBlockingStub(channel) private[client] def userAgent: String = configuration.userAgent diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala index 04b3f4e639a06..00a6bcc9b5c45 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import java.io.{File, FilenameFilter} -import io.grpc.StatusRuntimeException import org.apache.commons.io.FileUtils +import org.apache.spark.SparkException import org.apache.spark.sql.connect.client.util.RemoteSparkSession import org.apache.spark.sql.types.{DoubleType, LongType, StructType} import org.apache.spark.storage.StorageLevel @@ -46,7 +46,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(databasesWithPattern.length == 0) val database = spark.catalog.getDatabase(db) assert(database.name == db) - val message = intercept[StatusRuntimeException] { + val message = intercept[SparkException] { spark.catalog.getDatabase("notExists") }.getMessage assert(message.contains("SCHEMA_NOT_FOUND")) @@ -66,7 +66,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { val catalogs = spark.catalog.listCatalogs().collect() assert(catalogs.length == 1) assert(catalogs.map(_.name) sameElements Array("spark_catalog")) - val message = intercept[StatusRuntimeException] { + val message = intercept[SparkException] { spark.catalog.setCurrentCatalog("notExists") }.getMessage assert(message.contains("plugin class not found")) @@ -141,7 +141,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(spark.catalog.listTables().collect().map(_.name).toSet == Set(parquetTableName)) } } - val message = intercept[StatusRuntimeException] { + val message = intercept[SparkException] { spark.catalog.getTable(parquetTableName) }.getMessage assert(message.contains("TABLE_OR_VIEW_NOT_FOUND")) @@ -207,7 +207,7 @@ class CatalogSuite extends RemoteSparkSession with SQLHelper { assert(spark.catalog.getFunction(absFunctionName).name == absFunctionName) val notExistsFunction = "notExists" assert(!spark.catalog.functionExists(notExistsFunction)) - val message = intercept[StatusRuntimeException] { + val message = intercept[SparkException] { spark.catalog.getFunction(notExistsFunction) }.getMessage assert(message.contains("UNRESOLVED_ROUTINE")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 0ababaa0af1a6..fd3dba4946522 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -23,7 +23,6 @@ import java.util.Properties import scala.collection.JavaConverters._ import scala.collection.mutable -import io.grpc.StatusRuntimeException import org.apache.commons.io.FileUtils import org.apache.commons.io.output.TeeOutputStream import org.apache.commons.lang3.{JavaVersion, SystemUtils} @@ -65,7 +64,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM assume(IntegrationTestUtils.isSparkHiveJarAvailable) withTable("test_martin") { // Fails, because table does not exist. - assertThrows[StatusRuntimeException] { + assertThrows[SparkException] { spark.sql("select * from test_martin").collect() } // Execute eager, DML @@ -153,7 +152,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM StructField("job", StringType) :: Nil)) .csv(testDataPath.toString) // Failed because the path cannot be provided both via option and load method (csv). - assertThrows[StatusRuntimeException] { + assertThrows[SparkException] { df.collect() } } @@ -237,7 +236,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM assert(result.length == 10) } finally { // clean up - assertThrows[StatusRuntimeException] { + assertThrows[SparkException] { spark.read.jdbc(url = s"$url;drop=true", table, new Properties()).collect() } } @@ -354,7 +353,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM val df = spark.range(10) val outputFolderPath = Files.createTempDirectory("output").toAbsolutePath // Failed because the path cannot be provided both via option and save method. - assertThrows[StatusRuntimeException] { + assertThrows[SparkException] { df.write.option("path", outputFolderPath.toString).save(outputFolderPath.toString) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala index 1f6ea879248dc..c44b515bdedf4 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ +import org.apache.spark.SparkException import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StringType, StructType} @@ -278,7 +279,7 @@ class DataFrameNaFunctionSuite extends QueryTest with SQLHelper { test("drop with col(*)") { val df = createDF() - val ex = intercept[RuntimeException] { + val ex = intercept[SparkException] { df.na.drop("any", Seq("*")).collect() } assert(ex.getMessage.contains("UNRESOLVED_COLUMN.WITH_SUGGESTION")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index aea31005f3bd6..12683e54d989f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import java.util.Random -import io.grpc.StatusRuntimeException import org.scalatest.matchers.must.Matchers._ +import org.apache.spark.SparkException import org.apache.spark.sql.connect.client.util.RemoteSparkSession class DataFrameStatSuite extends RemoteSparkSession { @@ -87,7 +87,7 @@ class DataFrameStatSuite extends RemoteSparkSession { val results = df.stat.cov("singles", "doubles") assert(math.abs(results - 55.0 / 3) < 1e-12) - intercept[StatusRuntimeException] { + intercept[SparkException] { df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes } val decimalData = Seq.tabulate(6)(i => (BigDecimal(i % 3), BigDecimal(i % 2))).toDF("a", "b") diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala index 79bedabd559b6..7ced0a9cb6755 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql import java.sql.Timestamp import java.util.Arrays -import io.grpc.StatusRuntimeException - +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append import org.apache.spark.sql.connect.client.util.QueryTest import org.apache.spark.sql.functions._ @@ -179,7 +178,7 @@ class KeyValueGroupedDatasetE2ETestSuite extends QueryTest with SQLHelper { assert(values == Arrays.asList[String]("0", "8,6,4,2,0", "1", "9,7,5,3,1")) // Star is not allowed as group sort column - val message = intercept[StatusRuntimeException] { + val message = intercept[SparkException] { grouped .flatMapSortedGroups(col("*")) { (g, iter) => Iterator(String.valueOf(g), iter.mkString(",")) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala index fb295c00edc02..044f6a48cc8f8 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala @@ -22,6 +22,7 @@ import scala.util.{Failure, Success} import org.scalatest.concurrent.Eventually._ +import org.apache.spark.SparkException import org.apache.spark.sql.connect.client.util.RemoteSparkSession import org.apache.spark.util.ThreadUtils @@ -85,11 +86,11 @@ class SparkSessionE2ESuite extends RemoteSparkSession { } finished } - val e1 = intercept[io.grpc.StatusRuntimeException] { + val e1 = intercept[SparkException] { spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect() } assert(e1.getMessage.contains("cancelled"), s"Unexpected exception: $e1") - val e2 = intercept[io.grpc.StatusRuntimeException] { + val e2 = intercept[SparkException] { spark.range(10).map(n => { Thread.sleep(30.seconds.toMillis); n }).collect() } assert(e2.getMessage.contains("cancelled"), s"Unexpected exception: $e2") diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index e9a10e27273ad..2267d803e1d7a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -21,11 +21,12 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable -import io.grpc.{Server, StatusRuntimeException} +import io.grpc.Server import io.grpc.netty.NettyServerBuilder import io.grpc.stub.StreamObserver import org.scalatest.BeforeAndAfterEach +import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ArtifactStatusesRequest, ArtifactStatusesResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc} import org.apache.spark.sql.SparkSession @@ -104,7 +105,7 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { val request = AnalyzePlanRequest.newBuilder().setSessionId("abc123").build() // Failed the ssl handshake as the dummy server does not have any server credentials installed. - assertThrows[StatusRuntimeException] { + assertThrows[SparkException] { client.analyze(request) } }