Skip to content

Commit

Permalink
[from now] 2024/09/11 15:23:04
Browse files Browse the repository at this point in the history
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 93a64d8eef1..0b577cd01bd 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -59,6 +59,7 @@ def _deserialize_accumulator(

 class SpecialAccumulatorIds:
     SQL_UDF_PROFIER = -1
+    SQL_UDF_LOGGER = -2

 class Accumulator(Generic[T]):
diff --git a/python/pyspark/logger/logger.py b/python/pyspark/logger/logger.py
index 975441a9cb5..610b897f8c3 100644
--- a/python/pyspark/logger/logger.py
+++ b/python/pyspark/logger/logger.py
@@ -64,7 +64,56 @@ class JSONFormatter(logging.Formatter):
         return json.dumps(log_entry, ensure_ascii=False)

-class PySparkLogger(logging.Logger):
+class PySparkLoggerBase(logging.Logger):
+    def info(self, msg: object, *args: object, **kwargs: object) -> None:
+        """
+        Log 'msg % args' with severity 'INFO' in structured JSON format.
+
+        Parameters
+        ----------
+        msg : str
+            The log message.
+        """
+        super().info(msg, *args, extra={"kwargs": kwargs})
+
+    def warning(self, msg: object, *args: object, **kwargs: object) -> None:
+        """
+        Log 'msg % args' with severity 'WARNING' in structured JSON format.
+
+        Parameters
+        ----------
+        msg : str
+            The log message.
+        """
+        super().warning(msg, *args, extra={"kwargs": kwargs})
+
+    def error(self, msg: object, *args: object, **kwargs: object) -> None:
+        """
+        Log 'msg % args' with severity 'ERROR' in structured JSON format.
+
+        Parameters
+        ----------
+        msg : str
+            The log message.
+        """
+        super().error(msg, *args, extra={"kwargs": kwargs})
+
+    def exception(self, msg: object, *args: object, **kwargs: object) -> None:
+        """
+        Convenience method for logging an ERROR with exception information.
+
+        Parameters
+        ----------
+        msg : str
+            The log message.
+        exc_info : bool = True
+            If True, exception information is added to the logging message.
+            This includes the exception type, value, and traceback. Default is True.
+        """
+        super().error(msg, *args, exc_info=True, extra={"kwargs": kwargs})
+
+
+class PySparkLogger(PySparkLoggerBase):
     """
     Custom logging.Logger wrapper for PySpark that logs messages in a structured JSON format.

@@ -147,50 +196,3 @@ class PySparkLogger(logging.Logger):
         logging.setLoggerClass(existing_logger)

         return cast(PySparkLogger, pyspark_logger)
-
-    def info(self, msg: object, *args: object, **kwargs: object) -> None:
-        """
-        Log 'msg % args' with severity 'INFO' in structured JSON format.
-
-        Parameters
-        ----------
-        msg : str
-            The log message.
-        """
-        super().info(msg, *args, extra={"kwargs": kwargs})
-
-    def warning(self, msg: object, *args: object, **kwargs: object) -> None:
-        """
-        Log 'msg % args' with severity 'WARNING' in structured JSON format.
-
-        Parameters
-        ----------
-        msg : str
-            The log message.
-        """
-        super().warning(msg, *args, extra={"kwargs": kwargs})
-
-    def error(self, msg: object, *args: object, **kwargs: object) -> None:
-        """
-        Log 'msg % args' with severity 'ERROR' in structured JSON format.
-
-        Parameters
-        ----------
-        msg : str
-            The log message.
-        """
-        super().error(msg, *args, extra={"kwargs": kwargs})
-
-    def exception(self, msg: object, *args: object, **kwargs: object) -> None:
-        """
-        Convenience method for logging an ERROR with exception information.
-
-        Parameters
-        ----------
-        msg : str
-            The log message.
-        exc_info : bool = True
-            If True, exception information is added to the logging message.
-            This includes the exception type, value, and traceback. Default is True.
-        """
-        super().error(msg, *args, exc_info=True, extra={"kwargs": kwargs})
diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py
index adba1b42a8b..8e8fdc2e539 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -95,6 +95,7 @@ from pyspark.sql.connect.plan import (
     PythonDataSource,
 )
 from pyspark.sql.connect.observation import Observation
+from pyspark.sql.connect.udf_logger import ConnectUDFLogCollector
 from pyspark.sql.connect.utils import get_python_ver
 from pyspark.sql.pandas.types import _create_converter_to_pandas, from_arrow_schema
 from pyspark.sql.types import DataType, StructType, TimestampType, _has_type
@@ -665,6 +666,7 @@ class SparkConnectClient(object):
         self._server_session_id: Optional[str] = None

         self._profiler_collector = ConnectProfilerCollector()
+        self._udf_log_collector = ConnectUDFLogCollector()

         self._progress_handlers: List[ProgressHandler] = []

@@ -1388,6 +1390,8 @@ class SparkConnectClient(object):
                             (aid, update) = pickleSer.loads(LiteralExpression._to_value(metric))
                             if aid == SpecialAccumulatorIds.SQL_UDF_PROFIER:
                                 self._profiler_collector._update(update)
+                            elif aid == SpecialAccumulatorIds.SQL_UDF_LOGGER:
+                                self._udf_log_collector._update(update)
                     elif observed_metrics.name in observations:
                         observation_result = observations[observed_metrics.name]._result
                         assert observation_result is not None
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index cacb479229b..c44ef9aaa01 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -93,6 +93,7 @@ from pyspark.sql.types import (
     MapType,
     StringType,
 )
+from pyspark.sql.udf_logger import UDFLogs
 from pyspark.sql.utils import to_str
 from pyspark.errors import (
     PySparkAttributeError,
@@ -1079,6 +1080,12 @@ class SparkSession:

     profile.__doc__ = PySparkSession.profile.__doc__

+    @Property
+    def udfLogs(self) -> UDFLogs:
+        return UDFLogs(self, self._client._udf_log_collector)
+
+    udfLogs.__doc__ = PySparkSession.udfLogs.__doc__
+
     def __reduce__(self) -> Tuple:
         """
         This method is called when the object is pickled. It returns a tuple of the object's
diff --git a/python/pyspark/sql/connect/udf_logger.py b/python/pyspark/sql/connect/udf_logger.py
new file mode 100644
index 00000000000..789a27601cd
--- /dev/null
+++ b/python/pyspark/sql/connect/udf_logger.py
@@ -0,0 +1,40 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import List, Optional
+
+from pyspark.sql.udf_logger import Logs, LogsParam, UDFLogCollector
+
+
+class ConnectUDFLogCollector(UDFLogCollector):
+    def __init__(self):
+        super().__init__()
+        self._value = LogsParam.zero({})
+
+    def collect(self, id: int) -> Optional[List[str]]:
+        with self._lock:
+            return self._value.get(id)
+
+    def clear(self, id: Optional[int] = None) -> None:
+        with self._lock:
+            if id is not None:
+                self._value.pop(id, None)
+            else:
+                self._value.clear()
+
+    def _update(self, update: Logs) -> None:
+        with self._lock:
+            self._value = LogsParam.addInPlace(self._value, update)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index b513d8d4111..e37efcd29d6 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -62,6 +62,7 @@ from pyspark.sql.types import (
     _from_numpy_type,
 )
 from pyspark.errors.exceptions.captured import install_exception_handler
+from pyspark.sql.udf_logger import AccumulatorUDFLogCollector, UDFLogs
 from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str, try_remote_session_classmethod
 from pyspark.errors import PySparkValueError, PySparkTypeError, PySparkRuntimeError

@@ -660,6 +661,7 @@ class SparkSession(SparkConversionMixin):
             self._jvm.SparkSession.setActiveSession(self._jsparkSession)

         self._profiler_collector = AccumulatorProfilerCollector()
+        self._udf_log_collector = AccumulatorUDFLogCollector()

     def _repr_html_(self) -> str:
         return """
@@ -955,6 +957,10 @@ class SparkSession(SparkConversionMixin):
         """
         return Profile(self._profiler_collector)

+    @Property
+    def udfLogs(self) -> UDFLogs:
+        return UDFLogs(self, self._udf_log_collector)
+
     def range(
         self,
         start: int,
diff --git a/python/pyspark/sql/udf_logger.py b/python/pyspark/sql/udf_logger.py
new file mode 100644
index 00000000000..eba7bdad8a9
--- /dev/null
+++ b/python/pyspark/sql/udf_logger.py
@@ -0,0 +1,175 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from abc import ABC, abstractmethod
+from collections import defaultdict
+import logging
+from threading import RLock
+from typing import ClassVar, Dict, List, Optional, cast, TYPE_CHECKING
+
+from pyspark.accumulators import (
+    Accumulator,
+    AccumulatorParam,
+    SpecialAccumulatorIds,
+    _accumulatorRegistry,
+    _deserialize_accumulator,
+)
+from pyspark.logger.logger import JSONFormatter, PySparkLoggerBase
+from pyspark.sql.functions import from_json
+from pyspark.sql.types import (
+    ArrayType,
+    MapType,
+    StringType,
+    StructType,
+    TimestampNTZType,
+    VariantType,
+)
+
+if TYPE_CHECKING:
+    from pyspark.sql.dataframe import DataFrame
+    from pyspark.sql.session import SparkSession
+
+Logs = Dict[int, List[str]]
+
+
+class _LogsParam(AccumulatorParam[Logs]):
+    @staticmethod
+    def zero(value: Logs) -> Logs:
+        return defaultdict(list)
+
+    @classmethod
+    def addInPlace(cls, value1: Logs, value2: Logs) -> Logs:
+        new_value = defaultdict(list)
+        for k, v in value1.items():
+            new_value[k].extend(v)
+        for k, v in value2.items():
+            new_value[k].extend(v)
+        return new_value
+
+
+LogsParam = _LogsParam()
+
+
+class UDFLogHandler(logging.Handler):
+    def __init__(self, result_id: int):
+        super().__init__()
+        self._accumulator = _deserialize_accumulator(
+            SpecialAccumulatorIds.SQL_UDF_LOGGER, LogsParam.zero({}), LogsParam
+        )
+        self._result_id = result_id
+        formatter = JSONFormatter()
+        formatter.default_msec_format = "%s.%03d"
+        self._formatter = formatter
+
+    def emit(self, record: logging.LogRecord) -> None:
+        msg = self._formatter.format(record)
+        self._accumulator.add({self._result_id: [msg]})
+
+
+class PySparkUDFLogger(PySparkLoggerBase):
+    _udf_log_handler: ClassVar[Optional[UDFLogHandler]] = None
+
+    def __init__(self, name: str):
+        super().__init__(name, level=logging.WARN)
+        if self._udf_log_handler is not None:
+            self.addHandler(self._udf_log_handler)
+
+    @staticmethod
+    def getLogger(name: str = "PySparkUDFLogger") -> "PySparkUDFLogger":
+        existing_logger = logging.getLoggerClass()
+        try:
+            if not isinstance(existing_logger, PySparkUDFLogger):
+                logging.setLoggerClass(PySparkUDFLogger)
+
+            udf_logger = logging.getLogger(name)
+        finally:
+            logging.setLoggerClass(existing_logger)
+
+        return cast(PySparkUDFLogger, udf_logger)
+
+
+class UDFLogCollector(ABC):
+    LOG_ENTRY_SCHEMA: ClassVar[StructType] = (
+        StructType()
+        .add("ts", TimestampNTZType())
+        .add("level", StringType())
+        .add("logger", StringType())
+        .add("msg", StringType())
+        .add("context", MapType(StringType(), VariantType()))
+        .add(
+            "exception",
+            StructType()
+            .add("class", StringType())
+            .add("msg", StringType())
+            .add("stacktrace", ArrayType(StringType())),
+        )
+    )
+
+    def __init__(self):
+        self._lock = RLock()
+
+    @AbstractMethod
+    def collect(self, id: int) -> Optional[List[str]]:
+        pass
+
+    @AbstractMethod
+    def clear(self, id: Optional[int] = None) -> None:
+        pass
+
+
+class AccumulatorUDFLogCollector(UDFLogCollector):
+    def __init__(self):
+        super().__init__()
+        if SpecialAccumulatorIds.SQL_UDF_LOGGER in _accumulatorRegistry:
+            self._accumulator = _accumulatorRegistry[SpecialAccumulatorIds.SQL_UDF_LOGGER]
+        else:
+            self._accumulator = Accumulator(
+                SpecialAccumulatorIds.SQL_UDF_LOGGER, LogsParam.zero({}), LogsParam
+            )
+
+    def collect(self, id: int) -> Optional[List[str]]:
+        with self._lock:
+            return self._accumulator.value.get(id)
+
+    def clear(self, id: Optional[int] = None) -> None:
+        with self._lock:
+            if id is not None:
+                self._accumulator.value.pop(id, None)
+            else:
+                self._accumulator.value.clear()
+
+
+class UDFLogs:
+    def __init__(self, sparkSession: "SparkSession", collector: UDFLogCollector):
+        self._sparkSession = sparkSession
+        self._collector = collector
+
+    def collect(self, id: int) -> Optional[List[str]]:
+        return self._collector.collect(id)
+
+    def collectAsDataFrame(self, id: int) -> Optional["DataFrame"]:
+        logs = self._collector.collect(id)
+        if logs is not None:
+            return (
+                self._sparkSession.createDataFrame([(row,) for row in logs], "json string")
+                .select(from_json("json", self._collector.LOG_ENTRY_SCHEMA).alias("json"))
+                .select("json.*")
+            )
+        else:
+            return None
+
+    def clear(self, id: Optional[int] = None) -> None:
+        self._collector.clear(id)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index b8263769c28..fe8a8780f44 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -70,6 +70,7 @@ from pyspark.sql.types import (
     _create_row,
     _parse_datatype_json_string,
 )
+from pyspark.sql.udf_logger import PySparkUDFLogger, UDFLogHandler
 from pyspark.util import fail_on_stopiteration, handle_worker_exception
 from pyspark import shuffle
 from pyspark.errors import PySparkRuntimeError, PySparkTypeError
@@ -803,16 +804,16 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil
         else:
             chained_func = chain(chained_func, f)

-    if profiler == "perf":
-        result_id = read_long(infile)
+    result_id = read_long(infile)
+    PySparkUDFLogger._udf_log_handler = UDFLogHandler(result_id)

+    if profiler == "perf":
         if _supports_profiler(eval_type):
             profiling_func = wrap_perf_profiler(chained_func, result_id)
         else:
             profiling_func = chained_func

     elif profiler == "memory":
-        result_id = read_long(infile)
         if _supports_profiler(eval_type) and has_memory_profiler:
             profiling_func = wrap_memory_profiler(chained_func, result_id)
         else:
@@ -1909,6 +1910,7 @@ def main(infile, outfile):
             faulthandler.disable()
             faulthandler_log_file.close()
             os.remove(faulthandler_log_path)
+        PySparkUDFLogger._udf_log_handler = None
     finish_time = time.time()
     report_times(outfile, boot_time, init_time, finish_time)
     write_long(shuffle.MemoryBytesSpilled, outfile)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index 87ff5a0ec43..4d7fbd77a1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -163,9 +163,7 @@ object PythonUDFRunner {
       chained.funcs.foreach { f =>
         PythonWorkerUtils.writePythonFunction(f, dataOut)
       }
-      if (profiler.isDefined) {
-        dataOut.writeLong(resultId)
-      }
+      dataOut.writeLong(resultId)
     }
   }

@@ -198,9 +196,7 @@ object PythonUDFRunner {
       chained.funcs.foreach { f =>
         PythonWorkerUtils.writePythonFunction(f, dataOut)
       }
-      if (profiler.isDefined) {
-        dataOut.writeLong(resultId)
-      }
+      dataOut.writeLong(resultId)
     }
   }
 }
  • Loading branch information
