-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import pstats | ||
from typing import Dict | ||
|
||
from pyspark.profiler import CodeMapDict | ||
from pyspark.sql.profiler import ProfilerCollector, ProfileResultsParam | ||
|
||
|
||
class ConnectProfilerCollector(ProfilerCollector): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self._profile_results = ProfileResultsParam.zero({}) | ||
|
||
@property | ||
def _perf_profile_results(self) -> Dict[int, pstats.Stats]: | ||
return { | ||
result_id: perf | ||
for result_id, (perf, _) in self._profile_results.items() | ||
if perf is not None | ||
} | ||
|
||
@property | ||
def _memory_profile_results(self) -> Dict[int, CodeMapDict]: | ||
return { | ||
result_id: mem | ||
for result_id, (_, mem) in self._profile_results.items() | ||
if mem is not None | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from abc import ABC, abstractmethod | ||
import pstats | ||
from threading import RLock | ||
from typing import Dict, Optional, Tuple, cast | ||
|
||
from pyspark.accumulators import Accumulator, AccumulatorParam, SpecialAccumulatorIds | ||
from pyspark.profiler import CodeMapDict, MemoryProfiler, MemUsageParam, PStatsParam | ||
|
||
|
||
ProfileResults = Dict[int, Tuple[Optional[pstats.Stats], Optional[CodeMapDict]]] | ||
|
||
|
||
class ProfileResultsParam(AccumulatorParam[ProfileResults]): | ||
@staticmethod | ||
def zero(value: ProfileResults) -> ProfileResults: | ||
return value or {} | ||
|
||
@staticmethod | ||
def addInPlace(value1: ProfileResults, value2: ProfileResults) -> ProfileResults: | ||
if value1 is None or len(value1) == 0: | ||
return value2 | ||
if value2 is None or len(value2) == 0: | ||
return value1 | ||
|
||
value = value1.copy() | ||
for key, (perf, mem) in value2.items(): | ||
if key in value1: | ||
orig_perf, orig_mem = value1[key] | ||
else: | ||
orig_perf, orig_mem = (PStatsParam.zero(None), MemUsageParam.zero(None)) | ||
value[key] = ( | ||
PStatsParam.addInPlace(orig_perf, perf), | ||
MemUsageParam.addInPlace(orig_mem, mem), | ||
) | ||
return value | ||
|
||
|
||
class ProfilerCollector(ABC): | ||
def __init__(self) -> None: | ||
self._lock = RLock() | ||
|
||
def show_perf_profiles(self, id: Optional[int] = None) -> None: | ||
with self._lock: | ||
if id is not None: | ||
stats = self._perf_profile_results.get(id) | ||
if stats is not None: | ||
print("=" * 60) | ||
print(f"Profile of UDF<id={id}>") | ||
print("=" * 60) | ||
stats.sort_stats("time", "cumulative").print_stats() | ||
else: | ||
for id in sorted(self._perf_profile_results.keys()): | ||
self.show_perf_profiles(id) | ||
|
||
@property | ||
@abstractmethod | ||
def _perf_profile_results(self) -> Dict[int, pstats.Stats]: | ||
... | ||
|
||
def show_memory_profiles(self, id: Optional[int] = None) -> None: | ||
with self._lock: | ||
if id is not None: | ||
code_map = self._memory_profile_results.get(id) | ||
if code_map is not None: | ||
print("=" * 60) | ||
print(f"Profile of UDF<id={id}>") | ||
print("=" * 60) | ||
MemoryProfiler._show_results(code_map) | ||
else: | ||
for id in sorted(self._memory_profile_results.keys()): | ||
self.show_memory_profiles(id) | ||
|
||
@property | ||
@abstractmethod | ||
def _memory_profile_results(self) -> Dict[int, CodeMapDict]: | ||
... | ||
|
||
@abstractmethod | ||
def _clear(self) -> None: | ||
... | ||
|
||
|
||
class AccumulatorProfilerCollector(ProfilerCollector): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self._accumulator = Accumulator( | ||
SpecialAccumulatorIds.SQL_UDF_PROFIER, cast(ProfileResults, {}), ProfileResultsParam() | ||
) | ||
|
||
@property | ||
def _perf_profile_results(self) -> Dict[int, pstats.Stats]: | ||
return { | ||
result_id: perf | ||
for result_id, (perf, _) in self._accumulator.value.items() | ||
if perf is not None | ||
} | ||
|
||
@property | ||
def _memory_profile_results(self) -> Dict[int, CodeMapDict]: | ||
return { | ||
result_id: mem | ||
for result_id, (_, mem) in self._accumulator.value.items() | ||
if mem is not None | ||
} | ||
|
||
def _clear(self) -> None: | ||
self._accumulator._value = {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# | ||
Check failure on line 1 in python/pyspark/sql/tests/connect/test_parity_udf_profiler.py GitHub Actions / Report test resultspython/pyspark/sql/tests/connect/test_parity_udf_profiler.py.test_perf_profiler_pandas_udf
Raw output
Check failure on line 1 in python/pyspark/sql/tests/connect/test_parity_udf_profiler.py GitHub Actions / Report test resultspython/pyspark/sql/tests/connect/test_parity_udf_profiler.py.test_perf_profiler_udf
Raw output
Check failure on line 1 in python/pyspark/sql/tests/connect/test_parity_udf_profiler.py GitHub Actions / Report test resultspython/pyspark/sql/tests/connect/test_parity_udf_profiler.py.test_perf_profiler_udf_multiple_actions
Raw output
Check failure on line 1 in python/pyspark/sql/tests/connect/test_parity_udf_profiler.py GitHub Actions / Report test resultspython/pyspark/sql/tests/connect/test_parity_udf_profiler.py.test_perf_profiler_udf_registered
Raw output
Check failure on line 1 in python/pyspark/sql/tests/connect/test_parity_udf_profiler.py GitHub Actions / Report test resultspython/pyspark/sql/tests/connect/test_parity_udf_profiler.py.test_perf_profiler_udf_with_arrow
Raw output
|
||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import inspect | ||
import os | ||
|
||
from pyspark.sql.tests.test_udf_profiler import UDFProfiler2TestsMixin, _do_computation | ||
from pyspark.testing.connectutils import ReusedConnectTestCase | ||
from pyspark.testing.utils import eventually | ||
|
||
|
||
class UDFProfilerParityTests(UDFProfiler2TestsMixin, ReusedConnectTestCase): | ||
def test_perf_profiler_udf_multiple_actions(self): | ||
def action(df): | ||
df.collect() | ||
df.show() | ||
|
||
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}): | ||
_do_computation(self.spark, action=action) | ||
|
||
profile_results = self.spark._profiler_collector._perf_profile_results | ||
|
||
def check(): | ||
self.assertEqual(6, len(profile_results), str(list(profile_results))) | ||
|
||
for id in profile_results: | ||
with self.trap_stdout() as io: | ||
self.spark.show_perf_profiles(id) | ||
|
||
self.assertIn(f"Profile of UDF<id={id}>", io.getvalue()) | ||
self.assertRegex( | ||
io.getvalue(), f"10.*{os.path.basename(inspect.getfile(_do_computation))}" | ||
) | ||
|
||
eventually(timeout=1, catch_assertions=True)(check)() | ||
|
||
|
||
if __name__ == "__main__": | ||
import unittest | ||
from pyspark.sql.tests.connect.test_parity_udf_profiler import * # noqa: F401 | ||
|
||
try: | ||
import xmlrunner # type: ignore[import] | ||
|
||
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) | ||
except ImportError: | ||
testRunner = None | ||
unittest.main(testRunner=testRunner, verbosity=2) |