From 68d8e65c7829a4a41f8c159c9b30c34cd623da47 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 3 Aug 2023 08:11:02 +0900 Subject: [PATCH] [SPARK-44424][CONNECT][PYTHON] Python client for reattaching to existing execute in Spark Connect ### What changes were proposed in this pull request? This PR proposes to implement the Python client side for https://github.com/apache/spark/pull/42228. Basically this PR applies the same changes of `ExecutePlanResponseReattachableIterator`, and `SparkConnectClient` to PySpark as the symmetry. ### Why are the changes needed? To enable the same feature in https://github.com/apache/spark/pull/42228 ### Does this PR introduce _any_ user-facing change? Yes, see https://github.com/apache/spark/pull/42228. ### How was this patch tested? Existing unittests because it enables the feature by default. Also, manual E2E tests. Closes #42235 from HyukjinKwon/SPARK-44599. Lead-authored-by: Hyukjin Kwon Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/client/core.py | 207 +++++++++------ python/pyspark/sql/connect/client/reattach.py | 237 ++++++++++++++++++ python/pyspark/sql/connect/session.py | 2 +- python/pyspark/sql/session.py | 2 + .../sql/tests/connect/client/test_client.py | 16 +- python/pyspark/testing/connectutils.py | 4 + 6 files changed, 386 insertions(+), 82 deletions(-) create mode 100644 python/pyspark/sql/connect/client/reattach.py diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 0288bbc65087e..d9def40ebe886 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -20,6 +20,7 @@ "getLogLevel", ] +from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) @@ -50,6 +51,7 @@ Generator, Type, TYPE_CHECKING, + Sequence, ) import pandas as pd @@ -558,8 +560,6 @@ def fromProto(cls, pb: pb2.ConfigResponse) -> "ConfigResult": class SparkConnectClient(object): """ Conceptually the remote spark session that communicates with the server - - .. versionadded:: 3.4.0 """ @classmethod @@ -572,24 +572,40 @@ def retry_exception(cls, e: Exception) -> bool: def __init__( self, connection: Union[str, ChannelBuilder], - userId: Optional[str] = None, - channelOptions: Optional[List[Tuple[str, Any]]] = None, - retryPolicy: Optional[Dict[str, Any]] = None, + user_id: Optional[str] = None, + channel_options: Optional[List[Tuple[str, Any]]] = None, + retry_policy: Optional[Dict[str, Any]] = None, + use_reattachable_execute: bool = True, ): """ Creates a new SparkSession for the Spark Connect interface. Parameters ---------- - connection: Union[str,ChannelBuilder] + connection : str or :class:`ChannelBuilder` Connection string that is used to extract the connection parameters and configure the GRPC connection. Or instance of ChannelBuilder that creates GRPC connection. Defaults to `sc://localhost`. - userId : Optional[str] + user_id : str, optional Optional unique user ID that is used to differentiate multiple users and isolate their Spark Sessions. If the `user_id` is not set, will default to the $USER environment. Defining the user ID as part of the connection string takes precedence. + channel_options: list of tuple, optional + Additional options that can be passed to the GRPC channel construction. + retry_policy: dict of str and any, optional + Additional configuration for retrying. There are four configurations as below + * ``max_retries`` + Maximum number of tries default 15 + * ``backoff_multiplier`` + Backoff multiplier for the policy. Default: 4(ms) + * ``initial_backoff`` + Backoff to wait before the first retry. Default: 50(ms) + * ``max_backoff`` + Maximum backoff controls the maximum amount of time to wait before retrying + a failed request. Default: 60000(ms). + use_reattachable_execute: bool + Enable reattachable execution. """ self.thread_local = threading.local() @@ -597,7 +613,7 @@ def __init__( self._builder = ( connection if isinstance(connection, ChannelBuilder) - else ChannelBuilder(connection, channelOptions) + else ChannelBuilder(connection, channel_options) ) self._user_id = None self._retry_policy = { @@ -606,8 +622,8 @@ def __init__( "initial_backoff": 50, "max_backoff": 60000, } - if retryPolicy: - self._retry_policy.update(retryPolicy) + if retry_policy: + self._retry_policy.update(retry_policy) # Generate a unique session ID for this client. This UUID must be unique to allow # concurrent Spark sessions of the same user. If the channel is closed, creating @@ -615,8 +631,8 @@ def __init__( self._session_id = str(uuid.uuid4()) if self._builder.userId is not None: self._user_id = self._builder.userId - elif userId is not None: - self._user_id = userId + elif user_id is not None: + self._user_id = user_id else: self._user_id = os.getenv("USER", None) @@ -624,8 +640,17 @@ def __init__( self._closed = False self._stub = grpc_lib.SparkConnectServiceStub(self._channel) self._artifact_manager = ArtifactManager(self._user_id, self._session_id, self._channel) + self._use_reattachable_execute = use_reattachable_execute # Configure logging for the SparkConnect client. + def disable_reattachable_execute(self) -> "SparkConnectClient": + self._use_reattachable_execute = False + return self + + def enable_reattachable_execute(self) -> "SparkConnectClient": + self._use_reattachable_execute = True + return self + def register_udf( self, function: Any, @@ -741,7 +766,7 @@ def _resources(self) -> Dict[str, ResourceInformation]: return resources def _build_observed_metrics( - self, metrics: List["pb2.ExecutePlanResponse.ObservedMetrics"] + self, metrics: Sequence["pb2.ExecutePlanResponse.ObservedMetrics"] ) -> Iterator[PlanObservedMetrics]: return (PlanObservedMetrics(x.name, [v for v in x.values]) for x in metrics) @@ -1065,17 +1090,29 @@ def _execute(self, req: pb2.ExecutePlanRequest) -> None: """ logger.info("Execute") + + def handle_response(b: pb2.ExecutePlanResponse) -> None: + if b.session_id != self._session_id: + raise SparkConnectException( + "Received incorrect session identifier for request: " + f"{b.session_id} != {self._session_id}" + ) + try: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): - with attempt: - for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): - if b.session_id != self._session_id: - raise SparkConnectException( - "Received incorrect session identifier for request: " - f"{b.session_id} != {self._session_id}" - ) + if self._use_reattachable_execute: + # Don't use retryHandler - own retry handling is inside. + generator = ExecutePlanResponseReattachableIterator( + req, self._stub, self._retry_policy, self._builder.metadata() + ) + for b in generator: + handle_response(b) + else: + for attempt in Retrying( + can_retry=SparkConnectClient.retry_exception, **self._retry_policy + ): + with attempt: + for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): + handle_response(b) except Exception as error: self._handle_error(error) @@ -1092,58 +1129,77 @@ def _execute_and_fetch_as_iterator( ]: logger.info("ExecuteAndFetchAsIterator") + def handle_response( + b: pb2.ExecutePlanResponse, + ) -> Iterator[ + Union[ + "pa.RecordBatch", + StructType, + PlanMetrics, + PlanObservedMetrics, + Dict[str, Any], + ] + ]: + if b.session_id != self._session_id: + raise SparkConnectException( + "Received incorrect session identifier for request: " + f"{b.session_id} != {self._session_id}" + ) + if b.HasField("metrics"): + logger.debug("Received metric batch.") + yield from self._build_metrics(b.metrics) + if b.observed_metrics: + logger.debug("Received observed metric batch.") + yield from self._build_observed_metrics(b.observed_metrics) + if b.HasField("schema"): + logger.debug("Received the schema.") + dt = types.proto_schema_to_pyspark_data_type(b.schema) + assert isinstance(dt, StructType) + yield dt + if b.HasField("sql_command_result"): + logger.debug("Received the SQL command result.") + yield {"sql_command_result": b.sql_command_result.relation} + if b.HasField("write_stream_operation_start_result"): + field = "write_stream_operation_start_result" + yield {field: b.write_stream_operation_start_result} + if b.HasField("streaming_query_command_result"): + yield {"streaming_query_command_result": b.streaming_query_command_result} + if b.HasField("streaming_query_manager_command_result"): + cmd_result = b.streaming_query_manager_command_result + yield {"streaming_query_manager_command_result": cmd_result} + if b.HasField("get_resources_command_result"): + resources = {} + for key, resource in b.get_resources_command_result.resources.items(): + name = resource.name + addresses = [address for address in resource.addresses] + resources[key] = ResourceInformation(name, addresses) + yield {"get_resources_command_result": resources} + if b.HasField("arrow_batch"): + logger.debug( + f"Received arrow batch rows={b.arrow_batch.row_count} " + f"size={len(b.arrow_batch.data)}" + ) + + with pa.ipc.open_stream(b.arrow_batch.data) as reader: + for batch in reader: + assert isinstance(batch, pa.RecordBatch) + yield batch + try: - for attempt in Retrying( - can_retry=SparkConnectClient.retry_exception, **self._retry_policy - ): - with attempt: - for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): - if b.session_id != self._session_id: - raise SparkConnectException( - "Received incorrect session identifier for request: " - f"{b.session_id} != {self._session_id}" - ) - if b.HasField("metrics"): - logger.debug("Received metric batch.") - yield from self._build_metrics(b.metrics) - if b.observed_metrics: - logger.debug("Received observed metric batch.") - yield from self._build_observed_metrics(b.observed_metrics) - if b.HasField("schema"): - logger.debug("Received the schema.") - dt = types.proto_schema_to_pyspark_data_type(b.schema) - assert isinstance(dt, StructType) - yield dt - if b.HasField("sql_command_result"): - logger.debug("Received the SQL command result.") - yield {"sql_command_result": b.sql_command_result.relation} - if b.HasField("write_stream_operation_start_result"): - field = "write_stream_operation_start_result" - yield {field: b.write_stream_operation_start_result} - if b.HasField("streaming_query_command_result"): - yield { - "streaming_query_command_result": b.streaming_query_command_result - } - if b.HasField("streaming_query_manager_command_result"): - cmd_result = b.streaming_query_manager_command_result - yield {"streaming_query_manager_command_result": cmd_result} - if b.HasField("get_resources_command_result"): - resources = {} - for key, resource in b.get_resources_command_result.resources.items(): - name = resource.name - addresses = [address for address in resource.addresses] - resources[key] = ResourceInformation(name, addresses) - yield {"get_resources_command_result": resources} - if b.HasField("arrow_batch"): - logger.debug( - f"Received arrow batch rows={b.arrow_batch.row_count} " - f"size={len(b.arrow_batch.data)}" - ) - - with pa.ipc.open_stream(b.arrow_batch.data) as reader: - for batch in reader: - assert isinstance(batch, pa.RecordBatch) - yield batch + if self._use_reattachable_execute: + # Don't use retryHandler - own retry handling is inside. + generator = ExecutePlanResponseReattachableIterator( + req, self._stub, self._retry_policy, self._builder.metadata() + ) + for b in generator: + yield from handle_response(b) + else: + for attempt in Retrying( + can_retry=SparkConnectClient.retry_exception, **self._retry_policy + ): + with attempt: + for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): + yield from handle_response(b) except Exception as error: self._handle_error(error) @@ -1502,6 +1558,9 @@ def __exit__( self._retry_state.set_done() return None + def is_first_try(self) -> bool: + return self._retry_state._count == 0 + class Retrying: """ diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py new file mode 100644 index 0000000000000..4d4cce0ca4413 --- /dev/null +++ b/python/pyspark/sql/connect/client/reattach.py @@ -0,0 +1,237 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pyspark.sql.connect.utils import check_dependencies + +check_dependencies(__name__) + +import warnings +import uuid +from collections.abc import Generator +from typing import Optional, Dict, Any, Iterator, Iterable, Tuple +from multiprocessing.pool import ThreadPool +import os + +import pyspark.sql.connect.proto as pb2 +import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib + + +class ExecutePlanResponseReattachableIterator(Generator): + """ + Retryable iterator of ExecutePlanResponses to an ExecutePlan call. + + It can handle situations when: + - the ExecutePlanResponse stream was broken by retryable network error (governed by + retryPolicy) + - the ExecutePlanResponse was gracefully ended by the server without a ResultComplete + message; this tells the client that there is more, and it should reattach to continue. + + Initial iterator is the result of an ExecutePlan on the request, but it can be reattached with + ReattachExecute request. ReattachExecute request is provided the responseId of last returned + ExecutePlanResponse on the iterator to return a new iterator from server that continues after + that. + + In reattachable execute the server does buffer some responses in case the client needs to + backtrack. To let server release this buffer sooner, this iterator asynchronously sends + ReleaseExecute RPCs that instruct the server to release responses that it already processed. + + Note: If the initial ExecutePlan did not even reach the server and execution didn't start, + the ReattachExecute can still fail with INVALID_HANDLE.OPERATION_NOT_FOUND, failing the whole + operation. + """ + + _release_thread_pool = ThreadPool(os.cpu_count() if os.cpu_count() else 8) + + def __init__( + self, + request: pb2.ExecutePlanRequest, + stub: grpc_lib.SparkConnectServiceStub, + retry_policy: Dict[str, Any], + metadata: Iterable[Tuple[str, str]], + ): + self._request = request + self._retry_policy = retry_policy + if request.operation_id: + self._operation_id = request.operation_id + else: + # Add operation id, if not present. + # with operationId set by the client, the client can use it to try to reattach on error + # even before getting the first response. If the operation in fact didn't even reach the + # server, that will end with INVALID_HANDLE.OPERATION_NOT_FOUND error. + self._operation_id = str(uuid.uuid4()) + + self._stub = stub + request.request_options.append( + pb2.ExecutePlanRequest.RequestOption( + reattach_options=pb2.ReattachOptions(reattachable=True) + ) + ) + request.operation_id = self._operation_id + self._initial_request = request + + # ResponseId of the last response returned by next() + self._last_returned_response_id: Optional[str] = None + + # True after ResponseComplete message was seen in the stream. + # Server will always send this message at the end of the stream, if the underlying iterator + # finishes without producing one, another iterator needs to be reattached. + self._result_complete = False + + # Initial iterator comes from ExecutePlan request. + # Note: This is not retried, because no error would ever be thrown here, and GRPC will only + # throw error on first self._has_next(). + self._iterator: Iterator[pb2.ExecutePlanResponse] = iter( + self._stub.ExecutePlan(self._initial_request, metadata=metadata) + ) + + # Current item from this iterator. + self._current: Optional[pb2.ExecutePlanResponse] = None + + def send(self, value: Any) -> pb2.ExecutePlanResponse: + # will trigger reattach in case the stream completed without result_complete + if not self._has_next(): + raise StopIteration() + + ret = self._current + assert ret is not None + + self._last_returned_response_id = ret.response_id + if ret.HasField("result_complete"): + self._result_complete = True + self._release_execute(None) # release all + else: + self._release_execute(self._last_returned_response_id) + self._current = None + return ret + + def _has_next(self) -> bool: + from pyspark.sql.connect.client.core import SparkConnectClient + from pyspark.sql.connect.client.core import Retrying + + if self._result_complete: + # After response complete response + return False + else: + for attempt in Retrying( + can_retry=SparkConnectClient.retry_exception, **self._retry_policy + ): + with attempt: + # on first try, we use the existing iterator. + if not attempt.is_first_try(): + # on retry, the iterator is borked, so we need a new one + self._iterator = iter( + self._stub.ReattachExecute(self._create_reattach_execute_request()) + ) + + if self._current is None: + try: + self._current = next(self._iterator) + except StopIteration: + pass + + has_next = self._current is not None + + # Graceful reattach: + # If iterator ended, but there was no ResponseComplete, it means that + # there is more, and we need to reattach. While ResponseComplete didn't + # arrive, we keep reattaching. + if not self._result_complete and not has_next: + while not has_next: + self._iterator = iter( + self._stub.ReattachExecute(self._create_reattach_execute_request()) + ) + # shouldn't change + assert not self._result_complete + try: + self._current = next(self._iterator) + except StopIteration: + pass + has_next = self._current is not None + return has_next + return False + + def _release_execute(self, until_response_id: Optional[str]) -> None: + """ + Inform the server to release the execution. + + This will send an asynchronous RPC which will not block this iterator, the iterator can + continue to be consumed. + + Release with untilResponseId informs the server that the iterator has been consumed until + and including response with that responseId, and these responses can be freed. + + Release with None means that the responses have been completely consumed and informs the + server that the completed execution can be completely freed. + """ + from pyspark.sql.connect.client.core import SparkConnectClient + from pyspark.sql.connect.client.core import Retrying + + request = self._create_release_execute_request(until_response_id) + + def target() -> None: + try: + for attempt in Retrying( + can_retry=SparkConnectClient.retry_exception, **self._retry_policy + ): + with attempt: + self._stub.ReleaseExecute(request) + except Exception as e: + warnings.warn(f"ReleaseExecute failed with exception: {e}.") + + ExecutePlanResponseReattachableIterator._release_thread_pool.apply_async(target) + + def _create_reattach_execute_request(self) -> pb2.ReattachExecuteRequest: + reattach = pb2.ReattachExecuteRequest( + session_id=self._initial_request.session_id, + user_context=self._initial_request.user_context, + operation_id=self._initial_request.operation_id, + ) + + if self._initial_request.client_type: + reattach.client_type = self._initial_request.client_type + + if self._last_returned_response_id: + reattach.last_response_id = self._last_returned_response_id + + return reattach + + def _create_release_execute_request( + self, until_response_id: Optional[str] + ) -> pb2.ReleaseExecuteRequest: + release = pb2.ReleaseExecuteRequest( + session_id=self._initial_request.session_id, + user_context=self._initial_request.user_context, + operation_id=self._initial_request.operation_id, + ) + + if self._initial_request.client_type: + release.client_type = self._initial_request.client_type + + if not until_response_id: + release.release_all.CopyFrom(pb2.ReleaseExecuteRequest.ReleaseAll()) + else: + release.release_until.response_id = until_response_id + + return release + + def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> Any: + super().throw(type, value, traceback) + + def close(self) -> None: + return super().close() + + def __del__(self) -> None: + return self.close() diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 8cd39ba7a7918..9bba0db05e43f 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -233,7 +233,7 @@ def __init__(self, connection: Union[str, ChannelBuilder], userId: Optional[str] the $USER environment. Defining the user ID as part of the connection string takes precedence. """ - self._client = SparkConnectClient(connection=connection, userId=userId) + self._client = SparkConnectClient(connection=connection, user_id=userId) self._session_id = self._client._session_id def table(self, tableName: str) -> DataFrame: diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 834b0307238ae..ede6318782e0a 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -1788,6 +1788,8 @@ def client(self) -> "SparkConnectClient": Notes ----- + This API is unstable, and a developer API. It returns non-API instance + :class:`SparkConnectClient`. This is an API dedicated to Spark Connect client only. With regular Spark Session, it throws an exception. """ diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 5c39d4502f540..9276b88e153b8 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -30,7 +30,7 @@ @unittest.skipIf(not should_test_connect, connect_requirement_message) class SparkConnectClientTestCase(unittest.TestCase): def test_user_agent_passthrough(self): - client = SparkConnectClient("sc://foo/;user_agent=bar") + client = SparkConnectClient("sc://foo/;user_agent=bar", use_reattachable_execute=False) mock = MockService(client._session_id) client._stub = mock @@ -41,7 +41,7 @@ def test_user_agent_passthrough(self): self.assertRegex(mock.req.client_type, r"^bar spark/[^ ]+ os/[^ ]+ python/[^ ]+$") def test_user_agent_default(self): - client = SparkConnectClient("sc://foo/") + client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) mock = MockService(client._session_id) client._stub = mock @@ -54,11 +54,11 @@ def test_user_agent_default(self): ) def test_properties(self): - client = SparkConnectClient("sc://foo/;token=bar") + client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) self.assertEqual(client.token, "bar") self.assertEqual(client.host, "foo") - client = SparkConnectClient("sc://foo/") + client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) self.assertIsNone(client.token) def test_channel_builder(self): @@ -67,12 +67,14 @@ class CustomChannelBuilder(ChannelBuilder): def userId(self) -> Optional[str]: return "abc" - client = SparkConnectClient(CustomChannelBuilder("sc://foo/")) + client = SparkConnectClient( + CustomChannelBuilder("sc://foo/"), use_reattachable_execute=False + ) self.assertEqual(client._user_id, "abc") def test_interrupt_all(self): - client = SparkConnectClient("sc://foo/;token=bar") + client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) mock = MockService(client._session_id) client._stub = mock @@ -80,7 +82,7 @@ def test_interrupt_all(self): self.assertIsNotNone(mock.req, "Interrupt API was not called when expected") def test_is_closed(self): - client = SparkConnectClient("sc://foo/;token=bar") + client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) self.assertFalse(client.is_closed) client.close() diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index b6145d0a00618..ba81c7836728e 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -170,6 +170,10 @@ def conf(cls): # Disable JVM stack trace in Spark Connect tests to prevent the # HTTP header size from exceeding the maximum allowed size. conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "false") + # Make the server terminate reattachable streams every 1 second and 123 bytes, + # to make the tests exercise reattach. + conf.set("spark.connect.execute.reattachable.senderMaxStreamDuration", "1s") + conf.set("spark.connect.execute.reattachable.senderMaxStreamSize", "123") return conf @classmethod