From 8d2ef5d5e72eec58b056daa726ac74f0c1eb2ecb Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Tue, 6 Feb 2024 12:38:16 -0500 Subject: [PATCH 1/4] Preserve name in Enum cast --- .../chunked_array/logical/categorical/mod.rs | 3 +- py-polars/tests/unit/operations/test_cast.py | 36 ++++++++++--------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs index 9f16779189e7..a2e0bd4397dd 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/mod.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/mod.rs @@ -373,7 +373,8 @@ impl LogicalType for CategoricalChunked { Ok(self .to_enum(categories, *hash)? .set_ordering(*ordering, true) - .into_series()) + .into_series() + .with_name(self.name())) }, DataType::Enum(None, _) => { polars_bail!(ComputeError: "can not cast to enum without categories present") diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index 005655ffe0f2..f7e3a756b676 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -579,25 +579,29 @@ def test_strict_cast_string_and_binary( @pytest.mark.parametrize( - "dtype_out", + ("dtype_in", "dtype_out"), [ - (pl.UInt8), - (pl.Int8), - (pl.UInt16), - (pl.Int16), - (pl.UInt32), - (pl.Int32), - (pl.UInt64), - (pl.Int64), - (pl.Date), - (pl.Datetime), - (pl.Time), - (pl.Duration), - (pl.Enum(["1"])), + (pl.Categorical, pl.UInt8), + (pl.Categorical, pl.Int8), + (pl.Categorical, pl.UInt16), + (pl.Categorical, pl.Int16), + (pl.Categorical, pl.UInt32), + (pl.Categorical, pl.Int32), + (pl.Categorical, pl.UInt64), + (pl.Categorical, pl.Int64), + (pl.Categorical, pl.Date), + (pl.Categorical, pl.Datetime), + (pl.Categorical, pl.Time), + (pl.Categorical, pl.Duration), + (pl.Categorical, pl.Enum(["1"])), + (pl.Enum(["1"]), pl.Enum(["1", "2"])), + (pl.Enum(["1"]), pl.Categorical), ], ) -def test_cast_categorical_name_retention(dtype_out: PolarsDataType) -> None: - assert pl.Series("a", ["1"], dtype=pl.Categorical).cast(dtype_out).name == "a" +def test_cast_categorical_name_retention( + dtype_in: PolarsDataType, dtype_out: PolarsDataType +) -> None: + assert pl.Series("a", ["1"], dtype=dtype_in).cast(dtype_out).name == "a" def test_cast_date_to_time() -> None: From f358e59bf734067f3868aeba1ddf49189cc5a918 Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Tue, 6 Feb 2024 17:22:34 -0500 Subject: [PATCH 2/4] Move enum cast to separate test --- py-polars/tests/unit/operations/test_cast.py | 61 +++++++++++++------- 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index f7e3a756b676..5a190e890ec5 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -579,29 +579,50 @@ def test_strict_cast_string_and_binary( @pytest.mark.parametrize( - ("dtype_in", "dtype_out"), + "dtype_out", [ - (pl.Categorical, pl.UInt8), - (pl.Categorical, pl.Int8), - (pl.Categorical, pl.UInt16), - (pl.Categorical, pl.Int16), - (pl.Categorical, pl.UInt32), - (pl.Categorical, pl.Int32), - (pl.Categorical, pl.UInt64), - (pl.Categorical, pl.Int64), - (pl.Categorical, pl.Date), - (pl.Categorical, pl.Datetime), - (pl.Categorical, pl.Time), - (pl.Categorical, pl.Duration), - (pl.Categorical, pl.Enum(["1"])), - (pl.Enum(["1"]), pl.Enum(["1", "2"])), - (pl.Enum(["1"]), pl.Categorical), + (pl.UInt8), + (pl.Int8), + (pl.UInt16), + (pl.Int16), + (pl.UInt32), + (pl.Int32), + (pl.UInt64), + (pl.Int64), + (pl.Date), + (pl.Datetime), + (pl.Time), + (pl.Duration), + (pl.String), + (pl.Enum(["1"])), ], ) -def test_cast_categorical_name_retention( - dtype_in: PolarsDataType, dtype_out: PolarsDataType -) -> None: - assert pl.Series("a", ["1"], dtype=dtype_in).cast(dtype_out).name == "a" +def test_cast_categorical_name_retention(dtype_out: PolarsDataType) -> None: + assert pl.Series("a", ["1"], dtype=pl.Categorical).cast(dtype_out).name == "a" + + +@pytest.mark.parametrize( + "dtype_out", + [ + (pl.UInt8), + (pl.Int8), + (pl.UInt16), + (pl.Int16), + (pl.UInt32), + (pl.Int32), + (pl.UInt64), + (pl.Int64), + (pl.Date), + (pl.Datetime), + (pl.Time), + (pl.Duration), + (pl.String), + (pl.Categorical), + (pl.Enum(["1", "2"])), + ], +) +def test_cast_enum_name_retention(dtype_out: PolarsDataType) -> None: + assert pl.Series("a", ["1"], dtype=pl.Enum(["1"])).cast(dtype_out).name == "a" def test_cast_date_to_time() -> None: From e80507535d1f639a861598f54c4e73bb79b8766d Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Tue, 6 Feb 2024 17:24:15 -0500 Subject: [PATCH 3/4] Use same coverage for both tests --- py-polars/tests/unit/operations/test_cast.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index 5a190e890ec5..a994c0fe1c16 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -594,7 +594,8 @@ def test_strict_cast_string_and_binary( (pl.Time), (pl.Duration), (pl.String), - (pl.Enum(["1"])), + (pl.Categorical), + (pl.Enum(["1", "2"])), ], ) def test_cast_categorical_name_retention(dtype_out: PolarsDataType) -> None: From 28c8674a939528d82bb16020fcfdb0421395f676 Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Tue, 6 Feb 2024 17:46:46 -0500 Subject: [PATCH 4/4] Parametrize dtype in --- py-polars/tests/unit/operations/test_cast.py | 30 ++++---------------- 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index a994c0fe1c16..4040c112cb24 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -579,29 +579,9 @@ def test_strict_cast_string_and_binary( @pytest.mark.parametrize( - "dtype_out", - [ - (pl.UInt8), - (pl.Int8), - (pl.UInt16), - (pl.Int16), - (pl.UInt32), - (pl.Int32), - (pl.UInt64), - (pl.Int64), - (pl.Date), - (pl.Datetime), - (pl.Time), - (pl.Duration), - (pl.String), - (pl.Categorical), - (pl.Enum(["1", "2"])), - ], + "dtype_in", + [(pl.Categorical), (pl.Enum(["1"]))], ) -def test_cast_categorical_name_retention(dtype_out: PolarsDataType) -> None: - assert pl.Series("a", ["1"], dtype=pl.Categorical).cast(dtype_out).name == "a" - - @pytest.mark.parametrize( "dtype_out", [ @@ -622,8 +602,10 @@ def test_cast_categorical_name_retention(dtype_out: PolarsDataType) -> None: (pl.Enum(["1", "2"])), ], ) -def test_cast_enum_name_retention(dtype_out: PolarsDataType) -> None: - assert pl.Series("a", ["1"], dtype=pl.Enum(["1"])).cast(dtype_out).name == "a" +def test_cast_categorical_name_retention( + dtype_in: PolarsDataType, dtype_out: PolarsDataType +) -> None: + assert pl.Series("a", ["1"], dtype=dtype_in).cast(dtype_out).name == "a" def test_cast_date_to_time() -> None: