diff --git a/python/pyspark/sql/pandas/utils.py b/python/pyspark/sql/pandas/utils.py index d080448cfd3a1..5849ae0edd6d9 100644 --- a/python/pyspark/sql/pandas/utils.py +++ b/python/pyspark/sql/pandas/utils.py @@ -94,3 +94,33 @@ def require_minimum_pyarrow_version() -> None: errorClass="ARROW_LEGACY_IPC_FORMAT", messageParameters={}, ) + + +def require_minimum_numpy_version() -> None: + """Raise ImportError if minimum version of NumPy is not installed""" + minimum_numpy_version = "1.21" + + try: + import numpy + + have_numpy = True + except ImportError as error: + have_numpy = False + raised_error = error + if not have_numpy: + raise PySparkImportError( + errorClass="PACKAGE_NOT_INSTALLED", + messageParameters={ + "package_name": "NumPy", + "minimum_version": str(minimum_numpy_version), + }, + ) from raised_error + if LooseVersion(numpy.__version__) < LooseVersion(minimum_numpy_version): + raise PySparkImportError( + errorClass="UNSUPPORTED_PACKAGE_VERSION", + messageParameters={ + "package_name": "NumPy", + "minimum_version": str(minimum_numpy_version), + "current_version": str(numpy.__version__), + }, + ) diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index 4bf75474d92c3..b654d4517f81b 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -15,18 +15,26 @@ # limitations under the License. # +import math + from typing import Any, TYPE_CHECKING, List, Optional, Union from types import ModuleType -from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError +from pyspark.errors import ( + PySparkRuntimeError, + PySparkTypeError, + PySparkValueError, +) from pyspark.sql import Column, functions as F from pyspark.sql.types import NumericType from pyspark.sql.utils import is_remote, require_minimum_plotly_version +from pandas.core.dtypes.inference import is_integer if TYPE_CHECKING: from pyspark.sql import DataFrame, Row from pyspark.sql._typing import ColumnOrName import pandas as pd + import numpy as np from plotly.graph_objs import Figure @@ -388,6 +396,127 @@ def box( """ return self(kind="box", column=column, precision=precision, **kwargs) + def kde( + self, + column: Union[str, List[str]], + bw_method: Union[int, float], + ind: Union["np.ndarray", int, None] = None, + **kwargs: Any, + ) -> "Figure": + """ + Generate Kernel Density Estimate plot using Gaussian kernels. + + In statistics, kernel density estimation (KDE) is a non-parametric way to + estimate the probability density function (PDF) of a random variable. This + function uses Gaussian kernels and includes automatic bandwidth determination. + + Parameters + ---------- + column: str or list of str + Column name or list of names to be used for creating the kde plot. + bw_method : int or float + The method used to calculate the estimator bandwidth. + See KernelDensity in PySpark for more information. + ind : NumPy array or integer, optional + Evaluation points for the estimated PDF. If None (default), + 1000 equally spaced points are used. If `ind` is a NumPy array, the + KDE is evaluated at the points passed. If `ind` is an integer, + `ind` number of equally spaced points are used. + **kwargs : optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] + >>> columns = ["length", "width", "species"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.kde(column=["length", "width"], bw_method=0.3) # doctest: +SKIP + >>> df.plot.kde(column="length", bw_method=0.3) # doctest: +SKIP + """ + return self(kind="kde", column=column, bw_method=bw_method, ind=ind, **kwargs) + + +class PySparkKdePlotBase: + @staticmethod + def get_ind(sdf: "DataFrame", ind: Union["np.ndarray", int, None]) -> "np.ndarray": + from pyspark.sql.pandas.utils import require_minimum_numpy_version + + require_minimum_numpy_version() + import numpy as np + + def calc_min_max() -> "Row": + if len(sdf.columns) > 1: + min_col = F.least(*map(F.min, sdf)) # type: ignore + max_col = F.greatest(*map(F.max, sdf)) # type: ignore + else: + min_col = F.min(sdf.columns[-1]) + max_col = F.max(sdf.columns[-1]) + return sdf.select(min_col, max_col).first() # type: ignore + + if ind is None: + min_val, max_val = calc_min_max() + sample_range = max_val - min_val + ind = np.linspace( + min_val - 0.5 * sample_range, + max_val + 0.5 * sample_range, + 1000, + ) + elif is_integer(ind): + min_val, max_val = calc_min_max() + sample_range = max_val - min_val + ind = np.linspace( + min_val - 0.5 * sample_range, + max_val + 0.5 * sample_range, + ind, + ) + return ind # type: ignore + + @staticmethod + def compute_kde_col( + input_col: Column, + bw_method: Union[int, float], + ind: "np.ndarray", + ) -> Column: + # refers to org.apache.spark.mllib.stat.KernelDensity + assert bw_method is not None and isinstance( + bw_method, (int, float) + ), "'bw_method' must be set as a scalar number." + + assert ind is not None, "'ind' must be a scalar array." + + bandwidth = float(bw_method) + points = [float(i) for i in ind] + log_std_plus_half_log2_pi = math.log(bandwidth) + 0.5 * math.log(2 * math.pi) + + def norm_pdf( + mean: Column, + std: Column, + log_std_plus_half_log2_pi: Column, + x: Column, + ) -> Column: + x0 = x - mean + x1 = x0 / std + log_density = -0.5 * x1 * x1 - log_std_plus_half_log2_pi + return F.exp(log_density) + + return F.array( + [ + F.avg( + norm_pdf( + input_col.cast("double"), + F.lit(bandwidth), + F.lit(log_std_plus_half_log2_pi), + F.lit(point), + ) + ) + for point in points + ] + ) + class PySparkBoxPlotBase: @staticmethod diff --git a/python/pyspark/sql/plot/plotly.py b/python/pyspark/sql/plot/plotly.py index 71d40720e874d..884ee1da28aa4 100644 --- a/python/pyspark/sql/plot/plotly.py +++ b/python/pyspark/sql/plot/plotly.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any from pyspark.errors import PySparkValueError -from pyspark.sql.plot import PySparkPlotAccessor, PySparkBoxPlotBase +from pyspark.sql.plot import PySparkPlotAccessor, PySparkBoxPlotBase, PySparkKdePlotBase if TYPE_CHECKING: from pyspark.sql import DataFrame @@ -32,6 +32,8 @@ def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure": return plot_pie(data, **kwargs) if kind == "box": return plot_box(data, **kwargs) + if kind == "kde" or kind == "density": + return plot_kde(data, **kwargs) return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs) @@ -118,3 +120,46 @@ def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure": fig["layout"]["yaxis"]["title"] = "value" return fig + + +def plot_kde(data: "DataFrame", **kwargs: Any) -> "Figure": + from pyspark.sql.pandas.utils import require_minimum_pandas_version + + require_minimum_pandas_version() + + import pandas as pd + from plotly import express + + if "color" not in kwargs: + kwargs["color"] = "names" + + bw_method = kwargs.pop("bw_method", None) + colnames = kwargs.pop("column", None) + if isinstance(colnames, str): + colnames = [colnames] + ind = PySparkKdePlotBase.get_ind(data.select(*colnames), kwargs.pop("ind", None)) + + kde_cols = [ + PySparkKdePlotBase.compute_kde_col( + input_col=data[col_name], + ind=ind, + bw_method=bw_method, + ).alias(f"kde_{i}") + for i, col_name in enumerate(colnames) + ] + kde_results = data.select(*kde_cols).first() + pdf = pd.concat( + [ + pd.DataFrame( # type: ignore + { + "Density": kde_result, + "names": col_name, + "index": ind, + } + ) + for col_name, kde_result in zip(colnames, list(kde_results)) # type: ignore[arg-type] + ] + ) + fig = express.line(pdf, x="index", y="Density", **kwargs) + fig["layout"]["xaxis"]["title"] = None + return fig diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index d870cdbf9959b..9764b4a277273 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -20,7 +20,13 @@ import pyspark.sql.plot # noqa: F401 from pyspark.errors import PySparkTypeError, PySparkValueError -from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_plotly, + have_numpy, + plotly_requirement_message, + numpy_requirement_message, +) @unittest.skipIf(not have_plotly, plotly_requirement_message) @@ -375,6 +381,32 @@ def test_box_plot(self): }, ) + @unittest.skipIf(not have_numpy, numpy_requirement_message) + def test_kde_plot(self): + fig = self.sdf4.plot.kde(column="math_score", bw_method=0.3, ind=5) + expected_fig_data = { + "mode": "lines", + "name": "math_score", + "orientation": "v", + "xaxis": "x", + "yaxis": "y", + "type": "scatter", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + + fig = self.sdf4.plot.kde(column=["math_score", "english_score"], bw_method=0.3, ind=5) + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "mode": "lines", + "name": "english_score", + "orientation": "v", + "xaxis": "x", + "yaxis": "y", + "type": "scatter", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) + self.assertEqual(list(fig["data"][0]["x"]), list(fig["data"][1]["x"])) + class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index 00ad40e68bd7c..dab382c37f42b 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -55,6 +55,13 @@ plotly_requirement_message = str(e) have_plotly = plotly_requirement_message is None +numpy_requirement_message = None +try: + import numpy +except ImportError as e: + numpy_requirement_message = str(e) +have_numpy = numpy_requirement_message is None + from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils @@ -63,6 +70,7 @@ have_pandas = pandas_requirement_message is None have_pyarrow = pyarrow_requirement_message is None test_compiled = test_not_compiled_message is None +have_numpy = numpy_requirement_message is None class UTCOffsetTimezone(datetime.tzinfo):