diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 28950e5b41d49..64e78dbccb2f9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -122,6 +122,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) protected val faultHandlerEnabled: Boolean = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED) + protected val hideTraceback: Boolean = false protected val simplifiedTraceback: Boolean = false // All the Python functions should have the same exec, version and envvars. @@ -199,6 +200,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( if (reuseWorker) { envVars.put("SPARK_REUSE_WORKER", "1") } + if (hideTraceback) { + envVars.put("SPARK_HIDE_TRACEBACK", "1") + } if (simplifiedTraceback) { envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") } diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py index ad0b106d229aa..960f6c60f3ace 100644 --- a/python/pyspark/tests/test_util.py +++ b/python/pyspark/tests/test_util.py @@ -16,6 +16,7 @@ # import os import unittest +from unittest.mock import patch from py4j.protocol import Py4JJavaError @@ -125,6 +126,44 @@ def test_parse_memory(self): _parse_memory("2gs") +class HandleWorkerExceptionTests(unittest.TestCase): + exception_bytes = b"ValueError: test_message" + traceback_bytes = b"Traceback (most recent call last):" + + def run_handle_worker_exception(self, hide_traceback=None): + import io + from pyspark.util import handle_worker_exception + + try: + raise ValueError("test_message") + except Exception as e: + with io.BytesIO() as stream: + handle_worker_exception(e, stream, hide_traceback) + return stream.getvalue() + + @patch.dict(os.environ, {"SPARK_SIMPLIFIED_TRACEBACK": "", "SPARK_HIDE_TRACEBACK": ""}) + def test_env_full(self): + result = self.run_handle_worker_exception() + self.assertIn(self.exception_bytes, result) + self.assertIn(self.traceback_bytes, result) + + @patch.dict(os.environ, {"SPARK_HIDE_TRACEBACK": "1"}) + def test_env_hide_traceback(self): + result = self.run_handle_worker_exception() + self.assertIn(self.exception_bytes, result) + self.assertNotIn(self.traceback_bytes, result) + + def test_full(self): + result = self.run_handle_worker_exception(False) + self.assertIn(self.exception_bytes, result) + self.assertIn(self.traceback_bytes, result) + + def test_hide_traceback(self): + result = self.run_handle_worker_exception(True) + self.assertIn(self.exception_bytes, result) + self.assertNotIn(self.traceback_bytes, result) + + if __name__ == "__main__": from pyspark.tests.test_util import * # noqa: F401 diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 3e9a68ccfe2e5..f51706858182c 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -462,22 +462,40 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return f # type: ignore[return-value] -def handle_worker_exception(e: BaseException, outfile: IO) -> None: +def handle_worker_exception( + e: BaseException, outfile: IO, hide_traceback: Optional[bool] = None +) -> None: """ Handles exception for Python worker which writes SpecialLengths.PYTHON_EXCEPTION_THROWN (-2) and exception traceback info to outfile. JVM could then read from the outfile and perform exception handling there. + + Parameters + ---------- + e : BaseException + Exception handled + outfile : IO + IO object to write the exception info + hide_traceback : bool, optional + Whether to hide the traceback in the output. + By default, hides the traceback if environment variable SPARK_HIDE_TRACEBACK is set. """ - try: - exc_info = None + + if hide_traceback is None: + hide_traceback = bool(os.environ.get("SPARK_HIDE_TRACEBACK", False)) + + def format_exception() -> str: + if hide_traceback: + return "".join(traceback.format_exception_only(type(e), e)) if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False): tb = try_simplify_traceback(sys.exc_info()[-1]) # type: ignore[arg-type] if tb is not None: e.__cause__ = None - exc_info = "".join(traceback.format_exception(type(e), e, tb)) - if exc_info is None: - exc_info = traceback.format_exc() + return "".join(traceback.format_exception(type(e), e, tb)) + return traceback.format_exc() + try: + exc_info = format_exception() write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(exc_info.encode("utf-8"), outfile) except IOError: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4907e7ee6276e..7b560002edeba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3475,6 +3475,15 @@ object SQLConf { .checkValues(Set("legacy", "row", "dict")) .createWithDefaultString("legacy") + val PYSPARK_HIDE_TRACEBACK = + buildConf("spark.sql.execution.pyspark.udf.hideTraceback.enabled") + .doc( + "When true, only show the message of the exception from Python UDFs, " + + "hiding the stack trace. If this is enabled, simplifiedTraceback has no effect.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val PYSPARK_SIMPLIFIED_TRACEBACK = buildConf("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled") .doc( @@ -6286,6 +6295,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def pandasStructHandlingMode: String = getConf(PANDAS_STRUCT_HANDLING_MODE) + def pysparkHideTraceback: Boolean = getConf(PYSPARK_HIDE_TRACEBACK) + def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIED_TRACEBACK) def pandasGroupedMapAssignColumnsByName: Boolean = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index d704638b85e8a..f598430df0eea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -84,6 +84,7 @@ class ApplyInPandasWithStatePythonRunner( override protected lazy val timeZoneId: String = _timeZoneId override val errorOnDuplicatedFieldNames: Boolean = true + override val hideTraceback: Boolean = sqlConf.pysparkHideTraceback override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 579b496046852..1bddd81fbfe20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -50,6 +50,7 @@ abstract class BaseArrowPythonRunner( override val errorOnDuplicatedFieldNames: Boolean = true + override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback // Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index 99a9e706c6620..f42c4b6106cbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -59,6 +59,7 @@ class ArrowPythonUDTFRunner( override val errorOnDuplicatedFieldNames: Boolean = true + override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index c5e86d010938d..59e8970b9c9b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -60,6 +60,7 @@ class CoGroupedArrowPythonRunner( override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled + override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback protected def newWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index ed7ff6a753487..4655f96425fd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -100,6 +100,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled + override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala index 8cc2e1de7a4c3..1974c393c472c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala @@ -52,6 +52,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) { val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE) val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled + val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory @@ -68,6 +69,9 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) { if (reuseWorker) { envVars.put("SPARK_REUSE_WORKER", "1") } + if (hideTraceback) { + envVars.put("SPARK_HIDE_TRACEBACK", "1") + } if (simplifiedTraceback) { envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 167e1fd8b0f01..a322dfa10df5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -42,6 +42,7 @@ abstract class BasePythonUDFRunner( SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head._1.funcs.head.pythonExec) + override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled