Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49530][PYTHON][CONNECT] Support kde/density plots #48492

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions python/pyspark/sql/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,33 @@ def require_minimum_pyarrow_version() -> None:
errorClass="ARROW_LEGACY_IPC_FORMAT",
messageParameters={},
)


def require_minimum_numpy_version() -> None:
"""Raise ImportError if minimum version of NumPy is not installed"""
minimum_numpy_version = "1.21"

try:
import numpy

have_numpy = True
except ImportError as error:
have_numpy = False
raised_error = error
if not have_numpy:
raise PySparkImportError(
errorClass="PACKAGE_NOT_INSTALLED",
messageParameters={
"package_name": "NumPy",
"minimum_version": str(minimum_numpy_version),
},
) from raised_error
if LooseVersion(numpy.__version__) < LooseVersion(minimum_numpy_version):
raise PySparkImportError(
errorClass="UNSUPPORTED_PACKAGE_VERSION",
messageParameters={
"package_name": "NumPy",
"minimum_version": str(minimum_numpy_version),
"current_version": str(numpy.__version__),
},
)
131 changes: 130 additions & 1 deletion python/pyspark/sql/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,26 @@
# limitations under the License.
#

import math

from typing import Any, TYPE_CHECKING, List, Optional, Union
from types import ModuleType
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
from pyspark.errors import (
PySparkRuntimeError,
PySparkTypeError,
PySparkValueError,
)
from pyspark.sql import Column, functions as F
from pyspark.sql.types import NumericType
from pyspark.sql.utils import is_remote, require_minimum_plotly_version
from pandas.core.dtypes.inference import is_integer


if TYPE_CHECKING:
from pyspark.sql import DataFrame, Row
from pyspark.sql._typing import ColumnOrName
import pandas as pd
import numpy as np
from plotly.graph_objs import Figure


