From d8d9b79fee89568316299d86a7cf75b56de3ffdf Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 19 Jun 2024 19:43:33 +0200 Subject: [PATCH 1/2] Fix PandasBlocks implementation for missmatching categories --- partd/pandas.py | 13 ++++++++++++- partd/tests/test_pandas.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/partd/pandas.py b/partd/pandas.py index 880558c..b0c738d 100644 --- a/partd/pandas.py +++ b/partd/pandas.py @@ -211,6 +211,17 @@ def join(dfs): if not dfs: return pd.DataFrame() else: - return pd.concat(dfs) + result = pd.concat(dfs) + dtypes = { + col: "category" + for col in result.columns + if ( + pd.api.types.is_categorical_dtype(dfs[0][col].dtype) + and not pd.api.types.is_categorical_dtype(result[col].dtype) + ) + } + if dtypes: + result = result.astype(dtypes) + return result PandasBlocks = partial(Encode, serialize, deserialize, join) diff --git a/partd/tests/test_pandas.py b/partd/tests/test_pandas.py index 72c37dc..f64804b 100644 --- a/partd/tests/test_pandas.py +++ b/partd/tests/test_pandas.py @@ -146,3 +146,17 @@ def test_index_non_numeric_extension_types(dtype): df.index = df.index.astype(dtype) df2 = deserialize(serialize(df)) tm.assert_frame_equal(df, df2) + + +def test_categorical_concat(): + pytest.importorskip("pandas", minversion="2") + + df1 = pd.DataFrame({"a": ["x", "y"]}, dtype="category") + df2 = pd.DataFrame({"a": ["y", "z"]}, dtype="category") + + with PandasBlocks() as p: + p.append({'x': df1}) + p.append({'x': df2}) + + result = p.get(["x"]) + pd.testing.assert_frame_equal(result[0], pd.concat([df1, df2]).astype("category")) From 8ab3cb26d17005176fcbac3b6bd71feb077ccb75 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 19 Jun 2024 19:56:33 +0200 Subject: [PATCH 2/2] Update --- partd/pandas.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/partd/pandas.py b/partd/pandas.py index b0c738d..36c1b01 100644 --- a/partd/pandas.py +++ b/partd/pandas.py @@ -216,8 +216,8 @@ def join(dfs): col: "category" for col in result.columns if ( - pd.api.types.is_categorical_dtype(dfs[0][col].dtype) - and not pd.api.types.is_categorical_dtype(result[col].dtype) + isinstance(dfs[0][col].dtype, pd.CategoricalDtype) + and not isinstance(result[col].dtype, pd.CategoricalDtype) ) } if dtypes: