From 9368a0f0c1001fb6fd64799a2e744874b6cd27e4 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 8 Aug 2023 11:03:05 +0900 Subject: [PATCH 1/9] [SPARK-44694][PYTHON][CONNECT] Refactor active sessions and expose them as an API ### What changes were proposed in this pull request? This PR proposes to (mostly) refactor all the internal workarounds to get the active session correctly. There are few things to note: - _PySpark without Spark Connect does not already support the hierarchy of active sessions_. With pinned thread mode (enabled by default), PySpark does map each Python thread to JVM thread, but the thread creation happens within gateway server, that does not respect the thread hierarchy. Therefore, this PR follows the exactly same behaviour. - New thread will not have an active thread by default. - Other behaviours are same as PySpark without Connect, see also https://github.com/apache/spark/pull/42367 - Since I am here, I piggiyback few documentation changes. We missed document `SparkSession.readStream`, `SparkSession.streams`, `SparkSession.udtf`, `SparkSession.conf` and `SparkSession.version` in Spark Connect. - The changes here are mostly refactoring that reuses existing unittests while I expose two methods: - `SparkSession.getActiveSession` (only for Spark Connect) - `SparkSession.active` (for both in PySpark) ### Why are the changes needed? For Spark Connect users to be able to play with active and default sessions in Python. ### Does this PR introduce _any_ user-facing change? Yes, it adds new API: - `SparkSession.getActiveSession` (only for Spark Connect) - `SparkSession.active` (for both in PySpark) ### How was this patch tested? Existing unittests should cover all. Closes #42371 from HyukjinKwon/SPARK-44694. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../reference/pyspark.sql/spark_session.rst | 1 + python/pyspark/errors/error_classes.py | 5 + python/pyspark/ml/connect/io_utils.py | 8 +- python/pyspark/ml/connect/tuning.py | 11 +- python/pyspark/ml/torch/distributor.py | 3 +- python/pyspark/ml/util.py | 13 --- python/pyspark/pandas/utils.py | 7 +- python/pyspark/sql/connect/session.py | 107 ++++++++++++------ python/pyspark/sql/connect/udf.py | 25 ++-- python/pyspark/sql/connect/udtf.py | 27 +++-- python/pyspark/sql/session.py | 65 +++++++++-- .../sql/tests/connect/test_connect_basic.py | 4 +- python/pyspark/sql/utils.py | 18 +++ 13 files changed, 197 insertions(+), 97 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/spark_session.rst b/python/docs/source/reference/pyspark.sql/spark_session.rst index c16ca4f162f5c..f25dbab5f6b9b 100644 --- a/python/docs/source/reference/pyspark.sql/spark_session.rst +++ b/python/docs/source/reference/pyspark.sql/spark_session.rst @@ -28,6 +28,7 @@ See also :class:`SparkSession`. .. autosummary:: :toctree: api/ + SparkSession.active SparkSession.builder.appName SparkSession.builder.config SparkSession.builder.enableHiveSupport diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index a534bc6deb41e..24885e94d3255 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -617,6 +617,11 @@ "Argument `` should be a WindowSpec, got ." ] }, + "NO_ACTIVE_OR_DEFAULT_SESSION" : { + "message" : [ + "No active or default Spark session found. Please create a new Spark session before running the code." + ] + }, "NO_ACTIVE_SESSION" : { "message" : [ "No active Spark session found. Please create a new Spark session before running the code." diff --git a/python/pyspark/ml/connect/io_utils.py b/python/pyspark/ml/connect/io_utils.py index 9a963086aaf45..a09a244862c58 100644 --- a/python/pyspark/ml/connect/io_utils.py +++ b/python/pyspark/ml/connect/io_utils.py @@ -23,7 +23,7 @@ from urllib.parse import urlparse from typing import Any, Dict, List from pyspark.ml.base import Params -from pyspark.ml.util import _get_active_session +from pyspark.sql import SparkSession from pyspark.sql.utils import is_remote @@ -34,7 +34,7 @@ def _copy_file_from_local_to_fs(local_path: str, dest_path: str) -> None: - session = _get_active_session(is_remote()) + session = SparkSession.active() if is_remote(): session.copyFromLocalToFs(local_path, dest_path) else: @@ -228,7 +228,7 @@ def save(self, path: str, *, overwrite: bool = False) -> None: .. versionadded:: 3.5.0 """ - session = _get_active_session(is_remote()) + session = SparkSession.active() path_exist = True try: session.read.format("binaryFile").load(path).head() @@ -256,7 +256,7 @@ def load(cls, path: str) -> "Params": .. versionadded:: 3.5.0 """ - session = _get_active_session(is_remote()) + session = SparkSession.active() tmp_local_dir = tempfile.mkdtemp(prefix="pyspark_ml_model_") try: diff --git a/python/pyspark/ml/connect/tuning.py b/python/pyspark/ml/connect/tuning.py index 6d539933e1d69..c22c31e84e8de 100644 --- a/python/pyspark/ml/connect/tuning.py +++ b/python/pyspark/ml/connect/tuning.py @@ -178,11 +178,12 @@ def _parallelFitTasks( def get_single_task(index: int, param_map: Any) -> Callable[[], Tuple[int, float]]: def single_task() -> Tuple[int, float]: - # Active session is thread-local variable, in background thread the active session - # is not set, the following line sets it as the main thread active session. - active_session._jvm.SparkSession.setActiveSession( # type: ignore[union-attr] - active_session._jsparkSession # type: ignore[union-attr] - ) + if not is_remote(): + # Active session is thread-local variable, in background thread the active session + # is not set, the following line sets it as the main thread active session. + active_session._jvm.SparkSession.setActiveSession( # type: ignore[union-attr] + active_session._jsparkSession # type: ignore[union-attr] + ) model = estimator.fit(train, param_map) metric = evaluator.evaluate( diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py index 2056803d61cf4..a4e79b1dcc10b 100644 --- a/python/pyspark/ml/torch/distributor.py +++ b/python/pyspark/ml/torch/distributor.py @@ -49,7 +49,6 @@ LogStreamingServer, ) from pyspark.ml.dl_util import FunctionPickler -from pyspark.ml.util import _get_active_session def _get_resources(session: SparkSession) -> Dict[str, ResourceInformation]: @@ -165,7 +164,7 @@ def __init__( from pyspark.sql.utils import is_remote self.is_remote = is_remote() - self.spark = _get_active_session(self.is_remote) + self.spark = SparkSession.active() # indicate whether the server side is local mode self.is_spark_local_master = False diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 2c90ff3cb7b69..64676947017d0 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -747,16 +747,3 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return f(*args, **kwargs) return cast(FuncT, wrapped) - - -def _get_active_session(is_remote: bool) -> SparkSession: - if not is_remote: - spark = SparkSession.getActiveSession() - else: - import pyspark.sql.connect.session - - spark = pyspark.sql.connect.session._active_spark_session # type: ignore[assignment] - - if spark is None: - raise RuntimeError("An active SparkSession is required for the distributor.") - return spark diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index c66b3359e77d1..55b9a57ef6187 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -478,12 +478,7 @@ def is_testing() -> bool: def default_session() -> SparkSession: - if not is_remote(): - spark = SparkSession.getActiveSession() - else: - from pyspark.sql.connect.session import _active_spark_session - - spark = _active_spark_session # type: ignore[assignment] + spark = SparkSession.getActiveSession() if spark is None: spark = SparkSession.builder.appName("pandas-on-Spark").getOrCreate() diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 9bba0db05e43f..d75a30c561f93 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -18,6 +18,7 @@ check_dependencies(__name__) +import threading import os import warnings from collections.abc import Sized @@ -36,6 +37,7 @@ overload, Iterable, TYPE_CHECKING, + ClassVar, ) import numpy as np @@ -93,14 +95,13 @@ from pyspark.sql.connect.udtf import UDTFRegistration -# `_active_spark_session` stores the active spark connect session created by -# `SparkSession.builder.getOrCreate`. It is used by ML code. -# If sessions are created with `SparkSession.builder.create`, it stores -# The last created session -_active_spark_session = None - - class SparkSession: + # The active SparkSession for the current thread + _active_session: ClassVar[threading.local] = threading.local() + # Reference to the root SparkSession + _default_session: ClassVar[Optional["SparkSession"]] = None + _lock: ClassVar[RLock] = RLock() + class Builder: """Builder for :class:`SparkSession`.""" @@ -176,8 +177,6 @@ def enableHiveSupport(self) -> "SparkSession.Builder": ) def create(self) -> "SparkSession": - global _active_spark_session - has_channel_builder = self._channel_builder is not None has_spark_remote = "spark.remote" in self._options @@ -200,23 +199,26 @@ def create(self) -> "SparkSession": assert spark_remote is not None session = SparkSession(connection=spark_remote) - _active_spark_session = session + SparkSession._set_default_and_active_session(session) return session def getOrCreate(self) -> "SparkSession": - global _active_spark_session - if _active_spark_session is not None: - return _active_spark_session - _active_spark_session = self.create() - return _active_spark_session + with SparkSession._lock: + session = SparkSession.getActiveSession() + if session is None: + session = SparkSession._default_session + if session is None: + session = self.create() + return session _client: SparkConnectClient @classproperty def builder(cls) -> Builder: - """Creates a :class:`Builder` for constructing a :class:`SparkSession`.""" return cls.Builder() + builder.__doc__ = PySparkSession.builder.__doc__ + def __init__(self, connection: Union[str, ChannelBuilder], userId: Optional[str] = None): """ Creates a new SparkSession for the Spark Connect interface. @@ -236,6 +238,38 @@ def __init__(self, connection: Union[str, ChannelBuilder], userId: Optional[str] self._client = SparkConnectClient(connection=connection, user_id=userId) self._session_id = self._client._session_id + @classmethod + def _set_default_and_active_session(cls, session: "SparkSession") -> None: + """ + Set the (global) default :class:`SparkSession`, and (thread-local) + active :class:`SparkSession` when they are not set yet. + """ + with cls._lock: + if cls._default_session is None: + cls._default_session = session + if getattr(cls._active_session, "session", None) is None: + cls._active_session.session = session + + @classmethod + def getActiveSession(cls) -> Optional["SparkSession"]: + return getattr(cls._active_session, "session", None) + + getActiveSession.__doc__ = PySparkSession.getActiveSession.__doc__ + + @classmethod + def active(cls) -> "SparkSession": + session = cls.getActiveSession() + if session is None: + session = cls._default_session + if session is None: + raise PySparkRuntimeError( + error_class="NO_ACTIVE_OR_DEFAULT_SESSION", + message_parameters={}, + ) + return session + + active.__doc__ = PySparkSession.active.__doc__ + def table(self, tableName: str) -> DataFrame: return self.read.table(tableName) @@ -251,6 +285,8 @@ def read(self) -> "DataFrameReader": def readStream(self) -> "DataStreamReader": return DataStreamReader(self) + readStream.__doc__ = PySparkSession.readStream.__doc__ + def _inferSchemaFromList( self, data: Iterable[Any], names: Optional[List[str]] = None ) -> StructType: @@ -601,19 +637,20 @@ def stop(self) -> None: # specifically in Spark Connect the Spark Connect server is designed for # multi-tenancy - the remote client side cannot just stop the server and stop # other remote clients being used from other users. - global _active_spark_session - self.client.close() - _active_spark_session = None - - if "SPARK_LOCAL_REMOTE" in os.environ: - # When local mode is in use, follow the regular Spark session's - # behavior by terminating the Spark Connect server, - # meaning that you can stop local mode, and restart the Spark Connect - # client with a different remote address. - active_session = PySparkSession.getActiveSession() - if active_session is not None: - active_session.stop() - with SparkContext._lock: + with SparkSession._lock: + self.client.close() + if self is SparkSession._default_session: + SparkSession._default_session = None + if self is getattr(SparkSession._active_session, "session", None): + SparkSession._active_session.session = None + + if "SPARK_LOCAL_REMOTE" in os.environ: + # When local mode is in use, follow the regular Spark session's + # behavior by terminating the Spark Connect server, + # meaning that you can stop local mode, and restart the Spark Connect + # client with a different remote address. + if PySparkSession._activeSession is not None: + PySparkSession._activeSession.stop() del os.environ["SPARK_LOCAL_REMOTE"] del os.environ["SPARK_CONNECT_MODE_ENABLED"] if "SPARK_REMOTE" in os.environ: @@ -628,20 +665,18 @@ def is_stopped(self) -> bool: """ return self.client.is_closed - @classmethod - def getActiveSession(cls) -> Any: - raise PySparkNotImplementedError( - error_class="NOT_IMPLEMENTED", message_parameters={"feature": "getActiveSession()"} - ) - @property def conf(self) -> RuntimeConf: return RuntimeConf(self.client) + conf.__doc__ = PySparkSession.conf.__doc__ + @property def streams(self) -> "StreamingQueryManager": return StreamingQueryManager(self) + streams.__doc__ = PySparkSession.streams.__doc__ + def __getattr__(self, name: str) -> Any: if name in ["_jsc", "_jconf", "_jvm", "_jsparkSession"]: raise PySparkAttributeError( @@ -675,6 +710,8 @@ def version(self) -> str: assert result is not None return result + version.__doc__ = PySparkSession.version.__doc__ + @property def client(self) -> "SparkConnectClient": return self._client diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 2d7e423d3d571..eb0541b936925 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -37,8 +37,7 @@ from pyspark.sql.connect.types import UnparsedDataType from pyspark.sql.types import DataType, StringType from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration -from pyspark.errors import PySparkTypeError - +from pyspark.errors import PySparkTypeError, PySparkRuntimeError if TYPE_CHECKING: from pyspark.sql.connect._typing import ( @@ -58,14 +57,20 @@ def _create_py_udf( from pyspark.sql.udf import _create_arrow_py_udf if useArrow is None: - from pyspark.sql.connect.session import _active_spark_session - - is_arrow_enabled = ( - False - if _active_spark_session is None - else _active_spark_session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled") - == "true" - ) + is_arrow_enabled = False + try: + from pyspark.sql.connect.session import SparkSession + + session = SparkSession.active() + is_arrow_enabled = ( + str(session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled")).lower() + == "true" + ) + except PySparkRuntimeError as e: + if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION": + pass # Just uses the default if no session found. + else: + raise e else: is_arrow_enabled = useArrow diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index 5a95075a65537..c8495626292c5 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -68,13 +68,20 @@ def _create_py_udtf( if useArrow is not None: arrow_enabled = useArrow else: - from pyspark.sql.connect.session import _active_spark_session + from pyspark.sql.connect.session import SparkSession arrow_enabled = False - if _active_spark_session is not None: - value = _active_spark_session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled") - if isinstance(value, str) and value.lower() == "true": - arrow_enabled = True + try: + session = SparkSession.active() + arrow_enabled = ( + str(session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled")).lower() + == "true" + ) + except PySparkRuntimeError as e: + if e.error_class == "NO_ACTIVE_OR_DEFAULT_SESSION": + pass # Just uses the default if no session found. + else: + raise e # Create a regular Python UDTF and check for invalid handler class. regular_udtf = _create_udtf(cls, returnType, name, PythonEvalType.SQL_TABLE_UDF, deterministic) @@ -160,17 +167,13 @@ def _build_common_inline_user_defined_table_function( ) def __call__(self, *cols: "ColumnOrName") -> "DataFrame": + from pyspark.sql.connect.session import SparkSession from pyspark.sql.connect.dataframe import DataFrame - from pyspark.sql.connect.session import _active_spark_session - if _active_spark_session is None: - raise PySparkRuntimeError( - "An active SparkSession is required for " - "executing a Python user-defined table function." - ) + session = SparkSession.active() plan = self._build_common_inline_user_defined_table_function(*cols) - return DataFrame.withPlan(plan, _active_spark_session) + return DataFrame.withPlan(plan, session) def asNondeterministic(self) -> "UserDefinedTableFunction": self.deterministic = False diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index ede6318782e0a..9141051fdf830 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -64,8 +64,8 @@ _from_numpy_type, ) from pyspark.errors.exceptions.captured import install_exception_handler -from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str -from pyspark.errors import PySparkValueError, PySparkTypeError +from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str, try_remote_session_classmethod +from pyspark.errors import PySparkValueError, PySparkTypeError, PySparkRuntimeError if TYPE_CHECKING: from pyspark.sql._typing import AtomicValue, RowLike, OptionalPrimitiveType @@ -500,7 +500,7 @@ def getOrCreate(self) -> "SparkSession": ).applyModifiableSettings(session._jsparkSession, self._options) return session - # SparkConnect-specific API + # Spark Connect-specific API def create(self) -> "SparkSession": """Creates a new SparkSession. Can only be used in the context of Spark Connect and will throw an exception otherwise. @@ -510,6 +510,10 @@ def create(self) -> "SparkSession": Returns ------- :class:`SparkSession` + + Notes + ----- + This method will update the default and/or active session if they are not set. """ opts = dict(self._options) if "SPARK_REMOTE" in os.environ or "spark.remote" in opts: @@ -546,7 +550,11 @@ def create(self) -> "SparkSession": # to Python 3.9.6 (https://github.com/python/cpython/pull/28838) @classproperty def builder(cls) -> Builder: - """Creates a :class:`Builder` for constructing a :class:`SparkSession`.""" + """Creates a :class:`Builder` for constructing a :class:`SparkSession`. + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + """ return cls.Builder() _instantiatedSession: ClassVar[Optional["SparkSession"]] = None @@ -632,12 +640,16 @@ def newSession(self) -> "SparkSession": return self.__class__(self._sc, self._jsparkSession.newSession()) @classmethod + @try_remote_session_classmethod def getActiveSession(cls) -> Optional["SparkSession"]: """ Returns the active :class:`SparkSession` for the current thread, returned by the builder .. versionadded:: 3.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- :class:`SparkSession` @@ -667,6 +679,30 @@ def getActiveSession(cls) -> Optional["SparkSession"]: else: return None + @classmethod + @try_remote_session_classmethod + def active(cls) -> "SparkSession": + """ + Returns the active or default :class:`SparkSession` for the current thread, returned by + the builder. + + .. versionadded:: 3.5.0 + + Returns + ------- + :class:`SparkSession` + Spark session if an active or default session exists for the current thread. + """ + session = cls.getActiveSession() + if session is None: + session = cls._instantiatedSession + if session is None: + raise PySparkRuntimeError( + error_class="NO_ACTIVE_OR_DEFAULT_SESSION", + message_parameters={}, + ) + return session + @property def sparkContext(self) -> SparkContext: """ @@ -698,6 +734,9 @@ def version(self) -> str: .. versionadded:: 2.0.0 + .. versionchanged:: 3.4.0 + Supports Spark Connect. + Returns ------- str @@ -719,6 +758,9 @@ def conf(self) -> RuntimeConfig: .. versionadded:: 2.0.0 + .. versionchanged:: 3.4.0 + Supports Spark Connect. + Returns ------- :class:`pyspark.sql.conf.RuntimeConfig` @@ -726,7 +768,7 @@ def conf(self) -> RuntimeConfig: Examples -------- >>> spark.conf - + Set a runtime configuration for the session @@ -805,6 +847,9 @@ def udtf(self) -> "UDTFRegistration": .. versionadded:: 3.5.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Returns ------- :class:`UDTFRegistration` @@ -1639,6 +1684,9 @@ def readStream(self) -> DataStreamReader: .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -1650,7 +1698,7 @@ def readStream(self) -> DataStreamReader: Examples -------- >>> spark.readStream - + The example below uses Rate source that generates rows continuously. After that, we operate a modulo by 3, and then write the stream out to the console. @@ -1672,6 +1720,9 @@ def streams(self) -> "StreamingQueryManager": .. versionadded:: 2.0.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is evolving. @@ -1683,7 +1734,7 @@ def streams(self) -> "StreamingQueryManager": Examples -------- >>> spark.streams - + Get the list of active streaming queries diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 065f1585a9f06..0687fc9f31331 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -3043,9 +3043,6 @@ def test_unsupported_functions(self): def test_unsupported_session_functions(self): # SPARK-41934: Disable unsupported functions. - with self.assertRaises(NotImplementedError): - RemoteSparkSession.getActiveSession() - with self.assertRaises(NotImplementedError): RemoteSparkSession.builder.enableHiveSupport() @@ -3331,6 +3328,7 @@ def test_error_stack_trace(self): spark.stop() def test_can_create_multiple_sessions_to_different_remotes(self): + self.spark.stop() self.assertIsNotNone(self.spark._client) # Creates a new remote session. other = PySparkSession.builder.remote("sc://other.remote:114/").create() diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 8b520ed653f8c..d4f56fe822f3e 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import inspect import functools import os from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar, Union, Type @@ -258,6 +259,23 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return cast(FuncT, wrapped) +def try_remote_session_classmethod(f: FuncT) -> FuncT: + """Mark API supported from Spark Connect.""" + + @functools.wraps(f) + def wrapped(*args: Any, **kwargs: Any) -> Any: + + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + from pyspark.sql.connect.session import SparkSession # type: ignore[misc] + + assert inspect.isclass(args[0]) + return getattr(SparkSession, f.__name__)(*args[1:], **kwargs) + else: + return f(*args, **kwargs) + + return cast(FuncT, wrapped) + + def pyspark_column_op( func_name: str, left: "IndexOpsLike", right: Any, fillna: Any = None ) -> Union["SeriesOrIndex", None]: From 630b1777904f15c7ac05c3cd61c0006cd692bc93 Mon Sep 17 00:00:00 2001 From: Siying Dong Date: Tue, 8 Aug 2023 11:11:56 +0900 Subject: [PATCH 2/9] [SPARK-44683][SS] Logging level isn't passed to RocksDB state store provider correctly ### What changes were proposed in this pull request? The logging level is passed into RocksDB in a correct way. ### Why are the changes needed? We pass log4j's log level to RocksDB so that RocksDB debug log can go to log4j. However, we pass in log level after we create the logger. However, the way it is set isn't effective. This has two impacts: (1) setting DEBUG level don't make RocksDB generate DEBUG level logs; (2) setting WARN or ERROR level does prevent INFO level logging, but RocksDB still makes JNI calls to Scala, which is an unnecessary overhead. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually change the log level and observe the log lines in unit tests. Closes #42354 from siying/rocks_log_level. Authored-by: Siying Dong Signed-off-by: Jungtaek Lim --- .../apache/spark/sql/execution/streaming/state/RocksDB.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index d4366fe732be4..a2868df941178 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -611,8 +611,11 @@ class RocksDB( if (log.isWarnEnabled) dbLogLevel = InfoLogLevel.WARN_LEVEL if (log.isInfoEnabled) dbLogLevel = InfoLogLevel.INFO_LEVEL if (log.isDebugEnabled) dbLogLevel = InfoLogLevel.DEBUG_LEVEL - dbOptions.setLogger(dbLogger) + dbLogger.setInfoLogLevel(dbLogLevel) + // The log level set in dbLogger is effective and the one to dbOptions isn't applied to + // customized logger. We still set it as it might show up in RocksDB config file or logging. dbOptions.setInfoLogLevel(dbLogLevel) + dbOptions.setLogger(dbLogger) logInfo(s"Set RocksDB native logging level to $dbLogLevel") dbLogger } From 7493c5764f9644878babacccd4f688fe13ef84aa Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 8 Aug 2023 04:15:07 +0200 Subject: [PATCH 3/9] [SPARK-43429][CONNECT] Add Default & Active SparkSession for Scala Client ### What changes were proposed in this pull request? This adds the `default` and `active` session variables to `SparkSession`: - `default` session is global value. It is typically the first session created through `getOrCreate`. It can be changed through `set` or `clear`. If the session is closed and it is the `default` session we clear the `default` session. - `active` session is a thread local value. It is typically the first session created in this thread or it inherits is value from its parent thread. It can be changed through `set` or `clear`, please note that these methods operate thread locally, so they won't change the parent or children. If the session is closed and it is the `active` session for the current thread then we clear the active value (only for the current thread!). ### Why are the changes needed? To increase compatibility with the existing SparkSession API in `sql/core`. ### Does this PR introduce _any_ user-facing change? Yes. It adds a couple methods that were missing from the Scala Client. ### How was this patch tested? Added tests to `SparkSessionSuite`. Closes #42367 from hvanhovell/SPARK-43429. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../org/apache/spark/sql/SparkSession.scala | 100 ++++++++++-- .../apache/spark/sql/SparkSessionSuite.scala | 144 ++++++++++++++++-- .../CheckConnectJvmClientCompatibility.scala | 2 - 3 files changed, 225 insertions(+), 21 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 355d7edadc788..7367ed153f7db 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.io.Closeable import java.net.URI import java.util.concurrent.TimeUnit._ -import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import scala.collection.JavaConverters._ import scala.reflect.runtime.universe.TypeTag @@ -730,6 +730,23 @@ object SparkSession extends Logging { override def load(c: Configuration): SparkSession = create(c) }) + /** The active SparkSession for the current thread. */ + private val activeThreadSession = new InheritableThreadLocal[SparkSession] + + /** Reference to the root SparkSession. */ + private val defaultSession = new AtomicReference[SparkSession] + + /** + * Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when + * they are not set yet. + */ + private def setDefaultAndActiveSession(session: SparkSession): Unit = { + defaultSession.compareAndSet(null, session) + if (getActiveSession.isEmpty) { + setActiveSession(session) + } + } + /** * Create a new [[SparkSession]] based on the connect client [[Configuration]]. */ @@ -742,8 +759,17 @@ object SparkSession extends Logging { */ private[sql] def onSessionClose(session: SparkSession): Unit = { sessions.invalidate(session.client.configuration) + defaultSession.compareAndSet(session, null) + if (getActiveSession.contains(session)) { + clearActiveSession() + } } + /** + * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. + * + * @since 3.4.0 + */ def builder(): Builder = new Builder() private[sql] lazy val cleaner = { @@ -799,10 +825,15 @@ object SparkSession extends Logging { * * This will always return a newly created session. * + * This method will update the default and/or active session if they are not set. + * * @since 3.5.0 */ def create(): SparkSession = { - tryCreateSessionFromClient().getOrElse(SparkSession.this.create(builder.configuration)) + val session = tryCreateSessionFromClient() + .getOrElse(SparkSession.this.create(builder.configuration)) + setDefaultAndActiveSession(session) + session } /** @@ -811,30 +842,79 @@ object SparkSession extends Logging { * If a session exist with the same configuration that is returned instead of creating a new * session. * + * This method will update the default and/or active session if they are not set. + * * @since 3.5.0 */ def getOrCreate(): SparkSession = { - tryCreateSessionFromClient().getOrElse(sessions.get(builder.configuration)) + val session = tryCreateSessionFromClient() + .getOrElse(sessions.get(builder.configuration)) + setDefaultAndActiveSession(session) + session } } - def getActiveSession: Option[SparkSession] = { - throw new UnsupportedOperationException("getActiveSession is not supported") + /** + * Returns the default SparkSession. + * + * @since 3.5.0 + */ + def getDefaultSession: Option[SparkSession] = Option(defaultSession.get()) + + /** + * Sets the default SparkSession. + * + * @since 3.5.0 + */ + def setDefaultSession(session: SparkSession): Unit = { + defaultSession.set(session) } - def getDefaultSession: Option[SparkSession] = { - throw new UnsupportedOperationException("getDefaultSession is not supported") + /** + * Clears the default SparkSession. + * + * @since 3.5.0 + */ + def clearDefaultSession(): Unit = { + defaultSession.set(null) } + /** + * Returns the active SparkSession for the current thread. + * + * @since 3.5.0 + */ + def getActiveSession: Option[SparkSession] = Option(activeThreadSession.get()) + + /** + * Changes the SparkSession that will be returned in this thread and its children when + * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives + * an isolated SparkSession. + * + * @since 3.5.0 + */ def setActiveSession(session: SparkSession): Unit = { - throw new UnsupportedOperationException("setActiveSession is not supported") + activeThreadSession.set(session) } + /** + * Clears the active SparkSession for current thread. + * + * @since 3.5.0 + */ def clearActiveSession(): Unit = { - throw new UnsupportedOperationException("clearActiveSession is not supported") + activeThreadSession.remove() } + /** + * Returns the currently active SparkSession, otherwise the default one. If there is no default + * SparkSession, throws an exception. + * + * @since 3.5.0 + */ def active: SparkSession = { - throw new UnsupportedOperationException("active is not supported") + getActiveSession + .orElse(getDefaultSession) + .getOrElse(throw new IllegalStateException("No active or default Spark session found")) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 97fb46bf48af4..f06744399f833 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -16,6 +16,10 @@ */ package org.apache.spark.sql +import java.util.concurrent.{Executors, Phaser} + +import scala.util.control.NonFatal + import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} import org.apache.spark.sql.connect.client.util.ConnectFunSuite @@ -24,6 +28,10 @@ import org.apache.spark.sql.connect.client.util.ConnectFunSuite * Tests for non-dataframe related SparkSession operations. */ class SparkSessionSuite extends ConnectFunSuite { + private val connectionString1: String = "sc://test.it:17845" + private val connectionString2: String = "sc://test.me:14099" + private val connectionString3: String = "sc://doit:16845" + test("default") { val session = SparkSession.builder().getOrCreate() assert(session.client.configuration.host == "localhost") @@ -32,16 +40,15 @@ class SparkSessionSuite extends ConnectFunSuite { } test("remote") { - val session = SparkSession.builder().remote("sc://test.me:14099").getOrCreate() + val session = SparkSession.builder().remote(connectionString2).getOrCreate() assert(session.client.configuration.host == "test.me") assert(session.client.configuration.port == 14099) session.close() } test("getOrCreate") { - val connectionString = "sc://test.it:17865" - val session1 = SparkSession.builder().remote(connectionString).getOrCreate() - val session2 = SparkSession.builder().remote(connectionString).getOrCreate() + val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() + val session2 = SparkSession.builder().remote(connectionString1).getOrCreate() try { assert(session1 eq session2) } finally { @@ -51,9 +58,8 @@ class SparkSessionSuite extends ConnectFunSuite { } test("create") { - val connectionString = "sc://test.it:17845" - val session1 = SparkSession.builder().remote(connectionString).create() - val session2 = SparkSession.builder().remote(connectionString).create() + val session1 = SparkSession.builder().remote(connectionString1).create() + val session2 = SparkSession.builder().remote(connectionString1).create() try { assert(session1 ne session2) assert(session1.client.configuration == session2.client.configuration) @@ -64,8 +70,7 @@ class SparkSessionSuite extends ConnectFunSuite { } test("newSession") { - val connectionString = "sc://doit:16845" - val session1 = SparkSession.builder().remote(connectionString).create() + val session1 = SparkSession.builder().remote(connectionString3).create() val session2 = session1.newSession() try { assert(session1 ne session2) @@ -92,5 +97,126 @@ class SparkSessionSuite extends ConnectFunSuite { assertThrows[RuntimeException] { session.range(10).count() } + session.close() + } + + test("Default/Active session") { + // Make sure we start with a clean slate. + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + intercept[IllegalStateException](SparkSession.active) + + // Create a session + val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session1)) + assert(SparkSession.active == session1) + + // Create another session... + val session2 = SparkSession.builder().remote(connectionString2).create() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session1)) + SparkSession.setActiveSession(session2) + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + // Clear sessions + SparkSession.clearDefaultSession() + assert(SparkSession.getDefaultSession.isEmpty) + SparkSession.clearActiveSession() + assert(SparkSession.getDefaultSession.isEmpty) + + // Flip sessions + SparkSession.setActiveSession(session1) + SparkSession.setDefaultSession(session2) + assert(SparkSession.getDefaultSession.contains(session2)) + assert(SparkSession.getActiveSession.contains(session1)) + + // Close session1 + session1.close() + assert(SparkSession.getDefaultSession.contains(session2)) + assert(SparkSession.getActiveSession.isEmpty) + + // Close session2 + session2.close() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + } + + test("active session in multiple threads") { + SparkSession.clearDefaultSession() + SparkSession.clearActiveSession() + val session1 = SparkSession.builder().remote(connectionString1).create() + val session2 = SparkSession.builder().remote(connectionString1).create() + SparkSession.setActiveSession(session2) + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + val phaser = new Phaser(2) + val executor = Executors.newFixedThreadPool(2) + def execute(block: Phaser => Unit): java.util.concurrent.Future[Boolean] = { + executor.submit[Boolean] { () => + try { + block(phaser) + true + } catch { + case NonFatal(e) => + phaser.forceTermination() + throw e + } + } + } + + try { + val script1 = execute { phaser => + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + session1.close() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(session2)) + SparkSession.clearActiveSession() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.isEmpty) + } + val script2 = execute { phaser => + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(session2)) + SparkSession.clearActiveSession() + val internalSession = SparkSession.builder().remote(connectionString3).getOrCreate() + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.contains(session1)) + assert(SparkSession.getActiveSession.contains(internalSession)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(internalSession)) + + phaser.arriveAndAwaitAdvance() + assert(SparkSession.getDefaultSession.isEmpty) + assert(SparkSession.getActiveSession.contains(internalSession)) + internalSession.close() + assert(SparkSession.getActiveSession.isEmpty) + } + assert(script1.get()) + assert(script2.get()) + assert(SparkSession.getActiveSession.contains(session2)) + session2.close() + assert(SparkSession.getActiveSession.isEmpty) + } finally { + executor.shutdown() + } } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 6e577e0f21257..2bf9c41fb2cbd 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -207,8 +207,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.RelationalGroupedDataset.apply"), // SparkSession - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.clearDefaultSession"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.setDefaultSession"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sparkContext"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sharedState"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.sessionState"), From aa1261dc129618d27a1bdc743a5fdd54219f7c01 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Mon, 7 Aug 2023 19:16:38 -0700 Subject: [PATCH 4/9] [SPARK-44641][SQL] Incorrect result in certain scenarios when SPJ is not triggered ### What changes were proposed in this pull request? This PR makes sure we use unique partition values when calculating the final partitions in `BatchScanExec`, to make sure no duplicated partitions are generated. ### Why are the changes needed? When `spark.sql.sources.v2.bucketing.pushPartValues.enabled` and `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` are enabled, and SPJ is not triggered, currently Spark will generate incorrect/duplicated results. This is because with both configs enabled, Spark will delay the partition grouping until the time it calculates the final partitions used by the input RDD. To calculate the partitions, it uses partition values from the `KeyGroupedPartitioning` to find out the right ordering for the partitions. However, since grouping is not done when the partition values is computed, there could be duplicated partition values. This means the result could contain duplicated partitions too. ### Does this PR introduce _any_ user-facing change? No, this is a bug fix. ### How was this patch tested? Added a new test case for this scenario. Closes #42324 from sunchao/SPARK-44641. Authored-by: Chao Sun Signed-off-by: Chao Sun --- .../plans/physical/partitioning.scala | 9 ++- .../datasources/v2/BatchScanExec.scala | 9 ++- .../KeyGroupedPartitioningSuite.scala | 56 +++++++++++++++++++ 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index bd8ba54ddd736..456005768bd42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -313,7 +313,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in * ascending order, after evaluated by the transforms in `expressions`, for each input partition. * In addition, its length must be the same as the number of input partitions (and thus is a 1-1 - * mapping), and each row in `partitionValues` must be unique. + * mapping). The `partitionValues` may contain duplicated partition values. * * For example, if `expressions` is `[years(ts_col)]`, then a valid value of `partitionValues` is * `[0, 1, 2]`, which represents 3 input partitions with distinct partition values. All rows @@ -355,6 +355,13 @@ case class KeyGroupedPartitioning( override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = KeyGroupedShuffleSpec(this, distribution) + + lazy val uniquePartitionValues: Seq[InternalRow] = { + partitionValues + .map(InternalRowComparableWrapper(_, expressions)) + .distinct + .map(_.row) + } } object KeyGroupedPartitioning { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 4b53819739262..eba3c71f871e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -190,10 +190,17 @@ case class BatchScanExec( Seq.fill(numSplits)(Seq.empty)) } } else { + // either `commonPartitionValues` is not defined, or it is defined but + // `applyPartialClustering` is false. val partitionMapping = groupedPartitions.map { case (row, parts) => InternalRowComparableWrapper(row, p.expressions) -> parts }.toMap - finalPartitions = p.partitionValues.map { partValue => + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + finalPartitions = p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 880c30ba9f98d..8461f528277c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1039,4 +1039,60 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } } + + test("SPARK-44641: duplicated records when SPJ is not triggered") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, items_schema, items_partitions) + sql(s""" + INSERT INTO testcat.ns.$items VALUES + (1, 'aa', 40.0, cast('2020-01-01' as timestamp)), + (1, 'aa', 41.0, cast('2020-01-15' as timestamp)), + (2, 'bb', 10.0, cast('2020-01-01' as timestamp)), + (2, 'bb', 10.5, cast('2020-01-01' as timestamp)), + (3, 'cc', 15.5, cast('2020-02-01' as timestamp))""") + + val purchases_partitions = Array(bucket(8, "item_id")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"""INSERT INTO testcat.ns.$purchases VALUES + (1, 42.0, cast('2020-01-01' as timestamp)), + (1, 44.0, cast('2020-01-15' as timestamp)), + (1, 45.0, cast('2020-01-15' as timestamp)), + (2, 11.0, cast('2020-01-01' as timestamp)), + (3, 19.5, cast('2020-02-01' as timestamp))""") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClusteredEnabled => + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClusteredEnabled.toString) { + + // join keys are not the same as the partition keys, therefore SPJ is not triggered. + val df = sql( + s""" + SELECT id, name, i.price as purchase_price, p.item_id, p.price as sale_price + FROM testcat.ns.$items i JOIN testcat.ns.$purchases p + ON i.arrive_time = p.time ORDER BY id, purchase_price, p.item_id, sale_price + """) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.nonEmpty, "shuffle should exist when SPJ is not used") + + checkAnswer(df, + Seq( + Row(1, "aa", 40.0, 1, 42.0), + Row(1, "aa", 40.0, 2, 11.0), + Row(1, "aa", 41.0, 1, 44.0), + Row(1, "aa", 41.0, 1, 45.0), + Row(2, "bb", 10.0, 1, 42.0), + Row(2, "bb", 10.0, 2, 11.0), + Row(2, "bb", 10.5, 1, 42.0), + Row(2, "bb", 10.5, 2, 11.0), + Row(3, "cc", 15.5, 3, 19.5) + ) + ) + } + } + } + } } From 6dadd188f3652816c291919a2413f73c13bb1b47 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Tue, 8 Aug 2023 11:04:53 +0800 Subject: [PATCH 5/9] [SPARK-44554][INFRA] Make Python linter related checks pass of branch-3.3/3.4 daily testing ### What changes were proposed in this pull request? The daily testing of `branch-3.3/3.4` uses the same yml file as the master now and the upgrade to `MyPy` in https://github.com/apache/spark/pull/41690 resulted in Python linter check failure of `branch-3.3/3.4`, - branch-3.3: https://github.com/apache/spark/actions/runs/5677524469/job/15386025539 - branch-3.4: https://github.com/apache/spark/actions/runs/5678626664/job/15389273919 image So this pr do the following change for workaround: 1. Install different Python linter dependencies for `branch-3.3/3.4`, the dependency list comes from the corresponding branch to ensure compatibility with the version 2. Skip `Install dependencies for Python code generation check` and `Python code generation check` for `branch-3.3/3.4` due to they do not use `Buf remote plugins` and `Buf remote generation` is no longer supported. Meanwhile, the protobuf files in the branch generally do not change, so we can skip this check. ### Why are the changes needed? Make Python linter related checks pass of branch-3.3/3.4 daily testing ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions - Manually checked branch-3.4, the newly added condition should be ok Closes #42167 from LuciferYang/SPARK-44554. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: yangjie01 --- .github/workflows/build_and_test.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index cd68c0904d9a4..b4559dea42bb9 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -657,7 +657,22 @@ jobs: - name: Spark connect jvm client mima check if: inputs.branch != 'branch-3.3' run: ./dev/connect-jvm-client-mima-check + - name: Install Python linter dependencies for branch-3.3 + if: inputs.branch == 'branch-3.3' + run: | + # SPARK-44554: Copy from https://github.com/apache/spark/blob/073d0b60d31bf68ebacdc005f59b928a5902670f/.github/workflows/build_and_test.yml#L501-L508 + # Should delete this section after SPARK 3.3 EOL. + python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==21.12b0' + python3.9 -m pip install 'pandas-stubs==1.2.0.53' + - name: Install Python linter dependencies for branch-3.4 + if: inputs.branch == 'branch-3.4' + run: | + # SPARK-44554: Copy from https://github.com/apache/spark/blob/a05c27e85829fe742c1828507a1fd180cdc84b54/.github/workflows/build_and_test.yml#L571-L578 + # Should delete this section after SPARK 3.4 EOL. + python3.9 -m pip install 'flake8==3.9.0' pydata_sphinx_theme 'mypy==0.920' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' numpydoc 'jinja2<3.0.0' 'black==22.6.0' + python3.9 -m pip install 'pandas-stubs==1.2.0.53' ipython 'grpcio==1.48.1' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' - name: Install Python linter dependencies + if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4' run: | # TODO(SPARK-32407): Sphinx 3.1+ does not correctly index nested classes. # See also https://github.com/sphinx-doc/sphinx/issues/7551. @@ -668,6 +683,7 @@ jobs: - name: Python linter run: PYTHON_EXECUTABLE=python3.9 ./dev/lint-python - name: Install dependencies for Python code generation check + if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4' run: | # See more in "Installation" https://docs.buf.build/installation#tarball curl -LO https://github.com/bufbuild/buf/releases/download/v1.24.0/buf-Linux-x86_64.tar.gz @@ -676,6 +692,7 @@ jobs: rm buf-Linux-x86_64.tar.gz python3.9 -m pip install 'protobuf==3.20.3' 'mypy-protobuf==3.3.0' - name: Python code generation check + if: inputs.branch != 'branch-3.3' && inputs.branch != 'branch-3.4' run: if test -f ./dev/connect-check-protos.py; then PATH=$PATH:$HOME/buf/bin PYTHON_EXECUTABLE=python3.9 ./dev/connect-check-protos.py; fi - name: Install JavaScript linter dependencies run: | From 25053d98186489d9f2061c9b815a5a33f7e309c4 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Tue, 8 Aug 2023 11:06:21 +0800 Subject: [PATCH 6/9] [SPARK-44689][CONNECT] Make the exception handling of function `SparkConnectPlanner#unpackScalarScalaUDF` more universal ### What changes were proposed in this pull request? This PR changes the exception handling in the `unpackScalarScalaUD` function in `SparkConnectPlanner` from determining the exception type based on a fixed nesting level to using Guava `Throwables` to get the root cause and then determining the type of the root cause. This makes it compatible with differences between different Java versions. ### Why are the changes needed? The following failure occurred when testing `UDFClassLoadingE2ESuite` in Java 17 daily test: https://github.com/apache/spark/actions/runs/5766913899/job/15635782831 ``` [info] UDFClassLoadingE2ESuite: [info] - update class loader after stubbing: new session *** FAILED *** (101 milliseconds) [info] "Exception in SerializedLambda.readResolve" did not contain "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf" (UDFClassLoadingE2ESuite.scala:57) ... [info] - update class loader after stubbing: same session *** FAILED *** (52 milliseconds) [info] "Exception in SerializedLambda.readResolve" did not contain "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf" (UDFClassLoadingE2ESuite.scala:73) ... ``` After analysis, it was found that there are differences in the exception stack generated on the server side between Java 8 and Java 17: - Java 8 ``` java.io.IOException: unexpected exception type at java.io.ObjectStreamClass.throwMiscException(ObjectStreamClass.java:1750) at java.io.ObjectStreamClass.invokeReadResolve(ObjectStreamClass.java:1280) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2222) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669) at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2431) at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2355) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2213) at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1669) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:503) at java.io.ObjectInputStream.readObject(ObjectInputStream.java:461) at org.apache.spark.util.SparkSerDeUtils.deserialize(SparkSerDeUtils.scala:50) at org.apache.spark.util.SparkSerDeUtils.deserialize$(SparkSerDeUtils.scala:41) at org.apache.spark.util.Utils$.deserialize(Utils.scala:95) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.unpackScalarScalaUDF(SparkConnectPlanner.scala:1516) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.org$apache$spark$sql$connect$planner$SparkConnectPlanner$$unpackUdf(SparkConnectPlanner.scala:1507) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformScalarScalaFunction(SparkConnectPlanner.scala:1544) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleRegisterScalarScalaUDF(SparkConnectPlanner.scala:2565) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleRegisterUserDefinedFunction(SparkConnectPlanner.scala:2492) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.process(SparkConnectPlanner.scala:2363) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.handleCommand(ExecuteThreadRunner.scala:202) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1(ExecuteThreadRunner.scala:158) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1$adapted(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:184) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:184) at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withContextClassLoader$1(SessionHolder.scala:171) at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:179) at org.apache.spark.sql.connect.service.SessionHolder.withContextClassLoader(SessionHolder.scala:170) at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:183) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.executeInternal(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.org$apache$spark$sql$connect$execution$ExecuteThreadRunner$$execute(ExecuteThreadRunner.scala:84) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner$ExecutionThread.run(ExecuteThreadRunner.scala:227) Caused by: java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf.$deserializeLambda$(java.lang.invoke.SerializedLambda) at java.lang.Class.getDeclaredMethod(Class.java:2130) at java.lang.invoke.SerializedLambda$1.run(SerializedLambda.java:224) at java.lang.invoke.SerializedLambda$1.run(SerializedLambda.java:221) at java.security.AccessController.doPrivileged(Native Method) at java.lang.invoke.SerializedLambda.readResolve(SerializedLambda.java:221) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at java.io.ObjectStreamClass.invokeReadResolve(ObjectStreamClass.java:1274) ... 31 more ``` - Java 17 ``` java.lang.RuntimeException: Exception in SerializedLambda.readResolve at java.base/java.lang.invoke.SerializedLambda.readResolve(SerializedLambda.java:288) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77) at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.base/java.lang.reflect.Method.invoke(Method.java:568) at java.base/java.io.ObjectStreamClass.invokeReadResolve(ObjectStreamClass.java:1190) at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2266) at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1733) at java.base/java.io.ObjectInputStream$FieldValues.(ObjectInputStream.java:2606) at java.base/java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:2457) at java.base/java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2257) at java.base/java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1733) at java.base/java.io.ObjectInputStream.readObject(ObjectInputStream.java:509) at java.base/java.io.ObjectInputStream.readObject(ObjectInputStream.java:467) at org.apache.spark.util.SparkSerDeUtils.deserialize(SparkSerDeUtils.scala:50) at org.apache.spark.util.SparkSerDeUtils.deserialize$(SparkSerDeUtils.scala:41) at org.apache.spark.util.Utils$.deserialize(Utils.scala:95) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.unpackScalarScalaUDF(SparkConnectPlanner.scala:1517) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.org$apache$spark$sql$connect$planner$SparkConnectPlanner$$unpackUdf(SparkConnectPlanner.scala:1507) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformScalarScalaFunction(SparkConnectPlanner.scala:1552) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleRegisterScalarScalaUDF(SparkConnectPlanner.scala:2573) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.handleRegisterUserDefinedFunction(SparkConnectPlanner.scala:2500) at org.apache.spark.sql.connect.planner.SparkConnectPlanner.process(SparkConnectPlanner.scala:2371) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.handleCommand(ExecuteThreadRunner.scala:202) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1(ExecuteThreadRunner.scala:158) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.$anonfun$executeInternal$1$adapted(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:184) at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:184) at org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94) at org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withContextClassLoader$1(SessionHolder.scala:171) at org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:179) at org.apache.spark.sql.connect.service.SessionHolder.withContextClassLoader(SessionHolder.scala:170) at org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:183) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.executeInternal(ExecuteThreadRunner.scala:132) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner.org$apache$spark$sql$connect$execution$ExecuteThreadRunner$$execute(ExecuteThreadRunner.scala:84) at org.apache.spark.sql.connect.execution.ExecuteThreadRunner$ExecutionThread.run(ExecuteThreadRunner.scala:227) Caused by: java.security.PrivilegedActionException: java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf.$deserializeLambda$(java.lang.invoke.SerializedLambda) at java.base/java.security.AccessController.doPrivileged(AccessController.java:573) at java.base/java.lang.invoke.SerializedLambda.readResolve(SerializedLambda.java:269) ... 36 more Caused by: java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf.$deserializeLambda$(java.lang.invoke.SerializedLambda) at java.base/java.lang.Class.getDeclaredMethod(Class.java:2675) at java.base/java.lang.invoke.SerializedLambda$1.run(SerializedLambda.java:272) at java.base/java.lang.invoke.SerializedLambda$1.run(SerializedLambda.java:269) at java.base/java.security.AccessController.doPrivileged(AccessController.java:569) ... 37 more ``` While their root exceptions are both `NoSuchMethodException`, the levels of nesting are different. We can add an exception check branch to make it compatible with Java 17, for example: ```scala case e: IOException if e.getCause.isInstanceOf[NoSuchMethodException] => throw new ClassNotFoundException(... ${e.getCause} ...) case e: RuntimeException if e.getCause != null && e.getCause.getCause.isInstanceOf[NoSuchMethodException] => throw new ClassNotFoundException(... ${e.getCause.getCause} ...) ``` But if future Java versions change the nested levels of exceptions again, this will necessitate another modification of this part of the code. Therefore, this PR has been revised to fetch the root cause of the exception and conduct a type check on the root cause to make it as universal as possible. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass Git Hub Actions - Manually check with Java 17 ``` java -version openjdk version "17.0.8" 2023-07-18 LTS OpenJDK Runtime Environment Zulu17.44+15-CA (build 17.0.8+7-LTS) OpenJDK 64-Bit Server VM Zulu17.44+15-CA (build 17.0.8+7-LTS, mixed mode, sharing) ``` run ``` build/sbt clean "connect-client-jvm/testOnly *UDFClassLoadingE2ESuite" -Phive ``` Before ``` [info] UDFClassLoadingE2ESuite: [info] - update class loader after stubbing: new session *** FAILED *** (60 milliseconds) [info] "Exception in SerializedLambda.readResolve" did not contain "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf" (UDFClassLoadingE2ESuite.scala:57) ... [info] - update class loader after stubbing: same session *** FAILED *** (15 milliseconds) [info] "Exception in SerializedLambda.readResolve" did not contain "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf" (UDFClassLoadingE2ESuite.scala:73) ... [info] Run completed in 9 seconds, 565 milliseconds. [info] Total number of tests run: 2 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 0, failed 2, canceled 0, ignored 0, pending 0 [info] *** 2 TESTS FAILED *** [error] Failed tests: [error] org.apache.spark.sql.connect.client.UDFClassLoadingE2ESuite [error] (connect-client-jvm / Test / testOnly) sbt.TestsFailedException: Tests unsuccessful ``` After ``` [info] UDFClassLoadingE2ESuite: [info] - update class loader after stubbing: new session (116 milliseconds) [info] - update class loader after stubbing: same session (41 milliseconds) [info] Run completed in 9 seconds, 781 milliseconds. [info] Total number of tests run: 2 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 2, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` Closes #42360 from LuciferYang/unpackScalarScalaUDF-exception-java17. Authored-by: yangjie01 Signed-off-by: yangjie01 --- .../connect/planner/SparkConnectPlanner.scala | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 7136476b515f9..f70a17e580a3e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.connect.planner -import java.io.IOException - import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Try +import com.google.common.base.Throwables import com.google.common.collect.{Lists, Maps} import com.google.protobuf.{Any => ProtoAny, ByteString} import io.grpc.{Context, Status, StatusRuntimeException} @@ -1518,11 +1517,15 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { logDebug(s"Unpack using class loader: ${Utils.getContextOrSparkClassLoader}") Utils.deserialize[T](fun.getPayload.toByteArray, Utils.getContextOrSparkClassLoader) } catch { - case e: IOException if e.getCause.isInstanceOf[NoSuchMethodException] => - throw new ClassNotFoundException( - s"Failed to load class correctly due to ${e.getCause}. " + - "Make sure the artifact where the class is defined is installed by calling" + - " session.addArtifact.") + case t: Throwable => + Throwables.getRootCause(t) match { + case nsm: NoSuchMethodException => + throw new ClassNotFoundException( + s"Failed to load class correctly due to $nsm. " + + "Make sure the artifact where the class is defined is installed by calling" + + " session.addArtifact.") + case _ => throw t + } } } From 590b77f76284ad03ad8b3b6d30b23983c66513fc Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Tue, 8 Aug 2023 11:09:58 +0800 Subject: [PATCH 7/9] [SPARK-44005][PYTHON] Improve error messages for regular Python UDTFs that return non-tuple values ### What changes were proposed in this pull request? This PR improves error messages for regular Python UDTFs when the result rows are not one of tuple, list and dict. Note this is supported when arrow optimization is enabled. ### Why are the changes needed? To make Python UDTFs more user friendly. ### Does this PR introduce _any_ user-facing change? Yes. ``` class TestUDTF: def eval(self, a: int): yield a ``` Before this PR, this will fail with this error `Unexpected tuple 1 with StructType` After this PR, this will have a more user-friendly error: `[UDTF_INVALID_OUTPUT_ROW_TYPE] The type of an individual output row in the UDTF is invalid. Each row should be a tuple, list, or dict, but got 'int'. Please make sure that the output rows are of the correct type.` ### How was this patch tested? Existing UTs. Closes #42353 from allisonwang-db/spark-44005-non-tuple-return-val. Authored-by: allisonwang-db Signed-off-by: Ruifeng Zheng --- python/pyspark/errors/error_classes.py | 5 +++++ python/pyspark/sql/tests/test_udtf.py | 26 +++++++++++--------------- python/pyspark/worker.py | 12 +++++++++--- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 24885e94d3255..bc32afeb87a9f 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -743,6 +743,11 @@ "User defined table function encountered an error in the '' method: " ] }, + "UDTF_INVALID_OUTPUT_ROW_TYPE" : { + "message" : [ + "The type of an individual output row in the UDTF is invalid. Each row should be a tuple, list, or dict, but got ''. Please make sure that the output rows are of the correct type." + ] + }, "UDTF_RETURN_NOT_ITERABLE" : { "message" : [ "The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got ''. Please make sure that the UDTF returns one of these types." diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index b2f473996bcb6..300067716e9de 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -163,24 +163,21 @@ def eval(self, a: int, b: int): self.assertEqual(rows, [Row(a=1, b=2), Row(a=2, b=3)]) def test_udtf_eval_returning_non_tuple(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): yield a - func = udtf(TestUDTF, returnType="a: int") - # TODO(SPARK-44005): improve this error message - with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"): - func(lit(1)).collect() + with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): + TestUDTF(lit(1)).collect() - def test_udtf_eval_returning_non_generator(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): return (a,) - func = udtf(TestUDTF, returnType="a: int") - # TODO(SPARK-44005): improve this error message - with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"): - func(lit(1)).collect() + with self.assertRaisesRegex(PythonException, "UDTF_INVALID_OUTPUT_ROW_TYPE"): + TestUDTF(lit(1)).collect() def test_udtf_with_invalid_return_value(self): @udtf(returnType="x: int") @@ -1852,21 +1849,20 @@ def eval(self): self.spark.conf.set("spark.sql.execution.pythonUDTF.arrow.enabled", old_value) def test_udtf_eval_returning_non_tuple(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): yield a - func = udtf(TestUDTF, returnType="a: int") # When arrow is enabled, it can handle non-tuple return value. - self.assertEqual(func(lit(1)).collect(), [Row(a=1)]) + assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)]) - def test_udtf_eval_returning_non_generator(self): + @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): - return (a,) + return [a] - func = udtf(TestUDTF, returnType="a: int") - self.assertEqual(func(lit(1)).collect(), [Row(a=1)]) + assertDataFrameEqual(TestUDTF(lit(1)), [Row(a=1)]) def test_numeric_output_type_casting(self): class TestUDTF: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b32e20e3b0418..6f27400387e72 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -648,9 +648,8 @@ def wrap_udtf(f, return_type): return_type_size = len(return_type) def verify_and_convert_result(result): - # TODO(SPARK-44005): support returning non-tuple values - if result is not None and hasattr(result, "__len__"): - if len(result) != return_type_size: + if result is not None: + if hasattr(result, "__len__") and len(result) != return_type_size: raise PySparkRuntimeError( error_class="UDTF_RETURN_SCHEMA_MISMATCH", message_parameters={ @@ -658,6 +657,13 @@ def verify_and_convert_result(result): "actual": str(len(result)), }, ) + + if not (isinstance(result, (list, dict, tuple)) or hasattr(result, "__dict__")): + raise PySparkRuntimeError( + error_class="UDTF_INVALID_OUTPUT_ROW_TYPE", + message_parameters={"type": type(result).__name__}, + ) + return toInternal(result) # Evaluate the function and return a tuple back to the executor. From b4b91212b1d4ce8f47f9e1abeb26b06122c01f13 Mon Sep 17 00:00:00 2001 From: Shuyou Dong Date: Tue, 8 Aug 2023 12:17:53 +0900 Subject: [PATCH 8/9] [SPARK-44703][CORE] Log eventLog rewrite duration when compact old event log files ### What changes were proposed in this pull request? Log eventLog rewrite duration when compact old event log files. ### Why are the changes needed? When enable `spark.eventLog.rolling.enabled` and the number of eventLog files exceeds the value of `spark.history.fs.eventLog.rolling.maxFilesToRetain`, HistoryServer will compact the old event log files into one compact file. Currently there is no log the rewrite duration in rewrite method, this metric is useful for understand the compact duration, so we need add logs in the method. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manual test. Closes #42378 from shuyouZZ/SPARK-44703. Authored-by: Shuyou Dong Signed-off-by: Jungtaek Lim --- .../apache/spark/deploy/history/EventLogFileCompactor.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala b/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala index 8558f765175fc..27040e83533ff 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileCompactor.scala @@ -149,6 +149,7 @@ class EventLogFileCompactor( val logWriter = new CompactedEventLogFileWriter(lastIndexEventLogPath, "dummy", None, lastIndexEventLogPath.getParent.toUri, sparkConf, hadoopConf) + val startTime = System.currentTimeMillis() logWriter.start() eventLogFiles.foreach { file => EventFilter.applyFilterToFile(fs, filters, file.getPath, @@ -158,6 +159,8 @@ class EventLogFileCompactor( ) } logWriter.stop() + val duration = System.currentTimeMillis() - startTime + logInfo(s"Finished rewriting eventLog files to ${logWriter.logPath} took $duration ms.") logWriter.logPath } From d2b60ff51fabdb38899e649aa2e700112534d79c Mon Sep 17 00:00:00 2001 From: itholic Date: Tue, 8 Aug 2023 16:16:11 +0900 Subject: [PATCH 9/9] [SPARK-43567][PS] Support `use_na_sentinel` for `factorize` ### What changes were proposed in this pull request? This PR proposes to support `use_na_sentinel` for `factorize`. ### Why are the changes needed? To match the behavior with [pandas 2](https://pandas.pydata.org/docs/dev/whatsnew/v2.0.0.html) ### Does this PR introduce _any_ user-facing change? Yes, the `na_sentinel` is removed in favor of `use_na_sentinel`. ### How was this patch tested? Enabling the existing tests. Closes #42270 from itholic/pandas_use_na_sentinel. Authored-by: itholic Signed-off-by: Hyukjin Kwon --- .../migration_guide/pyspark_upgrade.rst | 1 + python/pyspark/pandas/base.py | 39 +++++++------------ .../connect/series/test_parity_compute.py | 4 ++ .../pandas/tests/indexes/test_category.py | 8 +--- .../pandas/tests/series/test_compute.py | 20 ++++------ 5 files changed, 29 insertions(+), 43 deletions(-) diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index 7a691ee264571..d26f1cbbe0dc4 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -29,6 +29,7 @@ Upgrading from PySpark 3.5 to 4.0 * In Spark 4.0, ``Series.append`` has been removed from pandas API on Spark, use ``ps.concat`` instead. * In Spark 4.0, ``DataFrame.mad`` has been removed from pandas API on Spark. * In Spark 4.0, ``Series.mad`` has been removed from pandas API on Spark. +* In Spark 4.0, ``na_sentinel`` parameter from ``Index.factorize`` and `Series.factorize`` has been removed from pandas API on Spark, use ``use_na_sentinel`` instead. Upgrading from PySpark 3.3 to 3.4 diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py index 2de260e6e9351..0685af769872a 100644 --- a/python/pyspark/pandas/base.py +++ b/python/pyspark/pandas/base.py @@ -1614,7 +1614,7 @@ def take(self: IndexOpsLike, indices: Sequence[int]) -> IndexOpsLike: return cast(IndexOpsLike, self._psdf.iloc[indices].index) def factorize( - self: IndexOpsLike, sort: bool = True, na_sentinel: Optional[int] = -1 + self: IndexOpsLike, sort: bool = True, use_na_sentinel: bool = True ) -> Tuple[IndexOpsLike, pd.Index]: """ Encode the object as an enumerated type or categorical variable. @@ -1625,11 +1625,11 @@ def factorize( Parameters ---------- sort : bool, default True - na_sentinel : int or None, default -1 - Value to mark "not found". If None, will not drop the NaN - from the uniques of the values. - - .. deprecated:: 3.4.0 + use_na_sentinel : bool, default True + If True, the sentinel -1 will be used for NaN values, effectively assigning them + a distinct category. If False, NaN values will be encoded as non-negative integers, + treating them as unique categories in the encoding process and retaining them in the + set of unique categories in the data. Returns ------- @@ -1658,7 +1658,7 @@ def factorize( >>> uniques Index(['a', 'b', 'c'], dtype='object') - >>> codes, uniques = psser.factorize(na_sentinel=None) + >>> codes, uniques = psser.factorize(use_na_sentinel=False) >>> codes 0 1 1 3 @@ -1669,17 +1669,6 @@ def factorize( >>> uniques Index(['a', 'b', 'c', None], dtype='object') - >>> codes, uniques = psser.factorize(na_sentinel=-2) - >>> codes - 0 1 - 1 -2 - 2 0 - 3 2 - 4 1 - dtype: int32 - >>> uniques - Index(['a', 'b', 'c'], dtype='object') - For Index: >>> psidx = ps.Index(['b', None, 'a', 'c', 'b']) @@ -1691,8 +1680,8 @@ def factorize( """ from pyspark.pandas.series import first_series - assert (na_sentinel is None) or isinstance(na_sentinel, int) assert sort is True + use_na_sentinel = -1 if use_na_sentinel else False # type: ignore[assignment] warnings.warn( "Argument `na_sentinel` will be removed in 4.0.0.", @@ -1716,7 +1705,7 @@ def factorize( scol = map_scol[self.spark.column] codes, uniques = self._with_new_scol( scol.alias(self._internal.data_spark_column_names[0]) - ).factorize(na_sentinel=na_sentinel) + ).factorize(use_na_sentinel=use_na_sentinel) return codes, uniques.astype(self.dtype) uniq_sdf = self._internal.spark_frame.select(self.spark.column).distinct() @@ -1743,13 +1732,13 @@ def factorize( # Constructs `unique_to_code` mapping non-na unique to code unique_to_code = {} - if na_sentinel is not None: - na_sentinel_code = na_sentinel + if use_na_sentinel: + na_sentinel_code = use_na_sentinel code = 0 for unique in uniques_list: if pd.isna(unique): - if na_sentinel is None: - na_sentinel_code = code + if not use_na_sentinel: + na_sentinel_code = code # type: ignore[assignment] else: unique_to_code[unique] = code code += 1 @@ -1767,7 +1756,7 @@ def factorize( codes = self._with_new_scol(new_scol.alias(self._internal.data_spark_column_names[0])) - if na_sentinel is not None: + if use_na_sentinel: # Drops the NaN from the uniques of the values uniques_list = [x for x in uniques_list if not pd.isna(x)] diff --git a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py index 8876fcb139885..31916f12b4e7f 100644 --- a/python/pyspark/pandas/tests/connect/series/test_parity_compute.py +++ b/python/pyspark/pandas/tests/connect/series/test_parity_compute.py @@ -24,6 +24,10 @@ class SeriesParityComputeTests(SeriesComputeMixin, PandasOnSparkTestUtils, ReusedConnectTestCase): pass + @unittest.skip("TODO(SPARK-43620): Support `Column` for SparkConnectColumn.__getitem__.") + def test_factorize(self): + super().test_factorize() + if __name__ == "__main__": from pyspark.pandas.tests.connect.series.test_parity_compute import * # noqa: F401 diff --git a/python/pyspark/pandas/tests/indexes/test_category.py b/python/pyspark/pandas/tests/indexes/test_category.py index ffffae828c437..6aa92b7e1e390 100644 --- a/python/pyspark/pandas/tests/indexes/test_category.py +++ b/python/pyspark/pandas/tests/indexes/test_category.py @@ -210,10 +210,6 @@ def test_astype(self): self.assert_eq(pscidx.astype(str), pcidx.astype(str)) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43567): Enable CategoricalIndexTests.test_factorize for pandas 2.0.0.", - ) def test_factorize(self): pidx = pd.CategoricalIndex([1, 2, 3, None]) psidx = ps.from_pandas(pidx) @@ -224,8 +220,8 @@ def test_factorize(self): self.assert_eq(kcodes.tolist(), pcodes.tolist()) self.assert_eq(kuniques, puniques) - pcodes, puniques = pidx.factorize(na_sentinel=-2) - kcodes, kuniques = psidx.factorize(na_sentinel=-2) + pcodes, puniques = pidx.factorize(use_na_sentinel=-2) + kcodes, kuniques = psidx.factorize(use_na_sentinel=-2) self.assert_eq(kcodes.tolist(), pcodes.tolist()) self.assert_eq(kuniques, puniques) diff --git a/python/pyspark/pandas/tests/series/test_compute.py b/python/pyspark/pandas/tests/series/test_compute.py index 155649179e6ef..784bf29e1a25b 100644 --- a/python/pyspark/pandas/tests/series/test_compute.py +++ b/python/pyspark/pandas/tests/series/test_compute.py @@ -407,10 +407,6 @@ def test_abs(self): self.assert_eq(abs(psser), abs(pser)) self.assert_eq(np.abs(psser), np.abs(pser)) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43550): Enable SeriesTests.test_factorize for pandas 2.0.0.", - ) def test_factorize(self): pser = pd.Series(["a", "b", "a", "b"]) psser = ps.from_pandas(pser) @@ -492,27 +488,27 @@ def test_factorize(self): pser = pd.Series(["a", "b", "a", np.nan, None]) psser = ps.from_pandas(pser) - pcodes, puniques = pser.factorize(sort=True, na_sentinel=-2) - kcodes, kuniques = psser.factorize(na_sentinel=-2) + pcodes, puniques = pser.factorize(sort=True, use_na_sentinel=-2) + kcodes, kuniques = psser.factorize(use_na_sentinel=-2) self.assert_eq(pcodes.tolist(), kcodes.to_list()) self.assert_eq(puniques, kuniques) - pcodes, puniques = pser.factorize(sort=True, na_sentinel=2) - kcodes, kuniques = psser.factorize(na_sentinel=2) + pcodes, puniques = pser.factorize(sort=True, use_na_sentinel=2) + kcodes, kuniques = psser.factorize(use_na_sentinel=2) self.assert_eq(pcodes.tolist(), kcodes.to_list()) self.assert_eq(puniques, kuniques) if not pd_below_1_1_2: - pcodes, puniques = pser.factorize(sort=True, na_sentinel=None) - kcodes, kuniques = psser.factorize(na_sentinel=None) + pcodes, puniques = pser.factorize(sort=True, use_na_sentinel=None) + kcodes, kuniques = psser.factorize(use_na_sentinel=None) self.assert_eq(pcodes.tolist(), kcodes.to_list()) # puniques is Index(['a', 'b', nan], dtype='object') self.assert_eq(ps.Index(["a", "b", None]), kuniques) psser = ps.Series([1, 2, np.nan, 4, 5]) # Arrow takes np.nan as null psser.loc[3] = np.nan # Spark takes np.nan as NaN - kcodes, kuniques = psser.factorize(na_sentinel=None) - pcodes, puniques = psser._to_pandas().factorize(sort=True, na_sentinel=None) + kcodes, kuniques = psser.factorize(use_na_sentinel=None) + pcodes, puniques = psser._to_pandas().factorize(sort=True, use_na_sentinel=None) self.assert_eq(pcodes.tolist(), kcodes.to_list()) self.assert_eq(puniques, kuniques)