Skip to content

Commit

Permalink
[SPARK-49382][PS] Make frame box plot properly render the fliers/outl…
Browse files Browse the repository at this point in the history
…iers

### What changes were proposed in this pull request?
fliers/outliers was ignored in the initial implementation apache#36317

### Why are the changes needed?
feature parity for Pandas and Series box plot

### Does this PR introduce _any_ user-facing change?

```
import pyspark.pandas as ps
df = ps.DataFrame([[5.1, 3.5, 0], [4.9, 3.0, 0], [7.0, 3.2, 1], [6.4, 3.2, 1], [5.9, 3.0, 2], [100, 200, 300]], columns=['length', 'width', 'species'])
df.boxplot()
```

`df.length.plot.box()`
![image](https://github.com/user-attachments/assets/43da563c-5f68-4305-ad27-a4f04815dfd1)

before:
`df.boxplot()`
![image](https://github.com/user-attachments/assets/e25c2760-c12a-4801-a730-3987a020f889)

after:
`df.boxplot()`
![image](https://github.com/user-attachments/assets/c19f13b1-b9e4-423e-bcec-0c47c1c8df32)

### How was this patch tested?
CI and manually check

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47866 from zhengruifeng/plot_hist_fly.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
zhengruifeng authored and IvanK-db committed Sep 19, 2024
1 parent f3a430c commit 45b864a
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 1 deletion.
32 changes: 32 additions & 0 deletions python/pyspark/pandas/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from pyspark.sql import functions as F, Column
from pyspark.sql.types import DoubleType
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.missing import unsupported_function
from pyspark.pandas.config import get_option
from pyspark.pandas.utils import name_like_string
Expand Down Expand Up @@ -437,6 +438,37 @@ def get_fliers(colname, outliers, min_val):

return fliers

@staticmethod
def get_multicol_fliers(colnames, multicol_outliers, multicol_whiskers):
scols = []
extract_colnames = []
for i, colname in enumerate(colnames):
formated_colname = "`{}`".format(colname)
outlier_colname = "__{}_outlier".format(colname)
min_val = multicol_whiskers[colname]["min"]
pair_col = F.struct(
F.abs(F.col(formated_colname) - F.lit(min_val)).alias("ord"),
F.col(formated_colname).alias("val"),
)
scols.append(
SF.collect_top_k(
F.when(F.col(outlier_colname), pair_col)
.otherwise(F.lit(None))
.alias(f"pair_{i}"),
1001,
False,
).alias(f"top_{i}")
)
extract_colnames.append(f"top_{i}.val")

results = multicol_outliers.select(scols).select(extract_colnames).first()

fliers = {}
for i, colname in enumerate(colnames):
fliers[colname] = results[i]

return fliers


class KdePlotBase(NumericPlotBase):
@staticmethod
Expand Down
10 changes: 9 additions & 1 deletion python/pyspark/pandas/plot/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,19 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
# Computes min and max values of non-outliers - the whiskers
whiskers = BoxPlotBase.calc_multicol_whiskers(numeric_column_names, outliers)

fliers = None
if boxpoints:
fliers = BoxPlotBase.get_multicol_fliers(numeric_column_names, outliers, whiskers)

i = 0
for colname in numeric_column_names:
col_stats = multicol_stats[colname]
col_whiskers = whiskers[colname]

col_fliers = None
if fliers is not None and colname in fliers and len(fliers[colname]) > 0:
col_fliers = [fliers[colname]]

fig.add_trace(
go.Box(
x=[i],
Expand All @@ -214,7 +222,7 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
mean=[col_stats["mean"]],
lowerfence=[col_whiskers["min"]],
upperfence=[col_whiskers["max"]],
y=None, # todo: support y=fliers
y=col_fliers,
boxpoints=boxpoints,
notched=notched,
**kwargs,
Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,19 @@ def null_index(col: Column) -> Column:
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))


def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns

return _invoke_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse))

else:
from pyspark import SparkContext

sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, reverse))


def make_interval(unit: str, e: Union[Column, int, float]) -> Column:
unit_mapping = {
"YEAR": "years",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ private[sql] object PythonSQLUtils extends Logging {

def nullIndex(e: Column): Column = Column.internalFn("null_index", e)

def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
Column.internalFn("collect_top_k", e, lit(num), lit(reverse))

def pandasProduct(e: Column, ignoreNA: Boolean): Column =
Column.internalFn("pandas_product", e, lit(ignoreNA))

Expand Down

0 comments on commit 45b864a

Please sign in to comment.