Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35343][PYTHON] Make the conversion from/to pandas data-type-based for non-ExtensionDtypes #32592

Closed
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,10 @@ def __hash__(self):
"pyspark.pandas.tests.data_type_ops.test_complex_ops",
"pyspark.pandas.tests.data_type_ops.test_date_ops",
"pyspark.pandas.tests.data_type_ops.test_datetime_ops",
"pyspark.pandas.tests.data_type_ops.test_null_ops",
"pyspark.pandas.tests.data_type_ops.test_num_ops",
"pyspark.pandas.tests.data_type_ops.test_string_ops",
"pyspark.pandas.tests.data_type_ops.test_udt_ops",
"pyspark.pandas.tests.indexes.test_category",
"pyspark.pandas.tests.plot.test_frame_plot",
"pyspark.pandas.tests.plot.test_frame_plot_matplotlib",
Expand Down
26 changes: 23 additions & 3 deletions python/pyspark/pandas/data_type_ops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
#

import numbers
from abc import ABCMeta, abstractmethod
from abc import ABCMeta
from typing import Any, TYPE_CHECKING, Union

import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype

from pyspark.sql.types import (
Expand All @@ -30,14 +32,15 @@
FractionalType,
IntegralType,
MapType,
NullType,
NumericType,
StringType,
StructType,
TimestampType,
UserDefinedType,
)

import pyspark.sql.types as types
from pyspark.pandas.base import IndexOpsMixin
from pyspark.pandas.typedef import Dtype

if TYPE_CHECKING:
Expand All @@ -47,6 +50,8 @@

def is_valid_operand_for_numeric_arithmetic(operand: Any, *, allow_bool: bool = True) -> bool:
"""Check whether the operand is valid for arithmetic operations against numerics."""
from pyspark.pandas.base import IndexOpsMixin

