From b14729578d1c3716c283c925b70092f757e0ce2b Mon Sep 17 00:00:00 2001 From: Wenjun Si Date: Wed, 8 Dec 2021 09:53:26 +0800 Subject: [PATCH] [BACKPORT] Fix tests for cudf 21.10 (#2574) (#2608) --- mars/dataframe/reduction/aggregation.py | 7 +++- mars/serialization/cuda.py | 50 ++++++++++++++++++++++++- mars/serialization/tests/test_serial.py | 19 ++++++---- 3 files changed, 65 insertions(+), 11 deletions(-) diff --git a/mars/dataframe/reduction/aggregation.py b/mars/dataframe/reduction/aggregation.py index f7708f6ba7..e4d42f0bac 100644 --- a/mars/dataframe/reduction/aggregation.py +++ b/mars/dataframe/reduction/aggregation.py @@ -710,7 +710,12 @@ def _wrap_df(cls, op, value, index=None): elif not isinstance(value, xdf.DataFrame): new_index = None if not op.gpu else getattr(value, "index", None) dtype = getattr(value, "dtype", None) - value = xdf.DataFrame(value, columns=index, index=new_index) + if xdf is pd: + value = xdf.DataFrame(value, columns=index, index=new_index) + else: # pragma: no cover + value = xdf.DataFrame(value) + value.index = new_index + value.columns = index else: return value diff --git a/mars/serialization/cuda.py b/mars/serialization/cuda.py index 38226cf401..fbb8e59a7a 100644 --- a/mars/serialization/cuda.py +++ b/mars/serialization/cuda.py @@ -14,6 +14,8 @@ from typing import Any, List, Dict +import pandas as pd + from ..utils import lazy_import from .core import Serializer, buffered @@ -49,13 +51,57 @@ def deserialize(self, header: Dict, buffers: List, context: Dict): class CudfSerializer(Serializer): serializer_name = "cudf" + @staticmethod + def _get_ext_index_type(index_obj): + import cudf + + multi_index_type = None + if isinstance(index_obj, pd.MultiIndex): + multi_index_type = "pandas" + elif isinstance(index_obj, cudf.MultiIndex): + multi_index_type = "cudf" + + if multi_index_type is None: + return None + return { + "index_type": multi_index_type, + "names": list(index_obj.names), + } + + @staticmethod + def _apply_index_type(obj, attr, header): + import cudf + + multi_index_cls = ( + pd.MultiIndex if header["index_type"] == "pandas" else cudf.MultiIndex + ) + original_index = getattr(obj, attr) + if isinstance(original_index, (pd.MultiIndex, cudf.MultiIndex)): + return + new_index = multi_index_cls.from_tuples(original_index, names=header["names"]) + setattr(obj, attr, new_index) + def serialize(self, obj: Any, context: Dict): - return obj.device_serialize() + header, buffers = obj.device_serialize() + if hasattr(obj, "columns"): + header["_ext_columns"] = self._get_ext_index_type(obj.columns) + if hasattr(obj, "index"): + header["_ext_index"] = self._get_ext_index_type(obj.index) + return header, buffers def deserialize(self, header: Dict, buffers: List, context: Dict): from cudf.core.abc import Serializable - return Serializable.device_deserialize(header, buffers) + col_header = header.pop("_ext_columns", None) + index_header = header.pop("_ext_index", None) + + result = Serializable.device_deserialize(header, buffers) + + if col_header is not None: + self._apply_index_type(result, "columns", col_header) + if index_header is not None: + self._apply_index_type(result, "index", index_header) + return result if cupy is not None: diff --git a/mars/serialization/tests/test_serial.py b/mars/serialization/tests/test_serial.py index 8764dd0b1d..70e575e3ff 100644 --- a/mars/serialization/tests/test_serial.py +++ b/mars/serialization/tests/test_serial.py @@ -168,15 +168,18 @@ def test_cupy(np_val): @require_cudf def test_cudf(): - test_df = cudf.DataFrame( - pd.DataFrame( - { - "a": np.random.rand(1000), - "b": np.random.choice(list("abcd"), size=(1000,)), - "c": np.random.randint(0, 100, size=(1000,)), - } - ) + raw_df = pd.DataFrame( + { + "a": np.random.rand(1000), + "b": np.random.choice(list("abcd"), size=(1000,)), + "c": np.random.randint(0, 100, size=(1000,)), + } ) + test_df = cudf.DataFrame(raw_df) + cudf.testing.assert_frame_equal(test_df, deserialize(*serialize(test_df))) + + raw_df.columns = pd.MultiIndex.from_tuples([("a", "a"), ("a", "b"), ("b", "c")]) + test_df = cudf.DataFrame(raw_df) cudf.testing.assert_frame_equal(test_df, deserialize(*serialize(test_df)))