From 174fdea12ce62f3fb4da9d0f9d96dc376320d633 Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Thu, 16 Jan 2025 15:19:14 -0800 Subject: [PATCH] [SPARK-50858][PYTHON] Add configuration to hide Python UDF stack trace --- .../spark/api/python/PythonRunner.scala | 4 +++ python/pyspark/tests/test_util.py | 31 +++++++++++++++++++ python/pyspark/util.py | 13 +++++--- .../apache/spark/sql/internal/SQLConf.scala | 12 +++++++ .../ApplyInPandasWithStatePythonRunner.scala | 1 + .../execution/python/ArrowPythonRunner.scala | 1 + .../python/ArrowPythonUDTFRunner.scala | 1 + .../python/CoGroupedArrowPythonRunner.scala | 1 + .../python/PythonForeachWriter.scala | 1 + .../python/PythonPlannerRunner.scala | 4 +++ .../execution/python/PythonUDFRunner.scala | 1 + 11 files changed, 65 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index e3d10574419b3..5901cbe23e435 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -122,6 +122,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) protected val faultHandlerEnabled: Boolean = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED) + protected val hideTraceback: Boolean = false protected val simplifiedTraceback: Boolean = false // All the Python functions should have the same exec, version and envvars. @@ -199,6 +200,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( if (reuseWorker) { envVars.put("SPARK_REUSE_WORKER", "1") } + if (hideTraceback) { + envVars.put("SPARK_HIDE_TRACEBACK", "1") + } if (simplifiedTraceback) { envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") } diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py index ad0b106d229aa..1c1a0c4abe7ae 100644 --- a/python/pyspark/tests/test_util.py +++ b/python/pyspark/tests/test_util.py @@ -16,6 +16,7 @@ # import os import unittest +from unittest.mock import patch from py4j.protocol import Py4JJavaError @@ -125,6 +126,36 @@ def test_parse_memory(self): _parse_memory("2gs") +class HandleWorkerExceptionTests(unittest.TestCase): + exception_bytes = b"ValueError: test_message" + traceback_bytes = b"Traceback (most recent call last):" + + def run_handle_worker_exception(self): + import io + from pyspark.util import handle_worker_exception + + try: + raise ValueError("test_message") + except Exception as e: + with io.BytesIO() as stream: + handle_worker_exception(e, stream) + return stream.getvalue() + + @patch.dict( + os.environ, {"SPARK_SIMPLIFIED_TRACEBACK": "", "SPARK_HIDE_TRACEBACK": ""} + ) + def test_full(self): + result = self.run_handle_worker_exception() + self.assertIn(self.exception_bytes, result) + self.assertIn(self.traceback_bytes, result) + + @patch.dict(os.environ, {"SPARK_HIDE_TRACEBACK": "1"}) + def test_hide_traceback(self): + result = self.run_handle_worker_exception() + self.assertIn(self.exception_bytes, result) + self.assertNotIn(self.traceback_bytes, result) + + if __name__ == "__main__": from pyspark.tests.test_util import * # noqa: F401 diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 3e9a68ccfe2e5..30827371e6fc2 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -468,16 +468,19 @@ def handle_worker_exception(e: BaseException, outfile: IO) -> None: and exception traceback info to outfile. JVM could then read from the outfile and perform exception handling there. """ - try: - exc_info = None + + def format_exception(): + if os.environ.get("SPARK_HIDE_TRACEBACK", False): + return "".join(traceback.format_exception_only(type(e), e)) if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False): tb = try_simplify_traceback(sys.exc_info()[-1]) # type: ignore[arg-type] if tb is not None: e.__cause__ = None - exc_info = "".join(traceback.format_exception(type(e), e, tb)) - if exc_info is None: - exc_info = traceback.format_exc() + return "".join(traceback.format_exception(type(e), e, tb)) + return traceback.format_exc() + try: + exc_info = format_exception() write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(exc_info.encode("utf-8"), outfile) except IOError: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d4298087b2108..31650757fd826 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3459,6 +3459,16 @@ object SQLConf { .checkValues(Set("legacy", "row", "dict")) .createWithDefaultString("legacy") + val PYSPARK_HIDE_TRACEBACK = + buildConf("spark.sql.execution.pyspark.udf.hideTraceback.enabled") + .doc( + "When true, only show the message of the exception from Python UDFs, " + + "hiding stack trace and exception type." + ) + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val PYSPARK_SIMPLIFIED_TRACEBACK = buildConf("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled") .doc( @@ -6254,6 +6264,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def pandasStructHandlingMode: String = getConf(PANDAS_STRUCT_HANDLING_MODE) + def pysparkHideTraceback: Boolean = getConf(PYSPARK_HIDE_TRACEBACK) + def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIED_TRACEBACK) def pandasGroupedMapAssignColumnsByName: Boolean = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index d704638b85e8a..f598430df0eea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -84,6 +84,7 @@ class ApplyInPandasWithStatePythonRunner( override protected lazy val timeZoneId: String = _timeZoneId override val errorOnDuplicatedFieldNames: Boolean = true + override val hideTraceback: Boolean = sqlConf.pysparkHideTraceback override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 579b496046852..1bddd81fbfe20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -50,6 +50,7 @@ abstract class BaseArrowPythonRunner( override val errorOnDuplicatedFieldNames: Boolean = true + override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback // Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index 99a9e706c6620..f42c4b6106cbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -59,6 +59,7 @@ class ArrowPythonUDTFRunner( override val errorOnDuplicatedFieldNames: Boolean = true + override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index c5e86d010938d..59e8970b9c9b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -60,6 +60,7 @@ class CoGroupedArrowPythonRunner( override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled + override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback protected def newWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index ed7ff6a753487..4655f96425fd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -100,6 +100,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled + override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala index 63c98b3002aac..aeda917857b50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala @@ -50,6 +50,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) { val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE) val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory @@ -66,6 +67,9 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) { if (reuseWorker) { envVars.put("SPARK_REUSE_WORKER", "1") } + if (hideTraceback) { + envVars.put("SPARK_HIDE_TRACEBACK", "1") + } if (simplifiedTraceback) { envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 167e1fd8b0f01..a322dfa10df5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -42,6 +42,7 @@ abstract class BasePythonUDFRunner( SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head._1.funcs.head.pythonExec) + override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled