diff --git a/crates/polars-core/src/utils/supertype.rs b/crates/polars-core/src/utils/supertype.rs index 47ecd9c13a4f..1f1d7f9751c8 100644 --- a/crates/polars-core/src/utils/supertype.rs +++ b/crates/polars-core/src/utils/supertype.rs @@ -304,6 +304,7 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option { } } } + UnknownKind::Int(_) if dt.is_decimal() => Some(dt.clone()), _ => Some(Unknown(UnknownKind::Any)) } }, diff --git a/py-polars/tests/unit/datatypes/test_decimal.py b/py-polars/tests/unit/datatypes/test_decimal.py index 1e95adb4a3e7..9d14e0b8c549 100644 --- a/py-polars/tests/unit/datatypes/test_decimal.py +++ b/py-polars/tests/unit/datatypes/test_decimal.py @@ -461,3 +461,13 @@ def test_decimal_streaming() -> None: D("161102921617598.363263936811563000"), ], } + + +def test_decimal_supertype() -> None: + with pl.Config() as cfg: + cfg.activate_decimals() + pl.Config.activate_decimals() + q = pl.LazyFrame([0.12345678]).select( + pl.col("column_0").cast(pl.Decimal(scale=6)) * 1 + ) + assert q.collect().dtypes[0].is_decimal()