diff --git a/dev/requirements.txt b/dev/requirements.txt index 88883a963950e..c9d07fae000be 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -62,6 +62,7 @@ grpc-stubs==1.24.11 # Debug for Spark and Spark Connect graphviz==0.20.3 +flameprof==0.4 # TorchDistributor dependencies torch diff --git a/docs/img/pyspark-udf-profile.png b/docs/img/pyspark-udf-profile.png new file mode 100644 index 0000000000000..5b8ab3f3bd8b3 Binary files /dev/null and b/docs/img/pyspark-udf-profile.png differ diff --git a/python/docs/source/development/debugging.rst b/python/docs/source/development/debugging.rst index b0b2c4837ded4..9510fe0abde1e 100644 --- a/python/docs/source/development/debugging.rst +++ b/python/docs/source/development/debugging.rst @@ -263,6 +263,16 @@ The UDF IDs can be seen in the query plan, for example, ``add1(...)#2L`` in ``Ar +- ArrowEvalPython [add1(id#0L)#2L], [pythonUDF0#11L], 200 +- *(1) Range (0, 10, step=1, splits=16) +We can render the result with an arbitrary renderer function as shown below. + +.. code-block:: python + + def do_render(codemap): + # Your custom rendering logic + ... + + spark.profile.render(id=2, type="memory", renderer=do_render) + We can clear the result memory profile as shown below. .. code-block:: python @@ -358,6 +368,25 @@ The UDF IDs can be seen in the query plan, for example, ``add1(...)#2L`` in ``Ar +- ArrowEvalPython [add1(id#0L)#2L], [pythonUDF0#11L], 200 +- *(1) Range (0, 10, step=1, splits=16) +We can render the result with a preregistered renderer as shown below. + +.. code-block:: python + + >>> spark.profile.render(id=2, type="perf") # renderer="flameprof" by default + +.. image:: ../../../../docs/img/pyspark-udf-profile.png + :alt: PySpark UDF profile + +Or with an arbitrary renderer function as shown below. + +.. code-block:: python + + >>> def do_render(stats): + ... # Your custom rendering logic + ... ... + ... + >>> spark.profile.render(id=2, type="perf", renderer=do_render) + We can clear the result performance profile as shown below. .. code-block:: python diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index 6cc68cd46b117..549656bea103e 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -225,6 +225,10 @@ Package Supported version Note `pyarrow` >=10.0.0 Required for Spark SQL ========= ================= ====================== +Additional libraries that enhance functionality but are not included in the installation packages: + +- **flameprof**: Provide the default renderer for UDF performance profiling. + Pandas API on Spark ^^^^^^^^^^^^^^^^^^^ diff --git a/python/mypy.ini b/python/mypy.ini index bc6e239555073..c7cf8df114147 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -174,6 +174,9 @@ ignore_missing_imports = True [mypy-memory_profiler.*] ignore_missing_imports = True +[mypy-flameprof.*] +ignore_missing_imports = True + ; Ignore errors for proto generated code [mypy-pyspark.sql.connect.proto.*, pyspark.sql.connect.proto] ignore_errors = True diff --git a/python/pyspark/sql/profiler.py b/python/pyspark/sql/profiler.py index 711e39de4723b..9eaf1f264a5b2 100644 --- a/python/pyspark/sql/profiler.py +++ b/python/pyspark/sql/profiler.py @@ -15,10 +15,11 @@ # limitations under the License. # from abc import ABC, abstractmethod +from io import StringIO import os import pstats from threading import RLock -from typing import Dict, Optional, TYPE_CHECKING +from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union, TYPE_CHECKING, overload from pyspark.accumulators import ( Accumulator, @@ -360,6 +361,84 @@ def dump(self, path: str, id: Optional[int] = None, *, type: Optional[str] = Non }, ) + @overload + def render(self, id: int, *, type: Optional[str] = None, renderer: Optional[str] = None) -> Any: + ... + + @overload + def render( + self, id: int, *, type: Optional[Literal["perf"]], renderer: Callable[[pstats.Stats], Any] + ) -> Any: + ... + + @overload + def render( + self, id: int, *, type: Literal["memory"], renderer: Callable[[CodeMapDict], Any] + ) -> Any: + ... + + def render( + self, + id: int, + *, + type: Optional[str] = None, + renderer: Optional[ + Union[str, Callable[[pstats.Stats], Any], Callable[[CodeMapDict], Any]] + ] = None, + ) -> Any: + """ + Render the profile results. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + id : int + The UDF ID whose profiling results should be rendered. + type : str, optional + The profiler type to clear results for, which can be either "perf" or "memory". + renderer : str or callable, optional + The renderer to use. If not specified, the default renderer will be "flameprof" + for the "perf" profiler, which returns an :class:`IPython.display.HTML` object in + an IPython environment to draw the figure; otherwise, it returns the SVG source string. + For the "memory" profiler, no default renderer is provided. + + If a callable is provided, it should take a `pstats.Stats` object for "perf" profiler, + and `CodeMapDict` for "memory" profiler, and return the rendered result. + """ + result: Optional[Union[pstats.Stats, CodeMapDict]] + if type is None: + type = "perf" + if type == "perf": + result = self.profiler_collector._perf_profile_results.get(id) + elif type == "memory": + result = self.profiler_collector._memory_profile_results.get(id) + else: + raise PySparkValueError( + error_class="VALUE_NOT_ALLOWED", + message_parameters={ + "arg_name": "type", + "allowed_values": str(["perf", "memory"]), + }, + ) + + render: Optional[Union[Callable[[pstats.Stats], Any], Callable[[CodeMapDict], Any]]] = None + if renderer is None or isinstance(renderer, str): + render = _renderers.get((type, renderer)) + elif callable(renderer): + render = renderer + if render is None: + raise PySparkValueError( + error_class="VALUE_NOT_ALLOWED", + message_parameters={ + "arg_name": "(type, renderer)", + "allowed_values": str(list(_renderers.keys())), + }, + ) + + if result is not None: + return render(result) # type:ignore[arg-type] + def clear(self, id: Optional[int] = None, *, type: Optional[str] = None) -> None: """ Clear the profile results. @@ -388,3 +467,39 @@ def clear(self, id: Optional[int] = None, *, type: Optional[str] = None) -> None "allowed_values": str(["perf", "memory"]), }, ) + + +def _render_flameprof(stats: pstats.Stats) -> Any: + try: + from flameprof import render + except ImportError: + raise PySparkValueError( + error_class="PACKAGE_NOT_INSTALLED", + message_parameters={"package_name": "flameprof", "minimum_version": "0.4"}, + ) + + buf = StringIO() + render(stats.stats, buf) # type: ignore[attr-defined] + svg = buf.getvalue() + + try: + import IPython + + ipython = IPython.get_ipython() + except ImportError: + ipython = None + + if ipython: + from IPython.display import HTML + + return HTML(svg) + else: + return svg + + +_renderers: Dict[ + Tuple[str, Optional[str]], Union[Callable[[pstats.Stats], Any], Callable[[CodeMapDict], Any]] +] = { + ("perf", None): _render_flameprof, + ("perf", "flameprof"): _render_flameprof, +} diff --git a/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py index a1789a50896db..274364b181441 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py +++ b/python/pyspark/sql/tests/connect/test_parity_udf_profiler.py @@ -18,7 +18,11 @@ import os import unittest -from pyspark.sql.tests.test_udf_profiler import UDFProfiler2TestsMixin, _do_computation +from pyspark.sql.tests.test_udf_profiler import ( + UDFProfiler2TestsMixin, + _do_computation, + has_flameprof, +) from pyspark.testing.connectutils import ReusedConnectTestCase @@ -61,6 +65,9 @@ def action(df): io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}" ) + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_parity_udf_profiler import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_udf_profiler.py b/python/pyspark/sql/tests/test_udf_profiler.py index a66503bc02138..bb8c0765153c9 100644 --- a/python/pyspark/sql/tests/test_udf_profiler.py +++ b/python/pyspark/sql/tests/test_udf_profiler.py @@ -26,6 +26,7 @@ from typing import Iterator, cast from pyspark import SparkConf +from pyspark.errors import PySparkValueError from pyspark.sql import SparkSession from pyspark.sql.functions import col, pandas_udf, udf from pyspark.sql.window import Window @@ -38,6 +39,13 @@ pyarrow_requirement_message, ) +try: + import flameprof # noqa: F401 + + has_flameprof = True +except ImportError: + has_flameprof = False + def _do_computation(spark, *, action=lambda df: df.collect(), use_arrow=False): @udf("long", useArrow=use_arrow) @@ -200,6 +208,9 @@ def test_perf_profiler_udf(self): ) self.assertTrue(f"udf_{id}_perf.pstats" in os.listdir(d)) + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + @unittest.skipIf( not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), @@ -219,6 +230,9 @@ def test_perf_profiler_udf_with_arrow(self): io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}" ) + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + def test_perf_profiler_udf_multiple_actions(self): def action(df): df.collect() @@ -238,6 +252,9 @@ def action(df): io.getvalue(), f"20.*{os.path.basename(inspect.getfile(_do_computation))}" ) + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + def test_perf_profiler_udf_registered(self): @udf("long") def add1(x): @@ -259,6 +276,9 @@ def add1(x): io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}" ) + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + @unittest.skipIf( not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), @@ -289,6 +309,9 @@ def add2(x): io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}" ) + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + @unittest.skipIf( not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), @@ -322,6 +345,9 @@ def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]: io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}" ) + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + @unittest.skipIf( not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), @@ -369,6 +395,9 @@ def mean_udf(v: pd.Series) -> float: io.getvalue(), f"5.*{os.path.basename(inspect.getfile(_do_computation))}" ) + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + @unittest.skipIf( not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), @@ -398,6 +427,9 @@ def min_udf(v: pd.Series) -> float: io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}" ) + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + @unittest.skipIf( not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), @@ -426,6 +458,9 @@ def normalize(pdf): io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}" ) + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + @unittest.skipIf( not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), @@ -461,6 +496,9 @@ def asof_join(left, right): io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}" ) + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + @unittest.skipIf( not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), @@ -492,6 +530,9 @@ def normalize(table): io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}" ) + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + @unittest.skipIf( not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), @@ -521,6 +562,69 @@ def summarize(left, right): io.getvalue(), f"2.*{os.path.basename(inspect.getfile(_do_computation))}" ) + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + + def test_perf_profiler_render(self): + with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}): + _do_computation(self.spark) + self.assertEqual(3, len(self.profile_results), str(list(self.profile_results))) + + id = list(self.profile_results.keys())[0] + + if has_flameprof: + self.assertIn("svg", self.spark.profile.render(id)) + self.assertIn("svg", self.spark.profile.render(id, type="perf")) + self.assertIn("svg", self.spark.profile.render(id, renderer="flameprof")) + + with self.assertRaises(PySparkValueError) as pe: + self.spark.profile.render(id, type="unknown") + + self.check_error( + exception=pe.exception, + error_class="VALUE_NOT_ALLOWED", + message_parameters={ + "arg_name": "type", + "allowed_values": "['perf', 'memory']", + }, + ) + + with self.assertRaises(PySparkValueError) as pe: + self.spark.profile.render(id, type="memory") + + self.check_error( + exception=pe.exception, + error_class="VALUE_NOT_ALLOWED", + message_parameters={ + "arg_name": "(type, renderer)", + "allowed_values": "[('perf', None), ('perf', 'flameprof')]", + }, + ) + + with self.assertRaises(PySparkValueError) as pe: + self.spark.profile.render(id, renderer="unknown") + + self.check_error( + exception=pe.exception, + error_class="VALUE_NOT_ALLOWED", + message_parameters={ + "arg_name": "(type, renderer)", + "allowed_values": "[('perf', None), ('perf', 'flameprof')]", + }, + ) + + with self.trap_stdout() as io: + self.spark.profile.show(id, type="perf") + show_value = io.getvalue() + + with self.trap_stdout() as io: + self.spark.profile.render( + id, renderer=lambda s: s.sort_stats("time", "cumulative").print_stats() + ) + render_value = io.getvalue() + + self.assertIn(render_value, show_value) + def test_perf_profiler_clear(self): with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}): _do_computation(self.spark) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e9c259e68a27a..62524889b66a2 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -718,8 +718,8 @@ def wrap_perf_profiler(f, result_id): ) def profiling_func(*args, **kwargs): - pr = cProfile.Profile() - ret = pr.runcall(f, *args, **kwargs) + with cProfile.Profile() as pr: + ret = f(*args, **kwargs) st = pstats.Stats(pr) st.stream = None # make it picklable st.strip_dirs()