diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index c1dc7d2dc621e..2e188b411df17 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -26,6 +26,7 @@ from pyspark.sql import functions as F, Column from pyspark.sql.types import DoubleType +from pyspark.pandas.spark import functions as SF from pyspark.pandas.missing import unsupported_function from pyspark.pandas.config import get_option from pyspark.pandas.utils import name_like_string @@ -437,6 +438,37 @@ def get_fliers(colname, outliers, min_val): return fliers + @staticmethod + def get_multicol_fliers(colnames, multicol_outliers, multicol_whiskers): + scols = [] + extract_colnames = [] + for i, colname in enumerate(colnames): + formated_colname = "`{}`".format(colname) + outlier_colname = "__{}_outlier".format(colname) + min_val = multicol_whiskers[colname]["min"] + pair_col = F.struct( + F.abs(F.col(formated_colname) - F.lit(min_val)).alias("ord"), + F.col(formated_colname).alias("val"), + ) + scols.append( + SF.collect_top_k( + F.when(F.col(outlier_colname), pair_col) + .otherwise(F.lit(None)) + .alias(f"pair_{i}"), + 1001, + False, + ).alias(f"top_{i}") + ) + extract_colnames.append(f"top_{i}.val") + + results = multicol_outliers.select(scols).select(extract_colnames).first() + + fliers = {} + for i, colname in enumerate(colnames): + fliers[colname] = results[i] + + return fliers + class KdePlotBase(NumericPlotBase): @staticmethod diff --git a/python/pyspark/pandas/plot/plotly.py b/python/pyspark/pandas/plot/plotly.py index 4de313b1e831d..0afcd6d7e8696 100644 --- a/python/pyspark/pandas/plot/plotly.py +++ b/python/pyspark/pandas/plot/plotly.py @@ -199,11 +199,19 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs): # Computes min and max values of non-outliers - the whiskers whiskers = BoxPlotBase.calc_multicol_whiskers(numeric_column_names, outliers) + fliers = None + if boxpoints: + fliers = BoxPlotBase.get_multicol_fliers(numeric_column_names, outliers, whiskers) + i = 0 for colname in numeric_column_names: col_stats = multicol_stats[colname] col_whiskers = whiskers[colname] + col_fliers = None + if fliers is not None and colname in fliers and len(fliers[colname]) > 0: + col_fliers = [fliers[colname]] + fig.add_trace( go.Box( x=[i], @@ -214,7 +222,7 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs): mean=[col_stats["mean"]], lowerfence=[col_whiskers["min"]], upperfence=[col_whiskers["max"]], - y=None, # todo: support y=fliers + y=col_fliers, boxpoints=boxpoints, notched=notched, **kwargs, diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index 8abeff655ea50..6bef3d9b87c05 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -174,6 +174,19 @@ def null_index(col: Column) -> Column: return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc)) +def collect_top_k(col: Column, num: int, reverse: bool) -> Column: + if is_remote(): + from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns + + return _invoke_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse)) + + else: + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, reverse)) + + def make_interval(unit: str, e: Union[Column, int, float]) -> Column: unit_mapping = { "YEAR": "years", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 6b497553dcb0d..c1c9af2ea4273 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -149,6 +149,9 @@ private[sql] object PythonSQLUtils extends Logging { def nullIndex(e: Column): Column = Column.internalFn("null_index", e) + def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = + Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) + def pandasProduct(e: Column, ignoreNA: Boolean): Column = Column.internalFn("pandas_product", e, lit(ignoreNA))