Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50858][PYTHON] Add configuration to hide Python UDF stack trace #49535

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
protected val faultHandlerEnabled: Boolean = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
protected val hideTraceback: Boolean = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use conf.get(PYSPARK_HIDE_TRACEBACK) here so that we don't need to override every subclass?

Copy link
Author

@wengh wengh Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config is defined in org.apache.spark.sql.internal.SQLConf which seems to be inaccessible from here. For reference, PYSPARK_SIMPLIFIED_TRACEBACK is also defined in SQLConf so BasePythonRunner subclasses have to override it.

Is there an advantage for putting it in SQLConf rather than e.g. org.apache.spark.internal.config.Python?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conf in SQLConf is session-based conf that also can be set in runtime, and any conf in core module or StaticSQLConf is cluster-wide conf and can't be changed while the cluster is running.

protected val simplifiedTraceback: Boolean = false

// All the Python functions should have the same exec, version and envvars.
Expand Down Expand Up @@ -199,6 +200,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
if (reuseWorker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
if (hideTraceback) {
envVars.put("SPARK_HIDE_TRACEBACK", "1")
}
if (simplifiedTraceback) {
envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
}
Expand Down
39 changes: 39 additions & 0 deletions python/pyspark/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#
import os
import unittest
from unittest.mock import patch

from py4j.protocol import Py4JJavaError

Expand Down Expand Up @@ -125,6 +126,44 @@ def test_parse_memory(self):
_parse_memory("2gs")


class HandleWorkerExceptionTests(unittest.TestCase):
exception_bytes = b"ValueError: test_message"
traceback_bytes = b"Traceback (most recent call last):"

def run_handle_worker_exception(self, hide_traceback=None):
import io
from pyspark.util import handle_worker_exception

try:
raise ValueError("test_message")
except Exception as e:
with io.BytesIO() as stream:
handle_worker_exception(e, stream, hide_traceback)
return stream.getvalue()

@patch.dict(os.environ, {"SPARK_SIMPLIFIED_TRACEBACK": "", "SPARK_HIDE_TRACEBACK": ""})
def test_env_full(self):
result = self.run_handle_worker_exception()
self.assertIn(self.exception_bytes, result)
self.assertIn(self.traceback_bytes, result)

@patch.dict(os.environ, {"SPARK_HIDE_TRACEBACK": "1"})
def test_env_hide_traceback(self):
result = self.run_handle_worker_exception()
self.assertIn(self.exception_bytes, result)
self.assertNotIn(self.traceback_bytes, result)

def test_full(self):
result = self.run_handle_worker_exception(False)
self.assertIn(self.exception_bytes, result)
self.assertIn(self.traceback_bytes, result)

def test_hide_traceback(self):
result = self.run_handle_worker_exception(True)
self.assertIn(self.exception_bytes, result)
self.assertNotIn(self.traceback_bytes, result)


if __name__ == "__main__":
from pyspark.tests.test_util import * # noqa: F401

Expand Down
30 changes: 24 additions & 6 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,22 +462,40 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
return f # type: ignore[return-value]


def handle_worker_exception(e: BaseException, outfile: IO) -> None:
def handle_worker_exception(
e: BaseException, outfile: IO, hide_traceback: Optional[bool] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to pass hide_traceback?

Copy link
Author

@wengh wengh Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's optional, and uses the value of SPARK_HIDE_TRACEBACK by default (see docstring)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any place that passes this parameter except for the tests.
I just thought test_env_full and test_env_hide_traceback are enough without taking it here?
If we want to take this, we need more tests, like "setting env but pass it" will ignore the env var.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add more tests for the override behaviour

) -> None:
"""
Handles exception for Python worker which writes SpecialLengths.PYTHON_EXCEPTION_THROWN (-2)
and exception traceback info to outfile. JVM could then read from the outfile and perform
exception handling there.

Parameters
----------
e : BaseException
Exception handled
outfile : IO
IO object to write the exception info
hide_traceback : bool, optional
Whether to hide the traceback in the output.
By default, hides the traceback if environment variable SPARK_HIDE_TRACEBACK is set.
Comment on lines +473 to +481
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the parameters here!

"""
try:
exc_info = None

if hide_traceback is None:
hide_traceback = bool(os.environ.get("SPARK_HIDE_TRACEBACK", False))

def format_exception() -> str:
if hide_traceback:
return "".join(traceback.format_exception_only(type(e), e))
if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
tb = try_simplify_traceback(sys.exc_info()[-1]) # type: ignore[arg-type]
if tb is not None:
e.__cause__ = None
exc_info = "".join(traceback.format_exception(type(e), e, tb))
if exc_info is None:
exc_info = traceback.format_exc()
return "".join(traceback.format_exception(type(e), e, tb))
return traceback.format_exc()

try:
exc_info = format_exception()
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(exc_info.encode("utf-8"), outfile)
except IOError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3475,6 +3475,15 @@ object SQLConf {
.checkValues(Set("legacy", "row", "dict"))
.createWithDefaultString("legacy")

val PYSPARK_HIDE_TRACEBACK =
buildConf("spark.sql.execution.pyspark.udf.hideTraceback.enabled")
.doc(
"When true, only show the message of the exception from Python UDFs, " +
"hiding the stack trace.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to describe a bit more about the relationship between this and simplifiedTraceback, here or in the doc of simplifiedTraceback.
Seems like if this is enabled, simplifiedTraceback will be ignored?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, simplifiedTraceback is not applicable if hideTraceback is set. Unless caller sets parameter hide_traceback=False to override the config.

I'll update the description to reflect this.

.version("4.0.0")
.booleanConf
.createWithDefault(false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another way is to create this conf as an int, and show the max depth of stacktrace but I don't feel strongly.

Copy link
Author

@wengh wengh Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a use case where we only want to show only last k frames of the stack? I'm under the impression that we want to show full stack trace for most exceptions, and completely hide stack trace for specific library exceptions when the message is sufficient to identify the reason.


val PYSPARK_SIMPLIFIED_TRACEBACK =
buildConf("spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled")
.doc(
Expand Down Expand Up @@ -6286,6 +6295,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def pandasStructHandlingMode: String = getConf(PANDAS_STRUCT_HANDLING_MODE)

def pysparkHideTraceback: Boolean = getConf(PYSPARK_HIDE_TRACEBACK)

def pysparkSimplifiedTraceback: Boolean = getConf(PYSPARK_SIMPLIFIED_TRACEBACK)

def pandasGroupedMapAssignColumnsByName: Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class ApplyInPandasWithStatePythonRunner(
override protected lazy val timeZoneId: String = _timeZoneId
override val errorOnDuplicatedFieldNames: Boolean = true

override val hideTraceback: Boolean = sqlConf.pysparkHideTraceback
override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback

override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ abstract class BaseArrowPythonRunner(

override val errorOnDuplicatedFieldNames: Boolean = true

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

// Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class ArrowPythonUDTFRunner(

override val errorOnDuplicatedFieldNames: Boolean = true

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class CoGroupedArrowPythonRunner(

override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

protected def newWriter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType)

override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) {
val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory

Expand All @@ -68,6 +69,9 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) {
if (reuseWorker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
if (hideTraceback) {
envVars.put("SPARK_HIDE_TRACEBACK", "1")
}
if (simplifiedTraceback) {
envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ abstract class BasePythonUDFRunner(
SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
funcs.head._1.funcs.head.pythonExec)

override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback

override val faultHandlerEnabled: Boolean = SQLConf.get.pythonUDFWorkerFaulthandlerEnabled
Expand Down
Loading