ueshin committed Sep 11, 2024
1 parent e037953 commit 416c36d
Show file tree
Hide file tree
Showing 9 changed files with 290 additions and 57 deletions.
1 change: 1 addition & 0 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def _deserialize_accumulator(

class SpecialAccumulatorIds:
SQL_UDF_PROFIER = -1
SQL_UDF_LOGGER = -2


class Accumulator(Generic[T]):
Expand Down
98 changes: 50 additions & 48 deletions python/pyspark/logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,56 @@ def format(self, record: logging.LogRecord) -> str:
return json.dumps(log_entry, ensure_ascii=False)


class PySparkLogger(logging.Logger):
class PySparkLoggerBase(logging.Logger):
def info(self, msg: object, *args: object, **kwargs: object) -> None:
"""
Log 'msg % args' with severity 'INFO' in structured JSON format.
Parameters
----------
msg : str
The log message.
"""
super().info(msg, *args, extra={"kwargs": kwargs})

def warning(self, msg: object, *args: object, **kwargs: object) -> None:
"""
Log 'msg % args' with severity 'WARNING' in structured JSON format.
Parameters
----------
msg : str
The log message.
"""
super().warning(msg, *args, extra={"kwargs": kwargs})

def error(self, msg: object, *args: object, **kwargs: object) -> None:
"""
Log 'msg % args' with severity 'ERROR' in structured JSON format.
Parameters
----------
msg : str
The log message.
"""
super().error(msg, *args, extra={"kwargs": kwargs})

def exception(self, msg: object, *args: object, **kwargs: object) -> None:
"""
Convenience method for logging an ERROR with exception information.
Parameters
----------
msg : str
The log message.
exc_info : bool = True
If True, exception information is added to the logging message.
This includes the exception type, value, and traceback. Default is True.
"""
super().error(msg, *args, exc_info=True, extra={"kwargs": kwargs})


class PySparkLogger(PySparkLoggerBase):
"""
Custom logging.Logger wrapper for PySpark that logs messages in a structured JSON format.
Expand Down Expand Up @@ -147,50 +196,3 @@ def getLogger(name: Optional[str] = None) -> "PySparkLogger":
logging.setLoggerClass(existing_logger)

return cast(PySparkLogger, pyspark_logger)

def info(self, msg: object, *args: object, **kwargs: object) -> None:
"""
Log 'msg % args' with severity 'INFO' in structured JSON format.
Parameters
----------
msg : str
The log message.
"""
super().info(msg, *args, extra={"kwargs": kwargs})

def warning(self, msg: object, *args: object, **kwargs: object) -> None:
"""
Log 'msg % args' with severity 'WARNING' in structured JSON format.
Parameters
----------
msg : str
The log message.
"""
super().warning(msg, *args, extra={"kwargs": kwargs})

def error(self, msg: object, *args: object, **kwargs: object) -> None:
"""
Log 'msg % args' with severity 'ERROR' in structured JSON format.
Parameters
----------
msg : str
The log message.
"""
super().error(msg, *args, extra={"kwargs": kwargs})

def exception(self, msg: object, *args: object, **kwargs: object) -> None:
"""
Convenience method for logging an ERROR with exception information.
Parameters
----------
msg : str
The log message.
exc_info : bool = True
If True, exception information is added to the logging message.
This includes the exception type, value, and traceback. Default is True.
"""
super().error(msg, *args, exc_info=True, extra={"kwargs": kwargs})
4 changes: 4 additions & 0 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
PythonDataSource,
)
from pyspark.sql.connect.observation import Observation
from pyspark.sql.connect.udf_logger import ConnectUDFLogCollector
from pyspark.sql.connect.utils import get_python_ver
from pyspark.sql.pandas.types import _create_converter_to_pandas, from_arrow_schema
from pyspark.sql.types import DataType, StructType, TimestampType, _has_type
Expand Down Expand Up @@ -665,6 +666,7 @@ def __init__(
self._server_session_id: Optional[str] = None

self._profiler_collector = ConnectProfilerCollector()
self._udf_log_collector = ConnectUDFLogCollector()

self._progress_handlers: List[ProgressHandler] = []

Expand Down Expand Up @@ -1388,6 +1390,8 @@ def handle_response(
(aid, update) = pickleSer.loads(LiteralExpression._to_value(metric))
if aid == SpecialAccumulatorIds.SQL_UDF_PROFIER:
self._profiler_collector._update(update)
elif aid == SpecialAccumulatorIds.SQL_UDF_LOGGER:
self._udf_log_collector._update(update)
elif observed_metrics.name in observations:
observation_result = observations[observed_metrics.name]._result
assert observation_result is not None
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
MapType,
StringType,
)
from pyspark.sql.udf_logger import UDFLogs
from pyspark.sql.utils import to_str
from pyspark.errors import (
PySparkAttributeError,
Expand Down Expand Up @@ -1079,6 +1080,12 @@ def profile(self) -> Profile:

profile.__doc__ = PySparkSession.profile.__doc__

@property
def udfLogs(self) -> UDFLogs:
return UDFLogs(self, self._client._udf_log_collector)

udfLogs.__doc__ = PySparkSession.udfLogs.__doc__

def __reduce__(self) -> Tuple:
"""
This method is called when the object is pickled. It returns a tuple of the object's
Expand Down
40 changes: 40 additions & 0 deletions python/pyspark/sql/connect/udf_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional

from pyspark.sql.udf_logger import Logs, LogsParam, UDFLogCollector


class ConnectUDFLogCollector(UDFLogCollector):
def __init__(self):
super().__init__()
self._value = LogsParam.zero({})

def collect(self, id: int) -> Optional[List[str]]:
with self._lock:
return self._value.get(id)

def clear(self, id: Optional[int] = None) -> None:
with self._lock:
if id is not None:
self._value.pop(id, None)
else:
self._value.clear()

def _update(self, update: Logs) -> None:
with self._lock:
self._value = LogsParam.addInPlace(self._value, update)
6 changes: 6 additions & 0 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
_from_numpy_type,
)
from pyspark.errors.exceptions.captured import install_exception_handler
from pyspark.sql.udf_logger import AccumulatorUDFLogCollector, UDFLogs
from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str, try_remote_session_classmethod
from pyspark.errors import PySparkValueError, PySparkTypeError, PySparkRuntimeError

Expand Down Expand Up @@ -660,6 +661,7 @@ def __init__(
self._jvm.SparkSession.setActiveSession(self._jsparkSession)

self._profiler_collector = AccumulatorProfilerCollector()
self._udf_log_collector = AccumulatorUDFLogCollector()

def _repr_html_(self) -> str:
return """
Expand Down Expand Up @@ -955,6 +957,10 @@ def profile(self) -> Profile:
"""
return Profile(self._profiler_collector)

@property
def udfLogs(self) -> UDFLogs:
return UDFLogs(self, self._udf_log_collector)

def range(
self,
start: int,
Expand Down
Loading

0 comments on commit 416c36d

Please sign in to comment.