diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index 93a64d8eef10a..0b577cd01bd2c 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -59,6 +59,7 @@ def _deserialize_accumulator( class SpecialAccumulatorIds: SQL_UDF_PROFIER = -1 + SQL_UDF_LOGGER = -2 class Accumulator(Generic[T]): diff --git a/python/pyspark/logger/logger.py b/python/pyspark/logger/logger.py index 975441a9cb572..610b897f8c3ed 100644 --- a/python/pyspark/logger/logger.py +++ b/python/pyspark/logger/logger.py @@ -64,7 +64,56 @@ def format(self, record: logging.LogRecord) -> str: return json.dumps(log_entry, ensure_ascii=False) -class PySparkLogger(logging.Logger): +class PySparkLoggerBase(logging.Logger): + def info(self, msg: object, *args: object, **kwargs: object) -> None: + """ + Log 'msg % args' with severity 'INFO' in structured JSON format. + + Parameters + ---------- + msg : str + The log message. + """ + super().info(msg, *args, extra={"kwargs": kwargs}) + + def warning(self, msg: object, *args: object, **kwargs: object) -> None: + """ + Log 'msg % args' with severity 'WARNING' in structured JSON format. + + Parameters + ---------- + msg : str + The log message. + """ + super().warning(msg, *args, extra={"kwargs": kwargs}) + + def error(self, msg: object, *args: object, **kwargs: object) -> None: + """ + Log 'msg % args' with severity 'ERROR' in structured JSON format. + + Parameters + ---------- + msg : str + The log message. + """ + super().error(msg, *args, extra={"kwargs": kwargs}) + + def exception(self, msg: object, *args: object, **kwargs: object) -> None: + """ + Convenience method for logging an ERROR with exception information. + + Parameters + ---------- + msg : str + The log message. + exc_info : bool = True + If True, exception information is added to the logging message. + This includes the exception type, value, and traceback. Default is True. + """ + super().error(msg, *args, exc_info=True, extra={"kwargs": kwargs}) + + +class PySparkLogger(PySparkLoggerBase): """ Custom logging.Logger wrapper for PySpark that logs messages in a structured JSON format. @@ -147,50 +196,3 @@ def getLogger(name: Optional[str] = None) -> "PySparkLogger": logging.setLoggerClass(existing_logger) return cast(PySparkLogger, pyspark_logger) - - def info(self, msg: object, *args: object, **kwargs: object) -> None: - """ - Log 'msg % args' with severity 'INFO' in structured JSON format. - - Parameters - ---------- - msg : str - The log message. - """ - super().info(msg, *args, extra={"kwargs": kwargs}) - - def warning(self, msg: object, *args: object, **kwargs: object) -> None: - """ - Log 'msg % args' with severity 'WARNING' in structured JSON format. - - Parameters - ---------- - msg : str - The log message. - """ - super().warning(msg, *args, extra={"kwargs": kwargs}) - - def error(self, msg: object, *args: object, **kwargs: object) -> None: - """ - Log 'msg % args' with severity 'ERROR' in structured JSON format. - - Parameters - ---------- - msg : str - The log message. - """ - super().error(msg, *args, extra={"kwargs": kwargs}) - - def exception(self, msg: object, *args: object, **kwargs: object) -> None: - """ - Convenience method for logging an ERROR with exception information. - - Parameters - ---------- - msg : str - The log message. - exc_info : bool = True - If True, exception information is added to the logging message. - This includes the exception type, value, and traceback. Default is True. - """ - super().error(msg, *args, exc_info=True, extra={"kwargs": kwargs}) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index adba1b42a8bd6..8e8fdc2e5390d 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -95,6 +95,7 @@ PythonDataSource, ) from pyspark.sql.connect.observation import Observation +from pyspark.sql.connect.udf_logger import ConnectUDFLogCollector from pyspark.sql.connect.utils import get_python_ver from pyspark.sql.pandas.types import _create_converter_to_pandas, from_arrow_schema from pyspark.sql.types import DataType, StructType, TimestampType, _has_type @@ -665,6 +666,7 @@ def __init__( self._server_session_id: Optional[str] = None self._profiler_collector = ConnectProfilerCollector() + self._udf_log_collector = ConnectUDFLogCollector() self._progress_handlers: List[ProgressHandler] = [] @@ -1388,6 +1390,8 @@ def handle_response( (aid, update) = pickleSer.loads(LiteralExpression._to_value(metric)) if aid == SpecialAccumulatorIds.SQL_UDF_PROFIER: self._profiler_collector._update(update) + elif aid == SpecialAccumulatorIds.SQL_UDF_LOGGER: + self._udf_log_collector._update(update) elif observed_metrics.name in observations: observation_result = observations[observed_metrics.name]._result assert observation_result is not None diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index cacb479229bb7..c44ef9aaa015e 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -93,6 +93,7 @@ MapType, StringType, ) +from pyspark.sql.udf_logger import UDFLogs from pyspark.sql.utils import to_str from pyspark.errors import ( PySparkAttributeError, @@ -1079,6 +1080,12 @@ def profile(self) -> Profile: profile.__doc__ = PySparkSession.profile.__doc__ + @property + def udfLogs(self) -> UDFLogs: + return UDFLogs(self, self._client._udf_log_collector) + + udfLogs.__doc__ = PySparkSession.udfLogs.__doc__ + def __reduce__(self) -> Tuple: """ This method is called when the object is pickled. It returns a tuple of the object's diff --git a/python/pyspark/sql/connect/udf_logger.py b/python/pyspark/sql/connect/udf_logger.py new file mode 100644 index 0000000000000..789a27601cdef --- /dev/null +++ b/python/pyspark/sql/connect/udf_logger.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import List, Optional + +from pyspark.sql.udf_logger import Logs, LogsParam, UDFLogCollector + + +class ConnectUDFLogCollector(UDFLogCollector): + def __init__(self): + super().__init__() + self._value = LogsParam.zero({}) + + def collect(self, id: int) -> Optional[List[str]]: + with self._lock: + return self._value.get(id) + + def clear(self, id: Optional[int] = None) -> None: + with self._lock: + if id is not None: + self._value.pop(id, None) + else: + self._value.clear() + + def _update(self, update: Logs) -> None: + with self._lock: + self._value = LogsParam.addInPlace(self._value, update) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b513d8d4111b9..e37efcd29d6b0 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -62,6 +62,7 @@ _from_numpy_type, ) from pyspark.errors.exceptions.captured import install_exception_handler +from pyspark.sql.udf_logger import AccumulatorUDFLogCollector, UDFLogs from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str, try_remote_session_classmethod from pyspark.errors import PySparkValueError, PySparkTypeError, PySparkRuntimeError @@ -660,6 +661,7 @@ def __init__( self._jvm.SparkSession.setActiveSession(self._jsparkSession) self._profiler_collector = AccumulatorProfilerCollector() + self._udf_log_collector = AccumulatorUDFLogCollector() def _repr_html_(self) -> str: return """ @@ -955,6 +957,10 @@ def profile(self) -> Profile: """ return Profile(self._profiler_collector) + @property + def udfLogs(self) -> UDFLogs: + return UDFLogs(self, self._udf_log_collector) + def range( self, start: int, diff --git a/python/pyspark/sql/udf_logger.py b/python/pyspark/sql/udf_logger.py new file mode 100644 index 0000000000000..eba7bdad8a9b6 --- /dev/null +++ b/python/pyspark/sql/udf_logger.py @@ -0,0 +1,175 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import ABC, abstractmethod +from collections import defaultdict +import logging +from threading import RLock +from typing import ClassVar, Dict, List, Optional, cast, TYPE_CHECKING + +from pyspark.accumulators import ( + Accumulator, + AccumulatorParam, + SpecialAccumulatorIds, + _accumulatorRegistry, + _deserialize_accumulator, +) +from pyspark.logger.logger import JSONFormatter, PySparkLoggerBase +from pyspark.sql.functions import from_json +from pyspark.sql.types import ( + ArrayType, + MapType, + StringType, + StructType, + TimestampNTZType, + VariantType, +) + +if TYPE_CHECKING: + from pyspark.sql.dataframe import DataFrame + from pyspark.sql.session import SparkSession + +Logs = Dict[int, List[str]] + + +class _LogsParam(AccumulatorParam[Logs]): + @staticmethod + def zero(value: Logs) -> Logs: + return defaultdict(list) + + @classmethod + def addInPlace(cls, value1: Logs, value2: Logs) -> Logs: + new_value = defaultdict(list) + for k, v in value1.items(): + new_value[k].extend(v) + for k, v in value2.items(): + new_value[k].extend(v) + return new_value + + +LogsParam = _LogsParam() + + +class UDFLogHandler(logging.Handler): + def __init__(self, result_id: int): + super().__init__() + self._accumulator = _deserialize_accumulator( + SpecialAccumulatorIds.SQL_UDF_LOGGER, LogsParam.zero({}), LogsParam + ) + self._result_id = result_id + formatter = JSONFormatter() + formatter.default_msec_format = "%s.%03d" + self._formatter = formatter + + def emit(self, record: logging.LogRecord) -> None: + msg = self._formatter.format(record) + self._accumulator.add({self._result_id: [msg]}) + + +class PySparkUDFLogger(PySparkLoggerBase): + _udf_log_handler: ClassVar[Optional[UDFLogHandler]] = None + + def __init__(self, name: str): + super().__init__(name, level=logging.WARN) + if self._udf_log_handler is not None: + self.addHandler(self._udf_log_handler) + + @staticmethod + def getLogger(name: str = "PySparkUDFLogger") -> "PySparkUDFLogger": + existing_logger = logging.getLoggerClass() + try: + if not isinstance(existing_logger, PySparkUDFLogger): + logging.setLoggerClass(PySparkUDFLogger) + + udf_logger = logging.getLogger(name) + finally: + logging.setLoggerClass(existing_logger) + + return cast(PySparkUDFLogger, udf_logger) + + +class UDFLogCollector(ABC): + LOG_ENTRY_SCHEMA: ClassVar[StructType] = ( + StructType() + .add("ts", TimestampNTZType()) + .add("level", StringType()) + .add("logger", StringType()) + .add("msg", StringType()) + .add("context", MapType(StringType(), VariantType())) + .add( + "exception", + StructType() + .add("class", StringType()) + .add("msg", StringType()) + .add("stacktrace", ArrayType(StringType())), + ) + ) + + def __init__(self): + self._lock = RLock() + + @abstractmethod + def collect(self, id: int) -> Optional[List[str]]: + pass + + @abstractmethod + def clear(self, id: Optional[int] = None) -> None: + pass + + +class AccumulatorUDFLogCollector(UDFLogCollector): + def __init__(self): + super().__init__() + if SpecialAccumulatorIds.SQL_UDF_LOGGER in _accumulatorRegistry: + self._accumulator = _accumulatorRegistry[SpecialAccumulatorIds.SQL_UDF_LOGGER] + else: + self._accumulator = Accumulator( + SpecialAccumulatorIds.SQL_UDF_LOGGER, LogsParam.zero({}), LogsParam + ) + + def collect(self, id: int) -> Optional[List[str]]: + with self._lock: + return self._accumulator.value.get(id) + + def clear(self, id: Optional[int] = None) -> None: + with self._lock: + if id is not None: + self._accumulator.value.pop(id, None) + else: + self._accumulator.value.clear() + + +class UDFLogs: + def __init__(self, sparkSession: "SparkSession", collector: UDFLogCollector): + self._sparkSession = sparkSession + self._collector = collector + + def collect(self, id: int) -> Optional[List[str]]: + return self._collector.collect(id) + + def collectAsDataFrame(self, id: int) -> Optional["DataFrame"]: + logs = self._collector.collect(id) + if logs is not None: + return ( + self._sparkSession.createDataFrame([(row,) for row in logs], "json string") + .select(from_json("json", self._collector.LOG_ENTRY_SCHEMA).alias("json")) + .select("json.*") + ) + else: + return None + + def clear(self, id: Optional[int] = None) -> None: + self._collector.clear(id) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b8263769c28a9..fe8a8780f4448 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -70,6 +70,7 @@ _create_row, _parse_datatype_json_string, ) +from pyspark.sql.udf_logger import PySparkUDFLogger, UDFLogHandler from pyspark.util import fail_on_stopiteration, handle_worker_exception from pyspark import shuffle from pyspark.errors import PySparkRuntimeError, PySparkTypeError @@ -803,16 +804,16 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil else: chained_func = chain(chained_func, f) - if profiler == "perf": - result_id = read_long(infile) + result_id = read_long(infile) + PySparkUDFLogger._udf_log_handler = UDFLogHandler(result_id) + if profiler == "perf": if _supports_profiler(eval_type): profiling_func = wrap_perf_profiler(chained_func, result_id) else: profiling_func = chained_func elif profiler == "memory": - result_id = read_long(infile) if _supports_profiler(eval_type) and has_memory_profiler: profiling_func = wrap_memory_profiler(chained_func, result_id) else: @@ -1909,6 +1910,7 @@ def process(): faulthandler.disable() faulthandler_log_file.close() os.remove(faulthandler_log_path) + PySparkUDFLogger._udf_log_handler = None finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) write_long(shuffle.MemoryBytesSpilled, outfile) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 87ff5a0ec4333..4d7fbd77a1d34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -163,9 +163,7 @@ object PythonUDFRunner { chained.funcs.foreach { f => PythonWorkerUtils.writePythonFunction(f, dataOut) } - if (profiler.isDefined) { - dataOut.writeLong(resultId) - } + dataOut.writeLong(resultId) } } @@ -198,9 +196,7 @@ object PythonUDFRunner { chained.funcs.foreach { f => PythonWorkerUtils.writePythonFunction(f, dataOut) } - if (profiler.isDefined) { - dataOut.writeLong(resultId) - } + dataOut.writeLong(resultId) } } }