diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 641ff283325e8..0f3381f0c7d2f 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -4255,7 +4255,9 @@ def map_elements( The function is applied to each element of column `'a'`: >>> df.with_columns( # doctest: +SKIP - ... pl.col("a").map_elements(lambda x: x * 2).alias("a_times_2"), + ... pl.col("a") + ... .map_elements(lambda x: x * 2, return_dtype=pl.Int64) + ... .alias("a_times_2"), ... ) shape: (4, 3) ┌─────┬─────┬───────────┐ @@ -4296,7 +4298,7 @@ def map_elements( >>> ( ... df.lazy() ... .group_by("b") - ... .agg(pl.col("a").map_elements(lambda x: x.sum())) + ... .agg(pl.col("a").map_elements(lambda x: x.sum(), return_dtype=pl.Int64)) ... .collect() ... ) # doctest: +IGNORE_RESULT shape: (3, 2) @@ -4329,7 +4331,9 @@ def map_elements( ... } ... ) >>> df.with_columns( - ... scaled=pl.col("val").map_elements(lambda s: s * len(s)).over("key"), + ... scaled=pl.col("val") + ... .map_elements(lambda s: s * len(s), return_dtype=pl.List(pl.Int64)) + ... .over("key"), ... ).sort("key") shape: (6, 3) ┌─────┬─────┬────────┐ @@ -5315,12 +5319,16 @@ def xor(self, other: Any) -> Self: ... schema={"x": pl.UInt8, "y": pl.UInt8}, ... ) >>> df.with_columns( - ... pl.col("x").map_elements(binary_string).alias("bin_x"), - ... pl.col("y").map_elements(binary_string).alias("bin_y"), + ... pl.col("x") + ... .map_elements(binary_string, return_dtype=pl.String) + ... .alias("bin_x"), + ... pl.col("y") + ... .map_elements(binary_string, return_dtype=pl.String) + ... .alias("bin_y"), ... pl.col("x").xor(pl.col("y")).alias("xor_xy"), ... pl.col("x") ... .xor(pl.col("y")) - ... .map_elements(binary_string) + ... .map_elements(binary_string, return_dtype=pl.String) ... .alias("bin_xor_xy"), ... ) shape: (4, 6) diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index 4fde7f7a9ea9b..772d846b1570e 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -899,7 +899,7 @@ def _read_spreadsheet_pyxlsb( if schema_overrides: for idx, s in enumerate(series_data): if schema_overrides.get(s.name) in (Datetime, Date): - series_data[idx] = s.map_elements(convert_date) + series_data[idx] = s.map_elements(convert_date, return_dtype=Datetime) df = pl.DataFrame( {s.name: s for s in series_data}, diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 39877d4dc341a..b2547a22f3c80 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -5253,7 +5253,7 @@ def map_elements( Examples -------- >>> s = pl.Series("a", [1, 2, 3]) - >>> s.map_elements(lambda x: x + 10) # doctest: +SKIP + >>> s.map_elements(lambda x: x + 10, return_dtype=pl.Int64) # doctest: +SKIP shape: (3,) Series: 'a' [i64] [ diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index a523ee6439ea3..7433b1d6a1161 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -342,6 +342,12 @@ impl PySeries { ) -> PyResult { let series = &self.series; + if output_type.is_none() { + polars_warn!( + "calling `map_elements` without specifying `return_dtype` can lead to unpredictable results. \ + Please specify `return_dtype` to silence this warning.") + } + if skip_nulls && (series.null_count() == series.len()) { if let Some(output_type) = output_type { return Ok(Series::full_null(series.name(), series.len(), &output_type.0).into()); diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index 4a465aefbb837..ab45bfe3c65d6 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -961,14 +961,26 @@ def test_temporal_dtypes_map_elements( [ # don't actually do this; native expressions are MUCH faster ;) pl.col("timestamp") - .map_elements(lambda x: const_dtm, skip_nulls=skip_nulls) + .map_elements( + lambda x: const_dtm, + skip_nulls=skip_nulls, + return_dtype=pl.Datetime, + ) .alias("const_dtm"), # note: the below now trigger a PolarsInefficientMapWarning pl.col("timestamp") - .map_elements(lambda x: x and x.date(), skip_nulls=skip_nulls) + .map_elements( + lambda x: x and x.date(), + skip_nulls=skip_nulls, + return_dtype=pl.Date, + ) .alias("date"), pl.col("timestamp") - .map_elements(lambda x: x and x.time(), skip_nulls=skip_nulls) + .map_elements( + lambda x: x and x.time(), + skip_nulls=skip_nulls, + return_dtype=pl.Time, + ) .alias("time"), ] ), diff --git a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py index 8f9740a42032a..320de61f8e576 100644 --- a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py +++ b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py @@ -255,7 +255,10 @@ def test_parse_invalid_function(func: str) -> None: ("col", "func", "expr_repr"), TEST_CASES, ) -@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +@pytest.mark.filterwarnings( + "ignore:invalid value encountered:RuntimeWarning", + "ignore:.*without specifying `rerturn_dtype`:UserWarning", +) def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: with pytest.warns( PolarsInefficientMapWarning, @@ -294,7 +297,10 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: ) -@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +@pytest.mark.filterwarnings( + "ignore:invalid value encountered:RuntimeWarning", + "ignore:.*without specifying `rerturn_dtype`:UserWarning", +) def test_parse_apply_raw_functions() -> None: lf = pl.LazyFrame({"a": [1.1, 2.0, 3.4]}) @@ -373,7 +379,9 @@ def x10(self, x: pl.Expr) -> pl.Expr: ): pl_series = pl.Series("srs", [0, 1, 2, 3, 4]) assert_series_equal( - pl_series.map_elements(lambda x: numpy.cos(3) + x - abs(-1)), + pl_series.map_elements( + lambda x: numpy.cos(3) + x - abs(-1), return_dtype=pl.Float64 + ), numpy.cos(3) + pl_series - 1, ) @@ -405,6 +413,7 @@ def x10(self, x: pl.Expr) -> pl.Expr: ), ], ) +@pytest.mark.filterwarnings("ignore:.*without specifying `rerturn_dtype`:UserWarning") def test_parse_apply_series( data: list[Any], func: Callable[[Any], Any], expr_repr: str ) -> None: @@ -443,7 +452,7 @@ def test_expr_exact_warning_message() -> None: # and to keep the assertion on `len(warnings)`. with pytest.warns(PolarsInefficientMapWarning, match=rf"^{msg}$") as warnings: df = pl.DataFrame({"a": [1, 2, 3]}) - df.select(pl.col("a").map_elements(lambda x: x + 1)) + df.select(pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64)) assert len(warnings) == 1 diff --git a/py-polars/tests/unit/operations/map/test_map_elements.py b/py-polars/tests/unit/operations/map/test_map_elements.py index 87ced66a15104..35a276e8e94ca 100644 --- a/py-polars/tests/unit/operations/map/test_map_elements.py +++ b/py-polars/tests/unit/operations/map/test_map_elements.py @@ -25,9 +25,11 @@ def test_map_elements_infer_list() -> None: def test_map_elements_arithmetic_consistency() -> None: df = pl.DataFrame({"A": ["a", "a"], "B": [2, 3]}) with pytest.warns(PolarsInefficientMapWarning, match="with this one instead"): - assert df.group_by("A").agg(pl.col("B").map_elements(lambda x: x + 1.0))[ - "B" - ].to_list() == [[3.0, 4.0]] + assert df.group_by("A").agg( + pl.col("B").map_elements( + lambda x: x + 1.0, return_dtype=pl.List(pl.Float64) + ) + )["B"].to_list() == [[3.0, 4.0]] def test_map_elements_struct() -> None: @@ -85,9 +87,12 @@ def test_map_elements_list_any_value_fallback() -> None: match=r'(?s)with this one instead:.*pl.col\("text"\).str.json_decode()', ): df = pl.DataFrame({"text": ['[{"x": 1, "y": 2}, {"x": 3, "y": 4}]']}) - assert df.select(pl.col("text").map_elements(json.loads)).to_dict( - as_series=False - ) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]} + assert df.select( + pl.col("text").map_elements( + json.loads, + return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})), + ) + ).to_dict(as_series=False) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]} # starts with empty list '[]' df = pl.DataFrame( @@ -99,9 +104,14 @@ def test_map_elements_list_any_value_fallback() -> None: ] } ) - assert df.select(pl.col("text").map_elements(json.loads)).to_dict( - as_series=False - ) == {"text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]]} + assert df.select( + pl.col("text").map_elements( + json.loads, + return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})), + ) + ).to_dict(as_series=False) == { + "text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]] + } def test_map_elements_all_types() -> None: @@ -183,7 +193,9 @@ def test_map_elements_object_dtypes() -> None: ) .alias("is_numeric1"), pl.col("a") - .map_elements(lambda x: isinstance(x, (int, float))) + .map_elements( + lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean + ) .alias("is_numeric_infer"), ] ).to_dict(as_series=False) == { @@ -212,12 +224,20 @@ def test_map_elements_dict() -> None: match=r'(?s)with this one instead:.*pl.col\("abc"\).str.json_decode()', ): df = pl.DataFrame({"abc": ['{"A":"Value1"}', '{"B":"Value2"}']}) - assert df.select(pl.col("abc").map_elements(json.loads)).to_dict( - as_series=False - ) == {"abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}]} + assert df.select( + pl.col("abc").map_elements( + json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String}) + ) + ).to_dict(as_series=False) == { + "abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}] + } assert pl.DataFrame( {"abc": ['{"A":"Value1", "B":"Value2"}', '{"B":"Value3"}']} - ).select(pl.col("abc").map_elements(json.loads)).to_dict(as_series=False) == { + ).select( + pl.col("abc").map_elements( + json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String}) + ) + ).to_dict(as_series=False) == { "abc": [{"A": "Value1", "B": "Value2"}, {"A": None, "B": "Value3"}] } @@ -305,7 +325,7 @@ def test_apply_deprecated() -> None: with pytest.deprecated_call(): pl.col("a").apply(np.abs) with pytest.deprecated_call(): - pl.Series([1, 2, 3]).apply(np.abs) + pl.Series([1, 2, 3]).apply(np.abs, return_dtype=pl.Float64) def test_cabbage_strategy_14396() -> None: diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 5d9178a582b57..aadd6102b1c01 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -1023,12 +1023,12 @@ def test_fill_nan() -> None: def test_map_elements() -> None: with pytest.warns(PolarsInefficientMapWarning): a = pl.Series("a", [1, 2, None]) - b = a.map_elements(lambda x: x**2) + b = a.map_elements(lambda x: x**2, return_dtype=pl.Int64) assert list(b) == [1, 4, None] with pytest.warns(PolarsInefficientMapWarning): a = pl.Series("a", ["foo", "bar", None]) - b = a.map_elements(lambda x: x + "py") + b = a.map_elements(lambda x: x + "py", return_dtype=pl.String) assert list(b) == ["foopy", "barpy", None] b = a.map_elements(lambda x: len(x), return_dtype=pl.Int32) diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 3832942501093..0a45912894b37 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -95,7 +95,7 @@ def test_apply() -> None: ldf = pl.LazyFrame({"a": [1, 2, 3] * 20, "b": [1.0, 2.0, 3.0] * 20}) new = ldf.with_columns( pl.col("a") - .map_elements(lambda s: s * 2, strategy=strategy) # type: ignore[arg-type] + .map_elements(lambda s: s * 2, strategy=strategy, return_dtype=pl.Int64) # type: ignore[arg-type] .alias("foo") ) expected = ldf.clone().with_columns((pl.col("a") * 2).alias("foo"))