diff --git a/crates/polars-error/src/warning.rs b/crates/polars-error/src/warning.rs index be18c3fa7172..ee774a928bb8 100644 --- a/crates/polars-error/src/warning.rs +++ b/crates/polars-error/src/warning.rs @@ -14,6 +14,7 @@ pub unsafe fn set_warning_function(function: WarningFunction) { pub enum PolarsWarning { UserWarning, CategoricalRemappingWarning, + MapWithoutReturnDtypeWarning, } fn eprintln(fmt: &str, warning: PolarsWarning) { diff --git a/py-polars/polars/__init__.py b/py-polars/polars/__init__.py index 7a898d0ebc3f..4691a6f13cc2 100644 --- a/py-polars/polars/__init__.py +++ b/py-polars/polars/__init__.py @@ -80,6 +80,7 @@ ComputeError, DuplicateError, InvalidOperationError, + MapWithoutReturnDtypeWarning, NoDataError, OutOfBoundsError, PolarsError, @@ -244,6 +245,7 @@ "PolarsWarning", "CategoricalRemappingWarning", "ChronoFormatWarning", + "MapWithoutReturnDtypeWarning", "UnstableWarning", # core classes "DataFrame", diff --git a/py-polars/polars/exceptions.py b/py-polars/polars/exceptions.py index 1fe2b9db7ea6..5c3bc3e26bd1 100644 --- a/py-polars/polars/exceptions.py +++ b/py-polars/polars/exceptions.py @@ -5,6 +5,7 @@ ComputeError, DuplicateError, InvalidOperationError, + MapWithoutReturnDtypeWarning, NoDataError, OutOfBoundsError, PolarsError, @@ -64,6 +65,9 @@ class PolarsWarning(Exception): # type: ignore[no-redef] class CategoricalRemappingWarning(PolarsWarning): # type: ignore[no-redef, misc] """Warning raised when a categorical needs to be remapped to be compatible with another categorical.""" # noqa: W505 + class MapWithoutReturnDtypeWarning(PolarsWarning): # type: ignore[no-redef, misc] + """Warning raised when `map_elements` is performed without specifying the return dtype.""" # noqa: W505 + class InvalidAssert(PolarsError): # type: ignore[misc] """Exception raised when an unsupported testing assert is made.""" @@ -132,6 +136,7 @@ class CustomUFuncWarning(PolarsWarning): # type: ignore[misc] "ChronoFormatWarning", "DuplicateError", "InvalidOperationError", + "MapWithoutReturnDtypeWarning", "ModuleUpgradeRequired", "NoDataError", "NoRowsReturnedError", diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index fb1ce7ebbb09..0c53ed36d520 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -4255,7 +4255,9 @@ def map_elements( The function is applied to each element of column `'a'`: >>> df.with_columns( # doctest: +SKIP - ... pl.col("a").map_elements(lambda x: x * 2).alias("a_times_2"), + ... pl.col("a") + ... .map_elements(lambda x: x * 2, return_dtype=pl.Int64) + ... .alias("a_times_2"), ... ) shape: (4, 3) ┌─────┬─────┬───────────┐ @@ -4296,7 +4298,7 @@ def map_elements( >>> ( ... df.lazy() ... .group_by("b") - ... .agg(pl.col("a").map_elements(lambda x: x.sum())) + ... .agg(pl.col("a").map_elements(lambda x: x.sum(), return_dtype=pl.Int64)) ... .collect() ... ) # doctest: +IGNORE_RESULT shape: (3, 2) @@ -4329,7 +4331,9 @@ def map_elements( ... } ... ) >>> df.with_columns( - ... scaled=pl.col("val").map_elements(lambda s: s * len(s)).over("key"), + ... scaled=pl.col("val") + ... .map_elements(lambda s: s * len(s), return_dtype=pl.List(pl.Int64)) + ... .over("key"), ... ).sort("key") shape: (6, 3) ┌─────┬─────┬────────┐ @@ -5315,12 +5319,16 @@ def xor(self, other: Any) -> Self: ... schema={"x": pl.UInt8, "y": pl.UInt8}, ... ) >>> df.with_columns( - ... pl.col("x").map_elements(binary_string).alias("bin_x"), - ... pl.col("y").map_elements(binary_string).alias("bin_y"), + ... pl.col("x") + ... .map_elements(binary_string, return_dtype=pl.String) + ... .alias("bin_x"), + ... pl.col("y") + ... .map_elements(binary_string, return_dtype=pl.String) + ... .alias("bin_y"), ... pl.col("x").xor(pl.col("y")).alias("xor_xy"), ... pl.col("x") ... .xor(pl.col("y")) - ... .map_elements(binary_string) + ... .map_elements(binary_string, return_dtype=pl.String) ... .alias("bin_xor_xy"), ... ) shape: (4, 6) diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 4fde7f7a9ea9..772d846b1570 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -899,7 +899,7 @@ def _read_spreadsheet_pyxlsb( if schema_overrides: for idx, s in enumerate(series_data): if schema_overrides.get(s.name) in (Datetime, Date): - series_data[idx] = s.map_elements(convert_date) + series_data[idx] = s.map_elements(convert_date, return_dtype=Datetime) df = pl.DataFrame( {s.name: s for s in series_data}, diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 1b4f6ab3d568..c9ce2200690d 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -5292,7 +5292,7 @@ def map_elements( Examples -------- >>> s = pl.Series("a", [1, 2, 3]) - >>> s.map_elements(lambda x: x + 10) # doctest: +SKIP + >>> s.map_elements(lambda x: x + 10, return_dtype=pl.Int64) # doctest: +SKIP shape: (3,) Series: 'a' [i64] [ diff --git a/py-polars/src/error.rs b/py-polars/src/error.rs index 524286c70a1a..6af1e3997f0a 100644 --- a/py-polars/src/error.rs +++ b/py-polars/src/error.rs @@ -94,6 +94,11 @@ create_exception!( CategoricalRemappingWarning, PolarsBaseWarning ); +create_exception!( + polars.exceptions, + MapWithoutReturnDtypeWarning, + PolarsBaseWarning +); #[macro_export] macro_rules! raise_err( @@ -109,6 +114,9 @@ impl IntoPy for Wrap { PolarsWarning::CategoricalRemappingWarning => { CategoricalRemappingWarning::type_object(py).to_object(py) }, + PolarsWarning::MapWithoutReturnDtypeWarning => { + MapWithoutReturnDtypeWarning::type_object(py).to_object(py) + }, PolarsWarning::UserWarning => PyUserWarning::type_object(py).to_object(py), } } diff --git a/py-polars/src/lib.rs b/py-polars/src/lib.rs index cdd1725f6f9e..b3a2a77c5f08 100644 --- a/py-polars/src/lib.rs +++ b/py-polars/src/lib.rs @@ -54,8 +54,9 @@ use crate::conversion::Wrap; use crate::dataframe::PyDataFrame; use crate::error::{ CategoricalRemappingWarning, ColumnNotFoundError, ComputeError, DuplicateError, - InvalidOperationError, NoDataError, OutOfBoundsError, PolarsBaseError, PolarsBaseWarning, - PyPolarsErr, SchemaError, SchemaFieldNotFoundError, StructFieldNotFoundError, + InvalidOperationError, MapWithoutReturnDtypeWarning, NoDataError, OutOfBoundsError, + PolarsBaseError, PolarsBaseWarning, PyPolarsErr, SchemaError, SchemaFieldNotFoundError, + StructFieldNotFoundError, }; use crate::expr::PyExpr; use crate::functions::PyStringCacheHolder; @@ -296,6 +297,11 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> { py.get_type::(), ) .unwrap(); + m.add( + "MapWithoutReturnDtypeWarning", + py.get_type::(), + ) + .unwrap(); // Build info m.add("__version__", env!("CARGO_PKG_VERSION"))?; diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index a523ee6439ea..ee86741d5d59 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -342,6 +342,13 @@ impl PySeries { ) -> PyResult { let series = &self.series; + if output_type.is_none() { + polars_warn!( + MapWithoutReturnDtypeWarning, + "Calling `map_elements` without specifying `return_dtype` can lead to unpredictable results. \ + Specify `return_dtype` to silence this warning.") + } + if skip_nulls && (series.null_count() == series.len()) { if let Some(output_type) = output_type { return Ok(Series::full_null(series.name(), series.len(), &output_type.0).into()); diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index 683c249d1655..ae9faede5645 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -961,14 +961,26 @@ def test_temporal_dtypes_map_elements( [ # don't actually do this; native expressions are MUCH faster ;) pl.col("timestamp") - .map_elements(lambda x: const_dtm, skip_nulls=skip_nulls) + .map_elements( + lambda x: const_dtm, + skip_nulls=skip_nulls, + return_dtype=pl.Datetime, + ) .alias("const_dtm"), # note: the below now trigger a PolarsInefficientMapWarning pl.col("timestamp") - .map_elements(lambda x: x and x.date(), skip_nulls=skip_nulls) + .map_elements( + lambda x: x and x.date(), + skip_nulls=skip_nulls, + return_dtype=pl.Date, + ) .alias("date"), pl.col("timestamp") - .map_elements(lambda x: x and x.time(), skip_nulls=skip_nulls) + .map_elements( + lambda x: x and x.time(), + skip_nulls=skip_nulls, + return_dtype=pl.Time, + ) .alias("time"), ] ), diff --git a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py index 8f9740a42032..fa1663bc146c 100644 --- a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py +++ b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py @@ -255,7 +255,10 @@ def test_parse_invalid_function(func: str) -> None: ("col", "func", "expr_repr"), TEST_CASES, ) -@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +@pytest.mark.filterwarnings( + "ignore:invalid value encountered:RuntimeWarning", + "ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning", +) def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: with pytest.warns( PolarsInefficientMapWarning, @@ -294,7 +297,10 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: ) -@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +@pytest.mark.filterwarnings( + "ignore:invalid value encountered:RuntimeWarning", + "ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning", +) def test_parse_apply_raw_functions() -> None: lf = pl.LazyFrame({"a": [1.1, 2.0, 3.4]}) @@ -373,7 +379,9 @@ def x10(self, x: pl.Expr) -> pl.Expr: ): pl_series = pl.Series("srs", [0, 1, 2, 3, 4]) assert_series_equal( - pl_series.map_elements(lambda x: numpy.cos(3) + x - abs(-1)), + pl_series.map_elements( + lambda x: numpy.cos(3) + x - abs(-1), return_dtype=pl.Float64 + ), numpy.cos(3) + pl_series - 1, ) @@ -405,6 +413,9 @@ def x10(self, x: pl.Expr) -> pl.Expr: ), ], ) +@pytest.mark.filterwarnings( + "ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning" +) def test_parse_apply_series( data: list[Any], func: Callable[[Any], Any], expr_repr: str ) -> None: @@ -443,7 +454,7 @@ def test_expr_exact_warning_message() -> None: # and to keep the assertion on `len(warnings)`. with pytest.warns(PolarsInefficientMapWarning, match=rf"^{msg}$") as warnings: df = pl.DataFrame({"a": [1, 2, 3]}) - df.select(pl.col("a").map_elements(lambda x: x + 1)) + df.select(pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64)) assert len(warnings) == 1 diff --git a/py-polars/tests/unit/operations/map/test_map_elements.py b/py-polars/tests/unit/operations/map/test_map_elements.py index 87ced66a1510..35a276e8e94c 100644 --- a/py-polars/tests/unit/operations/map/test_map_elements.py +++ b/py-polars/tests/unit/operations/map/test_map_elements.py @@ -25,9 +25,11 @@ def test_map_elements_infer_list() -> None: def test_map_elements_arithmetic_consistency() -> None: df = pl.DataFrame({"A": ["a", "a"], "B": [2, 3]}) with pytest.warns(PolarsInefficientMapWarning, match="with this one instead"): - assert df.group_by("A").agg(pl.col("B").map_elements(lambda x: x + 1.0))[ - "B" - ].to_list() == [[3.0, 4.0]] + assert df.group_by("A").agg( + pl.col("B").map_elements( + lambda x: x + 1.0, return_dtype=pl.List(pl.Float64) + ) + )["B"].to_list() == [[3.0, 4.0]] def test_map_elements_struct() -> None: @@ -85,9 +87,12 @@ def test_map_elements_list_any_value_fallback() -> None: match=r'(?s)with this one instead:.*pl.col\("text"\).str.json_decode()', ): df = pl.DataFrame({"text": ['[{"x": 1, "y": 2}, {"x": 3, "y": 4}]']}) - assert df.select(pl.col("text").map_elements(json.loads)).to_dict( - as_series=False - ) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]} + assert df.select( + pl.col("text").map_elements( + json.loads, + return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})), + ) + ).to_dict(as_series=False) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]} # starts with empty list '[]' df = pl.DataFrame( @@ -99,9 +104,14 @@ def test_map_elements_list_any_value_fallback() -> None: ] } ) - assert df.select(pl.col("text").map_elements(json.loads)).to_dict( - as_series=False - ) == {"text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]]} + assert df.select( + pl.col("text").map_elements( + json.loads, + return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})), + ) + ).to_dict(as_series=False) == { + "text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]] + } def test_map_elements_all_types() -> None: @@ -183,7 +193,9 @@ def test_map_elements_object_dtypes() -> None: ) .alias("is_numeric1"), pl.col("a") - .map_elements(lambda x: isinstance(x, (int, float))) + .map_elements( + lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean + ) .alias("is_numeric_infer"), ] ).to_dict(as_series=False) == { @@ -212,12 +224,20 @@ def test_map_elements_dict() -> None: match=r'(?s)with this one instead:.*pl.col\("abc"\).str.json_decode()', ): df = pl.DataFrame({"abc": ['{"A":"Value1"}', '{"B":"Value2"}']}) - assert df.select(pl.col("abc").map_elements(json.loads)).to_dict( - as_series=False - ) == {"abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}]} + assert df.select( + pl.col("abc").map_elements( + json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String}) + ) + ).to_dict(as_series=False) == { + "abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}] + } assert pl.DataFrame( {"abc": ['{"A":"Value1", "B":"Value2"}', '{"B":"Value3"}']} - ).select(pl.col("abc").map_elements(json.loads)).to_dict(as_series=False) == { + ).select( + pl.col("abc").map_elements( + json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String}) + ) + ).to_dict(as_series=False) == { "abc": [{"A": "Value1", "B": "Value2"}, {"A": None, "B": "Value3"}] } @@ -305,7 +325,7 @@ def test_apply_deprecated() -> None: with pytest.deprecated_call(): pl.col("a").apply(np.abs) with pytest.deprecated_call(): - pl.Series([1, 2, 3]).apply(np.abs) + pl.Series([1, 2, 3]).apply(np.abs, return_dtype=pl.Float64) def test_cabbage_strategy_14396() -> None: diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 5d9178a582b5..aadd6102b1c0 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -1023,12 +1023,12 @@ def test_fill_nan() -> None: def test_map_elements() -> None: with pytest.warns(PolarsInefficientMapWarning): a = pl.Series("a", [1, 2, None]) - b = a.map_elements(lambda x: x**2) + b = a.map_elements(lambda x: x**2, return_dtype=pl.Int64) assert list(b) == [1, 4, None] with pytest.warns(PolarsInefficientMapWarning): a = pl.Series("a", ["foo", "bar", None]) - b = a.map_elements(lambda x: x + "py") + b = a.map_elements(lambda x: x + "py", return_dtype=pl.String) assert list(b) == ["foopy", "barpy", None] b = a.map_elements(lambda x: len(x), return_dtype=pl.Int32) diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 383294250109..0a45912894b3 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -95,7 +95,7 @@ def test_apply() -> None: ldf = pl.LazyFrame({"a": [1, 2, 3] * 20, "b": [1.0, 2.0, 3.0] * 20}) new = ldf.with_columns( pl.col("a") - .map_elements(lambda s: s * 2, strategy=strategy) # type: ignore[arg-type] + .map_elements(lambda s: s * 2, strategy=strategy, return_dtype=pl.Int64) # type: ignore[arg-type] .alias("foo") ) expected = ldf.clone().with_columns((pl.col("a") * 2).alias("foo"))