Skip to content

Commit

Permalink
[SPARK-49929][PYTHON][CONNECT] Support box plots
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support box plots with plotly backend on both Spark Connect and Spark classic.

### Why are the changes needed?
While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments.

See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress.

Part of https://issues.apache.org/jira/browse/SPARK-49530.

### Does this PR introduce _any_ user-facing change?
Yes. Box plots are supported as shown below.

```py
>>> data = [
...             ("A", 50, 55),
...             ("B", 55, 60),
...             ("C", 60, 65),
...             ("D", 65, 70),
...             ("E", 70, 75),
...             # outliers
...             ("F", 10, 15),
...             ("G", 85, 90),
...             ("H", 5, 150),
...         ]
>>> columns = ["student", "math_score", "english_score"]
>>> sdf = spark.createDataFrame(data, columns)
>>> fig1 = sdf.plot.box(column=["math_score", "english_score"])
>>> fig1.show()  # see below
>>> fig2 = sdf.plot(kind="box", column="math_score")
>>> fig2.show()  # see below
```

fig1:
![newplot (17)](https://github.com/user-attachments/assets/8c36c344-f6de-47e3-bd63-c0f3b57efc43)

fig2:
![newplot (18)](https://github.com/user-attachments/assets/9b7b60f6-58ec-4eff-9544-d5ab88a88631)

### How was this patch tested?
Unit tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48447 from xinrong-meng/box.

Authored-by: Xinrong Meng <xinrong@apache.org>
Signed-off-by: Xinrong Meng <xinrong@apache.org>
  • Loading branch information
xinrong-meng committed Oct 15, 2024
1 parent 74aed77 commit 488f680
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 5 deletions.
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,11 @@
"`<backend>` is not supported, it should be one of the values from <supported_backends>"
]
},
"UNSUPPORTED_PLOT_BACKEND_PARAM": {
"message": [
"`<backend>` does not support `<param>` set to <value>, it should be one of the values from <supported_values>"
]
},
"UNSUPPORTED_SIGNATURE": {
"message": [
"Unsupported signature: <signature>."
Expand Down
153 changes: 150 additions & 3 deletions python/pyspark/sql/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@
# limitations under the License.
#

from typing import Any, TYPE_CHECKING, Optional, Union
from typing import Any, TYPE_CHECKING, List, Optional, Union
from types import ModuleType
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
from pyspark.sql import Column, functions as F
from pyspark.sql.types import NumericType
from pyspark.sql.utils import require_minimum_plotly_version
from pyspark.sql.utils import is_remote, require_minimum_plotly_version


if TYPE_CHECKING:
from pyspark.sql import DataFrame
from pyspark.sql import DataFrame, Row
from pyspark.sql._typing import ColumnOrName
import pandas as pd
from plotly.graph_objs import Figure

Expand Down Expand Up @@ -338,3 +340,148 @@ def pie(self, x: str, y: str, **kwargs: Any) -> "Figure":
},
)
return self(kind="pie", x=x, y=y, **kwargs)

def box(
self, column: Union[str, List[str]], precision: float = 0.01, **kwargs: Any
) -> "Figure":
"""
Make a box plot of the DataFrame columns.
Make a box-and-whisker plot from DataFrame columns, optionally grouped by some
other columns. A box plot is a method for graphically depicting groups of numerical
data through their quartiles. The box extends from the Q1 to Q3 quartile values of
the data, with a line at the median (Q2). The whiskers extend from the edges of box
to show the range of the data. By default, they extend no more than
1.5 * IQR (IQR = Q3 - Q1) from the edges of the box, ending at the farthest data point
within that interval. Outliers are plotted as separate dots.
Parameters
----------
column: str or list of str
Column name or list of names to be used for creating the boxplot.
precision: float, default = 0.01
This argument is used by pyspark to compute approximate statistics
for building a boxplot.
**kwargs
Additional keyword arguments.
Returns
-------
:class:`plotly.graph_objs.Figure`
Examples
--------
>>> data = [
... ("A", 50, 55),
... ("B", 55, 60),
... ("C", 60, 65),
... ("D", 65, 70),
... ("E", 70, 75),
... ("F", 10, 15),
... ("G", 85, 90),
... ("H", 5, 150),
... ]
>>> columns = ["student", "math_score", "english_score"]
>>> df = spark.createDataFrame(data, columns)
>>> df.plot.box(column="math_score") # doctest: +SKIP
>>> df.plot.box(column=["math_score", "english_score"]) # doctest: +SKIP
"""
return self(kind="box", column=column, precision=precision, **kwargs)