if isinstance(operand, numbers.Number) and not isinstance(operand, bool):
return True
elif isinstance(operand, IndexOpsMixin):
Expand All @@ -66,6 +71,8 @@ def transform_boolean_operand_to_numeric(operand: Any, spark_type: types.DataTyp
Return the transformed operand if the operand is a boolean IndexOpsMixin,
otherwise return the original operand.
"""
from pyspark.pandas.base import IndexOpsMixin

if isinstance(operand, IndexOpsMixin) and isinstance(operand.spark.data_type, BooleanType):
return operand.spark.transform(lambda scol: scol.cast(spark_type))
else:
Expand All @@ -82,11 +89,13 @@ def __new__(cls, dtype: Dtype, spark_type: DataType):
from pyspark.pandas.data_type_ops.complex_ops import ArrayOps, MapOps, StructOps
from pyspark.pandas.data_type_ops.date_ops import DateOps
from pyspark.pandas.data_type_ops.datetime_ops import DatetimeOps
from pyspark.pandas.data_type_ops.null_ops import NullOps
from pyspark.pandas.data_type_ops.num_ops import (
IntegralOps,
FractionalOps,
)
from pyspark.pandas.data_type_ops.string_ops import StringOps
from pyspark.pandas.data_type_ops.udt_ops import UDTOps

if isinstance(dtype, CategoricalDtype):
return object.__new__(CategoricalOps)
Expand All @@ -110,6 +119,10 @@ def __new__(cls, dtype: Dtype, spark_type: DataType):
return object.__new__(MapOps)
elif isinstance(spark_type, StructType):
return object.__new__(StructOps)
elif isinstance(spark_type, NullType):
return object.__new__(NullOps)
elif isinstance(spark_type, UserDefinedType):
return object.__new__(UDTOps)
else:
raise TypeError("Type %s was not understood." % dtype)

Expand All @@ -118,7 +131,6 @@ def __init__(self, dtype: Dtype, spark_type: DataType):
self.spark_type = spark_type

@property
@abstractmethod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abstractmethod is removed in order to pass mypy checks.
Otherwise,

mypy checks failed:
python/pyspark/pandas/internal.py:1050: error: Cannot instantiate abstract class 'DataTypeOps' with abstract attribute 'pretty_name'
python/pyspark/pandas/internal.py:1441: error: Cannot instantiate abstract class 'DataTypeOps' with abstract attribute 'pretty_name'

Reference mypy issue: python/mypy#1843.

def pretty_name(self) -> str:
raise NotImplementedError()

Expand Down Expand Up @@ -163,3 +175,11 @@ def rmod(self, left, right) -> Union["Series", "Index"]:

def rpow(self, left, right) -> Union["Series", "Index"]:
raise TypeError("Exponentiation can not be applied to %s." % self.pretty_name)

def restore(self, col: pd.Series) -> pd.Series:
"""Restore column when to_pandas."""
return col

def prepare(self, col: pd.Series) -> pd.Series:
"""Prepare column when from_pandas."""
return col.replace({np.nan: None})
12 changes: 12 additions & 0 deletions python/pyspark/pandas/data_type_ops/categorical_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# limitations under the License.
#

import pandas as pd

from pyspark.pandas.data_type_ops.base import DataTypeOps


Expand All @@ -26,3 +28,13 @@ class CategoricalOps(DataTypeOps):
@property
def pretty_name(self) -> str:
return "categoricals"

def restore(self, col: pd.Series) -> pd.Series:
"""Restore column when to_pandas."""
return pd.Categorical.from_codes(
col, categories=self.dtype.categories, ordered=self.dtype.ordered
)

def prepare(self, col: pd.Series) -> pd.Series:
"""Prepare column when from_pandas."""
return col.cat.codes
4 changes: 4 additions & 0 deletions python/pyspark/pandas/data_type_ops/datetime_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,7 @@ def rsub(self, left, right) -> Union["Series", "Index"]:
)
else:
raise TypeError("datetime subtraction can only be applied to datetime series.")

def prepare(self, col):
"""Prepare column when from_pandas."""
return col
28 changes: 28 additions & 0 deletions python/pyspark/pandas/data_type_ops/null_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.pandas.data_type_ops.base import DataTypeOps


class NullOps(DataTypeOps):
"""
The class for binary operations of pandas-on-Spark objects with Spark type: NullType.
"""

@property
def pretty_name(self) -> str:
return "nulls"
29 changes: 29 additions & 0 deletions python/pyspark/pandas/data_type_ops/udt_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.pandas.data_type_ops.base import DataTypeOps


class UDTOps(DataTypeOps):
"""
The class for binary operations of pandas-on-Spark objects with Spark type:
UserDefinedType or its subclasses.
"""

@property
def pretty_name(self) -> str:
return "user defined types"
67 changes: 37 additions & 30 deletions python/pyspark/pandas/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,20 @@

import numpy as np
import pandas as pd
from pandas.api.types import CategoricalDtype, is_datetime64_dtype, is_datetime64tz_dtype
from pandas.api.types import CategoricalDtype # noqa: F401
from pyspark import sql as spark
from pyspark._globals import _NoValue, _NoValueType
from pyspark.sql import functions as F, Window
from pyspark.sql.functions import PandasUDFType, pandas_udf
from pyspark.sql.types import BooleanType, DataType, StructField, StructType, LongType
from pyspark.sql.types import ( # noqa: F401
BooleanType,
DataType,
IntegralType,
LongType,
StructField,
StructType,
StringType,
)

# For running doctests and reference resolution in PyCharm.
from pyspark import pandas as ps # noqa: F401
Expand All @@ -39,6 +47,7 @@
# This is required in old Python 3.5 to prevent circular reference.
from pyspark.pandas.series import Series # noqa: F401 (SPARK-34943)
from pyspark.pandas.config import get_option
from pyspark.pandas.data_type_ops.base import DataTypeOps
from pyspark.pandas.typedef import (
Dtype,
as_spark_type,
Expand Down Expand Up @@ -951,11 +960,11 @@ def arguments_for_restore_index(self) -> Dict:
for col, dtype in zip(self.index_spark_column_names, self.index_dtypes)
if isinstance(dtype, extension_dtypes)
}
categorical_dtypes = {
col: dtype
for col, dtype in zip(self.index_spark_column_names, self.index_dtypes)
if isinstance(dtype, CategoricalDtype)
}
dtypes = [dtype for dtype in self.index_dtypes]
spark_types = [
self.spark_frame.select(scol).schema[0].dataType for scol in self.index_spark_columns
]

for spark_column, column_name, dtype in zip(
self.data_spark_columns, self.data_spark_column_names, self.data_dtypes
):
Expand All @@ -969,8 +978,8 @@ def arguments_for_restore_index(self) -> Dict:
column_names.append(column_name)
if isinstance(dtype, extension_dtypes):
ext_dtypes[column_name] = dtype
elif isinstance(dtype, CategoricalDtype):
categorical_dtypes[column_name] = dtype
dtypes.append(dtype)
spark_types.append(self.spark_frame.select(spark_column).schema[0].dataType)

return dict(
index_columns=self.index_spark_column_names,
Expand All @@ -979,7 +988,8 @@ def arguments_for_restore_index(self) -> Dict:
column_labels=self.column_labels,
column_label_names=self.column_label_names,
ext_dtypes=ext_dtypes,
categorical_dtypes=categorical_dtypes,
dtypes=dtypes,
spark_types=spark_types,
)

@staticmethod
Expand All @@ -991,8 +1001,9 @@ def restore_index(
data_columns: List[str],
column_labels: List[Tuple],
column_label_names: List[Tuple],
dtypes: List[Dtype],
spark_types: List[DataType],
ext_dtypes: Dict[str, Dtype] = None,
categorical_dtypes: Dict[str, CategoricalDtype] = None
) -> pd.DataFrame:
"""
Restore pandas DataFrame indices using the metadata.
Expand All @@ -1003,10 +1014,12 @@ def restore_index(
:param data_columns: the original column names for data columns.
:param column_labels: the column labels after restored.
:param column_label_names: the column label names after restored.
:param dtypes: the dtypes after restored.
:param spark_types: the spark_types.
:param ext_dtypes: the map from the original column names to extension data types.
:param categorical_dtypes: the map from the original column names to categorical types.
:return: the restored pandas DataFrame

>>> from numpy import dtype
>>> pdf = pd.DataFrame({"index": [10, 20, 30], "a": ['a', 'b', 'c'], "b": [0, 2, 1]})
>>> InternalFrame.restore_index(
... pdf,
Expand All @@ -1015,8 +1028,9 @@ def restore_index(
... data_columns=["a", "b", "index"],
... column_labels=[("x",), ("y",), ("z",)],
... column_label_names=[("lv1",)],
... ext_dtypes=None,
... categorical_dtypes={"b": CategoricalDtype(categories=["i", "j", "k"])}
... dtypes=[dtype('int64'), dtype('object'),
... CategoricalDtype(categories=["i", "j", "k"]), dtype('int64')],
... spark_types=[LongType(), StringType(), StringType(), LongType()]
... ) # doctest: +NORMALIZE_WHITESPACE
lv1 x y z
idx
Expand All @@ -1027,11 +1041,8 @@ def restore_index(
if ext_dtypes is not None and len(ext_dtypes) > 0:
pdf = pdf.astype(ext_dtypes, copy=True)

if categorical_dtypes is not None:
for col, dtype in categorical_dtypes.items():
pdf[col] = pd.Categorical.from_codes(
pdf[col], categories=dtype.categories, ordered=dtype.ordered
)
for col, expected_dtype, spark_type in zip(pdf.columns, dtypes, spark_types):
pdf[col] = DataTypeOps(expected_dtype, spark_type).restore(pdf[col])

append = False
for index_field in index_columns:
Expand Down Expand Up @@ -1071,7 +1082,7 @@ def with_new_sdf(
*,
index_dtypes: Optional[List[Dtype]] = None,
data_columns: Optional[List[str]] = None,
data_dtypes: Optional[List[Dtype]] = None
data_dtypes: Optional[List[Dtype]] = None,
) -> "InternalFrame":
"""Copy the immutable InternalFrame with the updates by the specified Spark DataFrame.

Expand Down Expand Up @@ -1121,7 +1132,7 @@ def with_new_columns(
column_labels: Optional[List[Tuple]] = None,
data_dtypes: Optional[List[Dtype]] = None,
column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue,
keep_order: bool = True
keep_order: bool = True,
) -> "InternalFrame":
"""
Copy the immutable InternalFrame with the updates by the specified Spark Columns or Series.
Expand Down Expand Up @@ -1225,7 +1236,7 @@ def with_new_spark_column(
scol: spark.Column,
*,
dtype: Optional[Dtype] = None,
keep_order: bool = True
keep_order: bool = True,
) -> "InternalFrame":
"""
Copy the immutable InternalFrame with the updates by the specified Spark Column.
Expand Down Expand Up @@ -1273,7 +1284,7 @@ def copy(
column_labels: Union[Optional[List[Tuple]], _NoValueType] = _NoValue,
data_spark_columns: Union[Optional[List[spark.Column]], _NoValueType] = _NoValue,
data_dtypes: Union[Optional[List[Dtype]], _NoValueType] = _NoValue,
column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue
column_label_names: Union[Optional[List[Optional[Tuple]]], _NoValueType] = _NoValue,
) -> "InternalFrame":
"""Copy the immutable InternalFrame.

Expand Down Expand Up @@ -1423,13 +1434,9 @@ def prepare_pandas_frame(
index_dtypes = list(reset_index.dtypes)[:index_nlevels]
data_dtypes = list(reset_index.dtypes)[index_nlevels:]

for name, col in reset_index.iteritems():
dt = col.dtype
if is_datetime64_dtype(dt) or is_datetime64tz_dtype(dt):
continue
elif isinstance(dt, CategoricalDtype):
col = col.cat.codes
reset_index[name] = col.replace({np.nan: None})
for col, dtype in zip(reset_index.columns, reset_index.dtypes):
spark_type = infer_pd_series_spark_type(reset_index[col], dtype)
reset_index[col] = DataTypeOps(dtype, spark_type).prepare(reset_index[col])

return reset_index, index_columns, index_dtypes, data_columns, data_dtypes

Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/pandas/tests/data_type_ops/test_binary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ def test_rpow(self):
self.assertRaises(TypeError, lambda: "x" ** self.psser)
self.assertRaises(TypeError, lambda: 1 ** self.psser)

def test_from_to_pandas(self):
data = [b"1", b"2", b"3"]
pser = pd.Series(data)
psser = ps.Series(data)
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)


if __name__ == "__main__":
import unittest
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/pandas/tests/data_type_ops/test_boolean_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,13 @@ def test_rmod(self):
self.assertRaises(TypeError, lambda: datetime.date(1994, 1, 1) % self.psser)
self.assertRaises(TypeError, lambda: True % self.psser)

def test_from_to_pandas(self):
data = [True, True, False]
pser = pd.Series(data)
psser = ps.Series(data)
self.assert_eq(pser, psser.to_pandas())
self.assert_eq(ps.from_pandas(pser), psser)


if __name__ == "__main__":
import unittest
Expand Down
Loading