Skip to content

Commit

Permalink
[SPARK-50858][PYTHON] Add configuration to hide Python UDF stack trace
Browse files Browse the repository at this point in the history
  • Loading branch information
wengh committed Jan 17, 2025
1 parent 11d758f commit 174fdea
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
protected val faultHandlerEnabled: Boolean = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
protected val hideTraceback: Boolean = false
protected val simplifiedTraceback: Boolean = false

// All the Python functions should have the same exec, version and envvars.
Expand Down Expand Up @@ -199,6 +200,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
if (reuseWorker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
if (hideTraceback) {
envVars.put("SPARK_HIDE_TRACEBACK", "1")
}
if (simplifiedTraceback) {
envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
}
Expand Down
31 changes: 31 additions & 0 deletions python/pyspark/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#
import os
import unittest
from unittest.mock import patch

from py4j.protocol import Py4JJavaError

Expand Down Expand Up @@ -125,6 +126,36 @@ def test_parse_memory(self):
_parse_memory("2gs")


class HandleWorkerExceptionTests(unittest.TestCase):
exception_bytes = b"ValueError: test_message"
traceback_bytes = b"Traceback (most recent call last):"

def run_handle_worker_exception(self):
import io
from pyspark.util import handle_worker_exception

try:
raise ValueError("test_message")
except Exception as e:
with io.BytesIO() as stream:
handle_worker_exception(e, stream)
return stream.getvalue()

@patch.dict(
os.environ, {"SPARK_SIMPLIFIED_TRACEBACK": "", "SPARK_HIDE_TRACEBACK": ""}
)
def test_full(self):
result = self.run_handle_worker_exception()
self.assertIn(self.exception_bytes, result)
self.assertIn(self.traceback_bytes, result)

@patch.dict(os.environ, {"SPARK_HIDE_TRACEBACK": "1"})
def test_hide_traceback(self):
result = self.run_handle_worker_exception()
self.assertIn(self.exception_bytes, result)
self.assertNotIn(self.traceback_bytes, result)


if __name__ == "__main__":
from pyspark.tests.test_util import * # noqa: F401

Expand Down
13 changes: 8 additions & 5 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,16 +468,19 @@ def handle_worker_exception(e: BaseException, outfile: IO) -> None:
and exception traceback info to outfile. JVM could then read from the outfile and perform
exception handling there.
"""
try:
exc_info = None

def format_exception():
if os.environ.get("SPARK_HIDE_TRACEBACK", False):
return "".join(traceback.format_exception_only(type(e), e))
if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
tb = try_simplify_traceback(sys.exc_info()[-1]) # type: ignore[arg-type]
if tb is not None:
e.__cause__ = None
exc_info = "".join(traceback.format_exception(type(e), e, tb))
if exc_info is None:
exc_info = traceback.format_exc()
return "".join(traceback.format_exception(type(e), e, tb))
return traceback.format_exc()

try:
exc_info = format_exception()
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(exc_info.encode("utf-8"), outfile)
except IOError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3459,6 +3459,16 @@ object SQLConf {
.checkValues(Set("legacy", "row", "dict"))
.createWithDefaultString("legacy")

val PYSPARK_HIDE_TRACEBACK =
buildConf("spark.sql.execution.pyspark.udf.hideTraceback.enabled")
.doc(
"When true, only show the message of the exception from Python UDFs, " +
"hiding stack trace and exception type."
)
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val PYSPARK_SIMPLIFIED_TRACEBACK =
buildConf("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled")
.doc(
Expand Down Expand Up @@ -6254,6 +6264,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def pandasStructHandlingMode: String = getConf(PANDAS_STRUCT_HANDLING_MODE)

def pysparkHideTraceback: Boolean = getConf(PYSPARK_HIDE_TRACEBACK)

def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIED_TRACEBACK)

def pandasGroupedMapAssignColumnsByName: Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class ApplyInPandasWithStatePythonRunner(
override protected lazy val timeZoneId: String = _timeZoneId
override val errorOnDuplicatedFieldNames: Boolean = true

override val hideTraceback: Boolean = sqlConf.pysparkHideTraceback
override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback

override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ abstract class BaseArrowPythonRunner(

override val errorOnDuplicatedFieldNames: Boolean = true

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

// Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class ArrowPythonUDTFRunner(

override val errorOnDuplicatedFieldNames: Boolean = true

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class CoGroupedArrowPythonRunner(

override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

protected def newWriter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType)

override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) {
val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory

Expand All @@ -66,6 +67,9 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) {
if (reuseWorker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
if (hideTraceback) {
envVars.put("SPARK_HIDE_TRACEBACK", "1")
}
if (simplifiedTraceback) {
envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ abstract class BasePythonUDFRunner(
SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
funcs.head._1.funcs.head.pythonExec)

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
Expand Down

0 comments on commit 174fdea

Please sign in to comment.