class PySparkBoxPlotBase:
@staticmethod
def compute_box(
sdf: "DataFrame", colnames: List[str], whis: float, precision: float, showfliers: bool
) -> Optional["Row"]:
assert len(colnames) > 0
formatted_colnames = ["`{}`".format(colname) for colname in colnames]

stats_scols = []
for i, colname in enumerate(formatted_colnames):
percentiles = F.percentile_approx(colname, [0.25, 0.50, 0.75], int(1.0 / precision))
q1 = F.get(percentiles, 0)
med = F.get(percentiles, 1)
q3 = F.get(percentiles, 2)
iqr = q3 - q1
lfence = q1 - F.lit(whis) * iqr
ufence = q3 + F.lit(whis) * iqr

stats_scols.append(
F.struct(
F.mean(colname).alias("mean"),
med.alias("med"),
q1.alias("q1"),
q3.alias("q3"),
lfence.alias("lfence"),
ufence.alias("ufence"),
).alias(f"_box_plot_stats_{i}")
)

sdf_stats = sdf.select(*stats_scols)

result_scols = []
for i, colname in enumerate(formatted_colnames):
value = F.col(colname)

lfence = F.col(f"_box_plot_stats_{i}.lfence")
ufence = F.col(f"_box_plot_stats_{i}.ufence")
mean = F.col(f"_box_plot_stats_{i}.mean")
med = F.col(f"_box_plot_stats_{i}.med")
q1 = F.col(f"_box_plot_stats_{i}.q1")
q3 = F.col(f"_box_plot_stats_{i}.q3")

outlier = ~value.between(lfence, ufence)

# Computes min and max values of non-outliers - the whiskers
upper_whisker = F.max(F.when(~outlier, value).otherwise(F.lit(None)))
lower_whisker = F.min(F.when(~outlier, value).otherwise(F.lit(None)))

# If it shows fliers, take the top 1k with the highest absolute values
# Here we normalize the values by subtracting the median.
if showfliers:
pair = F.when(
outlier,
F.struct(F.abs(value - med), value.alias("val")),
).otherwise(F.lit(None))
topk = collect_top_k(pair, 1001, False)
fliers = F.when(F.size(topk) > 0, topk["val"]).otherwise(F.lit(None))
else:
fliers = F.lit(None)

result_scols.append(
F.struct(
F.first(mean).alias("mean"),
F.first(med).alias("med"),
F.first(q1).alias("q1"),
F.first(q3).alias("q3"),
upper_whisker.alias("upper_whisker"),
lower_whisker.alias("lower_whisker"),
fliers.alias("fliers"),
).alias(f"_box_plot_results_{i}")
)

sdf_result = sdf.join(sdf_stats.hint("broadcast")).select(*result_scols)
return sdf_result.first()


def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns

return _invoke_function_over_columns(name, *cols)

else:
from pyspark.sql.classic.column import _to_seq, _to_java_column
from pyspark import SparkContext

sc = SparkContext._active_spark_context
return Column(
sc._jvm.PythonSQLUtils.internalFn( # type: ignore
name, _to_seq(sc, cols, _to_java_column) # type: ignore
)
)


def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
return _invoke_internal_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse))
77 changes: 76 additions & 1 deletion python/pyspark/sql/plot/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from typing import TYPE_CHECKING, Any

from pyspark.sql.plot import PySparkPlotAccessor
from pyspark.errors import PySparkValueError
from pyspark.sql.plot import PySparkPlotAccessor, PySparkBoxPlotBase

if TYPE_CHECKING:
from pyspark.sql import DataFrame
Expand All @@ -29,6 +30,8 @@ def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure":

if kind == "pie":
return plot_pie(data, **kwargs)
if kind == "box":
return plot_box(data, **kwargs)

return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs)

Expand All @@ -43,3 +46,75 @@ def plot_pie(data: "DataFrame", **kwargs: Any) -> "Figure":
fig = express.pie(pdf, values=y, names=x, **kwargs)

return fig


def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure":
import plotly.graph_objs as go