Expand Down Expand Up @@ -388,6 +396,127 @@ def box(
"""
return self(kind="box", column=column, precision=precision, **kwargs)

def kde(
self,
column: Union[str, List[str]],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create https://issues.apache.org/jira/browse/SPARK-49999 for a follow-up on optional "column" parameter support in box, kde and hist plots

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create https://issues.apache.org/jira/browse/SPARK-50000 for optional "bw_method" in both Pandas on Spark and PySpark

bw_method: Union[int, float],
ind: Union["np.ndarray", int, None] = None,
**kwargs: Any,
) -> "Figure":
"""
Generate Kernel Density Estimate plot using Gaussian kernels.
In statistics, kernel density estimation (KDE) is a non-parametric way to
estimate the probability density function (PDF) of a random variable. This
function uses Gaussian kernels and includes automatic bandwidth determination.
Parameters
----------
column: str or list of str
Column name or list of names to be used for creating the kde plot.
bw_method : int or float
The method used to calculate the estimator bandwidth.
See KernelDensity in PySpark for more information.
ind : NumPy array or integer, optional
Evaluation points for the estimated PDF. If None (default),
1000 equally spaced points are used. If `ind` is a NumPy array, the
KDE is evaluated at the points passed. If `ind` is an integer,
`ind` number of equally spaced points are used.
**kwargs : optional
Additional keyword arguments.
Returns
-------
:class:`plotly.graph_objs.Figure`
Examples
--------
>>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)]
>>> columns = ["length", "width", "species"]
>>> df = spark.createDataFrame(data, columns)
>>> df.plot.kde(column=["length", "width"], bw_method=0.3) # doctest: +SKIP
>>> df.plot.kde(column="length", bw_method=0.3) # doctest: +SKIP
"""
return self(kind="kde", column=column, bw_method=bw_method, ind=ind, **kwargs)


class PySparkKdePlotBase:
@staticmethod
def get_ind(sdf: "DataFrame", ind: Union["np.ndarray", int, None]) -> "np.ndarray":
from pyspark.sql.pandas.utils import require_minimum_numpy_version

require_minimum_numpy_version()
import numpy as np

def calc_min_max() -> "Row":
if len(sdf.columns) > 1:
min_col = F.least(*map(F.min, sdf)) # type: ignore
max_col = F.greatest(*map(F.max, sdf)) # type: ignore
else:
min_col = F.min(sdf.columns[-1])
max_col = F.max(sdf.columns[-1])
return sdf.select(min_col, max_col).first() # type: ignore

if ind is None:
min_val, max_val = calc_min_max()
sample_range = max_val - min_val
ind = np.linspace(
min_val - 0.5 * sample_range,
max_val + 0.5 * sample_range,
1000,
)
elif is_integer(ind):
min_val, max_val = calc_min_max()
sample_range = max_val - min_val
ind = np.linspace(
min_val - 0.5 * sample_range,
max_val + 0.5 * sample_range,
ind,
)
return ind # type: ignore

@staticmethod
def compute_kde_col(
input_col: Column,
bw_method: Union[int, float],
ind: "np.ndarray",
) -> Column:
# refers to org.apache.spark.mllib.stat.KernelDensity
assert bw_method is not None and isinstance(
bw_method, (int, float)
), "'bw_method' must be set as a scalar number."

assert ind is not None, "'ind' must be a scalar array."

bandwidth = float(bw_method)
points = [float(i) for i in ind]
log_std_plus_half_log2_pi = math.log(bandwidth) + 0.5 * math.log(2 * math.pi)

def norm_pdf(
mean: Column,
std: Column,
log_std_plus_half_log2_pi: Column,
x: Column,
) -> Column:
x0 = x - mean
x1 = x0 / std
log_density = -0.5 * x1 * x1 - log_std_plus_half_log2_pi
return F.exp(log_density)

return F.array(
[
F.avg(
norm_pdf(
input_col.cast("double"),
F.lit(bandwidth),
F.lit(log_std_plus_half_log2_pi),
F.lit(point),
)
)
for point in points
]
)


class PySparkBoxPlotBase:
@staticmethod
Expand Down
47 changes: 46 additions & 1 deletion python/pyspark/sql/plot/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import TYPE_CHECKING, Any

from pyspark.errors import PySparkValueError
from pyspark.sql.plot import PySparkPlotAccessor, PySparkBoxPlotBase
from pyspark.sql.plot import PySparkPlotAccessor, PySparkBoxPlotBase, PySparkKdePlotBase

if TYPE_CHECKING:
from pyspark.sql import DataFrame
Expand All @@ -32,6 +32,8 @@ def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure":
return plot_pie(data, **kwargs)
if kind == "box":
return plot_box(data, **kwargs)
if kind == "kde" or kind == "density":
return plot_kde(data, **kwargs)

return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs)

Expand Down Expand Up @@ -118,3 +120,46 @@ def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure":

fig["layout"]["yaxis"]["title"] = "value"
return fig


def plot_kde(data: "DataFrame", **kwargs: Any) -> "Figure":
from pyspark.sql.pandas.utils import require_minimum_pandas_version

require_minimum_pandas_version()

import pandas as pd
from plotly import express

if "color" not in kwargs:
kwargs["color"] = "names"

bw_method = kwargs.pop("bw_method", None)
colnames = kwargs.pop("column", None)
if isinstance(colnames, str):
colnames = [colnames]
ind = PySparkKdePlotBase.get_ind(data.select(*colnames), kwargs.pop("ind", None))

kde_cols = [
PySparkKdePlotBase.compute_kde_col(
input_col=data[col_name],
ind=ind,
bw_method=bw_method,
).alias(f"kde_{i}")
for i, col_name in enumerate(colnames)
]
kde_results = data.select(*kde_cols).first()
pdf = pd.concat(
[
pd.DataFrame( # type: ignore
{
"Density": kde_result,
"names": col_name,
"index": ind,
}
)
for col_name, kde_result in zip(colnames, list(kde_results)) # type: ignore[arg-type]
]
)
fig = express.line(pdf, x="index", y="Density", **kwargs)
fig["layout"]["xaxis"]["title"] = None
return fig
34 changes: 33 additions & 1 deletion python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@

import pyspark.sql.plot # noqa: F401
from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_plotly,
have_numpy,
plotly_requirement_message,
numpy_requirement_message,
)


@unittest.skipIf(not have_plotly, plotly_requirement_message)
Expand Down Expand Up @@ -375,6 +381,32 @@ def test_box_plot(self):
},
)

@unittest.skipIf(not have_numpy, numpy_requirement_message)
def test_kde_plot(self):
fig = self.sdf4.plot.kde(column="math_score", bw_method=0.3, ind=5)
expected_fig_data = {
"mode": "lines",
"name": "math_score",
"orientation": "v",
"xaxis": "x",
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)

fig = self.sdf4.plot.kde(column=["math_score", "english_score"], bw_method=0.3, ind=5)
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"mode": "lines",
"name": "english_score",
"orientation": "v",
"xaxis": "x",
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][1], **expected_fig_data)
self.assertEqual(list(fig["data"][0]["x"]), list(fig["data"][1]["x"]))


class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase):
pass
Expand Down
8 changes: 8 additions & 0 deletions python/pyspark/testing/sqlutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@
plotly_requirement_message = str(e)
have_plotly = plotly_requirement_message is None

numpy_requirement_message = None
try:
import numpy
except ImportError as e:
numpy_requirement_message = str(e)
have_numpy = numpy_requirement_message is None

from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils
Expand All @@ -63,6 +70,7 @@
have_pandas = pandas_requirement_message is None
have_pyarrow = pyarrow_requirement_message is None
test_compiled = test_not_compiled_message is None
have_numpy = numpy_requirement_message is None


class UTCOffsetTimezone(datetime.tzinfo):
Expand Down