diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index f7daca8542d6a..dfd6008ac09a1 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -139,7 +139,7 @@ object Connect { "With any value greater than 0, the last sent response will always be buffered.") .version("3.5.0") .bytesConf(ByteUnit.BYTE) - .createWithDefaultString("1m") + .createWithDefaultString("10m") val CONNECT_EXTENSIONS_RELATION_CLASSES = buildStaticConf("spark.connect.extensions.relation.classes") diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala index 488858d33ea12..eddd1c6be72b1 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.test.SharedSparkSession * Base class and utilities for a test suite that starts and tests the real SparkConnectService * with a real SparkConnectClient, communicating over RPC, but both in-process. */ -class SparkConnectServerTest extends SharedSparkSession { +trait SparkConnectServerTest extends SharedSparkSession { // Server port val serverPort: Int = diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala index 169b15582b698..0e29a07b719af 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala @@ -22,7 +22,7 @@ import io.grpc.StatusRuntimeException import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkException +import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.sql.connect.SparkConnectServerTest import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.service.SparkConnectService @@ -32,7 +32,7 @@ class ReattachableExecuteSuite extends SparkConnectServerTest { // Tests assume that this query will result in at least a couple ExecutePlanResponses on the // stream. If this is no longer the case because of changes in how much is returned in a single // ExecutePlanResponse, it may need to be adjusted. - val MEDIUM_RESULTS_QUERY = "select * from range(1000000)" + val MEDIUM_RESULTS_QUERY = "select * from range(10000000)" test("reattach after initial RPC ends") { withClient { client => @@ -138,13 +138,12 @@ class ReattachableExecuteSuite extends SparkConnectServerTest { val reattachIter = stub.reattachExecute( buildReattachExecuteRequest(operationId, Some(response.getResponseId))) assert(reattachIter.hasNext) - reattachIter.next() - - // Nevertheless, the original iterator will handle the INVALID_CURSOR.DISCONNECTED error - iter.next() - // iterator changed because it had to reconnect - assert(reattachableIter.innerIterator ne initialInnerIter) } + + // Nevertheless, the original iterator will handle the INVALID_CURSOR.DISCONNECTED error + iter.next() + // iterator changed because it had to reconnect + assert(reattachableIter.innerIterator ne initialInnerIter) } } @@ -246,19 +245,26 @@ class ReattachableExecuteSuite extends SparkConnectServerTest { val iter = stub.executePlan( buildExecutePlanRequest(buildPlan(MEDIUM_RESULTS_QUERY), operationId = operationId)) var lastSeenResponse: String = null + val serverRetryBuffer = SparkEnv.get.conf + .get(Connect.CONNECT_EXECUTE_REATTACHABLE_OBSERVER_RETRY_BUFFER_SIZE) + .toLong iter.hasNext // open iterator val execution = getExecutionHolder // after consuming enough from the iterator, server should automatically start releasing var lastSeenIndex = 0 - while (iter.hasNext && execution.responseObserver.releasedUntilIndex == 0) { + var totalSizeSeen = 0 + while (iter.hasNext && totalSizeSeen <= 1.1 * serverRetryBuffer) { val r = iter.next() lastSeenResponse = r.getResponseId() + totalSizeSeen += r.getSerializedSize lastSeenIndex += 1 } assert(iter.hasNext) - assert(execution.responseObserver.releasedUntilIndex > 0) + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(execution.responseObserver.releasedUntilIndex > 0) + } // Reattach from the beginning is not available. val reattach = stub.reattachExecute(buildReattachExecuteRequest(operationId, None))