# 'whis' isn't actually an argument in plotly (but in matplotlib). But seems like
# plotly doesn't expose the reach of the whiskers to the beyond the first and
# third quartiles (?). Looks they use default 1.5.
whis = kwargs.pop("whis", 1.5)
# 'precision' is pyspark specific to control precision for approx_percentile
precision = kwargs.pop("precision", 0.01)
colnames = kwargs.pop("column", None)
if isinstance(colnames, str):
colnames = [colnames]

# Plotly options
boxpoints = kwargs.pop("boxpoints", "suspectedoutliers")
notched = kwargs.pop("notched", False)
if boxpoints not in ["suspectedoutliers", False]:
raise PySparkValueError(
errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM",
messageParameters={
"backend": "plotly",
"param": "boxpoints",
"value": str(boxpoints),
"supported_values": ", ".join(["suspectedoutliers", "False"]),
},
)
if notched:
raise PySparkValueError(
errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM",
messageParameters={
"backend": "plotly",
"param": "notched",
"value": str(notched),
"supported_values": ", ".join(["False"]),
},
)

fig = go.Figure()

results = PySparkBoxPlotBase.compute_box(
data,
colnames,
whis,
precision,
boxpoints is not None,
)
assert len(results) == len(colnames) # type: ignore

for i, colname in enumerate(colnames):
result = results[i] # type: ignore

fig.add_trace(
go.Box(
x=[i],
name=colname,
q1=[result["q1"]],
median=[result["med"]],
q3=[result["q3"]],
mean=[result["mean"]],
lowerfence=[result["lower_whisker"]],
upperfence=[result["upper_whisker"]],
y=[result["fliers"]] if result["fliers"] else None,
boxpoints=boxpoints,
notched=notched,
**kwargs,
)
)

fig["layout"]["yaxis"]["title"] = "value"
return fig
77 changes: 76 additions & 1 deletion python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from datetime import datetime

import pyspark.sql.plot # noqa: F401
from pyspark.errors import PySparkTypeError
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message


Expand Down Expand Up @@ -48,6 +48,22 @@ def sdf3(self):
columns = ["sales", "signups", "visits", "date"]
return self.spark.createDataFrame(data, columns)

@property
def sdf4(self):
data = [
("A", 50, 55),
("B", 55, 60),
("C", 60, 65),
("D", 65, 70),
("E", 70, 75),
# outliers
("F", 10, 15),
("G", 85, 90),
("H", 5, 150),
]
columns = ["student", "math_score", "english_score"]
return self.spark.createDataFrame(data, columns)

def _check_fig_data(self, fig_data, **kwargs):
for key, expected_value in kwargs.items():
if key in ["x", "y", "labels", "values"]:
Expand Down Expand Up @@ -300,6 +316,65 @@ def test_pie_plot(self):
messageParameters={"arg_name": "y", "arg_type": "StringType()"},
)

def test_box_plot(self):
fig = self.sdf4.plot.box(column="math_score")
expected_fig_data = {
"boxpoints": "suspectedoutliers",
"lowerfence": (5,),
"mean": (50.0,),
"median": (55,),
"name": "math_score",
"notched": False,
"q1": (10,),
"q3": (65,),
"upperfence": (85,),
"x": [0],
"type": "box",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)

fig = self.sdf4.plot(kind="box", column=["math_score", "english_score"])
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"boxpoints": "suspectedoutliers",
"lowerfence": (55,),
"mean": (72.5,),
"median": (65,),
"name": "english_score",
"notched": False,
"q1": (55,),
"q3": (75,),
"upperfence": (90,),
"x": [1],
"y": [[150, 15]],
"type": "box",
}
self._check_fig_data(fig["data"][1], **expected_fig_data)
with self.assertRaises(PySparkValueError) as pe:
self.sdf4.plot.box(column="math_score", boxpoints=True)
self.check_error(
exception=pe.exception,
errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM",
messageParameters={
"backend": "plotly",
"param": "boxpoints",
"value": "True",
"supported_values": ", ".join(["suspectedoutliers", "False"]),
},
)
with self.assertRaises(PySparkValueError) as pe:
self.sdf4.plot.box(column="math_score", notched=True)
self.check_error(
exception=pe.exception,
errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM",
messageParameters={
"backend": "plotly",
"param": "notched",
"value": "True",
"supported_values": ", ".join(["False"]),
},
)


class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase):
pass
Expand Down

0 comments on commit 488f680

Please sign in to comment.