Skip to content

Commit

Permalink
feat: warn if map_elements is called without return_dtype specified
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 20, 2024
1 parent 77b8529 commit d379102
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 33 deletions.
20 changes: 14 additions & 6 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4255,7 +4255,9 @@ def map_elements(
The function is applied to each element of column `'a'`:
>>> df.with_columns( # doctest: +SKIP
... pl.col("a").map_elements(lambda x: x * 2).alias("a_times_2"),
... pl.col("a")
... .map_elements(lambda x: x * 2, return_dtype=pl.Int64)
... .alias("a_times_2"),
... )
shape: (4, 3)
┌─────┬─────┬───────────┐
Expand Down Expand Up @@ -4296,7 +4298,7 @@ def map_elements(
>>> (
... df.lazy()
... .group_by("b")
... .agg(pl.col("a").map_elements(lambda x: x.sum()))
... .agg(pl.col("a").map_elements(lambda x: x.sum(), return_dtype=pl.Int64))
... .collect()
... ) # doctest: +IGNORE_RESULT
shape: (3, 2)
Expand Down Expand Up @@ -4329,7 +4331,9 @@ def map_elements(
... }
... )
>>> df.with_columns(
... scaled=pl.col("val").map_elements(lambda s: s * len(s)).over("key"),
... scaled=pl.col("val")
... .map_elements(lambda s: s * len(s), return_dtype=pl.List(pl.Int64))
... .over("key"),
... ).sort("key")
shape: (6, 3)
┌─────┬─────┬────────┐
Expand Down Expand Up @@ -5315,12 +5319,16 @@ def xor(self, other: Any) -> Self:
... schema={"x": pl.UInt8, "y": pl.UInt8},
... )
>>> df.with_columns(
... pl.col("x").map_elements(binary_string).alias("bin_x"),
... pl.col("y").map_elements(binary_string).alias("bin_y"),
... pl.col("x")
... .map_elements(binary_string, return_dtype=pl.String)
... .alias("bin_x"),
... pl.col("y")
... .map_elements(binary_string, return_dtype=pl.String)
... .alias("bin_y"),
... pl.col("x").xor(pl.col("y")).alias("xor_xy"),
... pl.col("x")
... .xor(pl.col("y"))
... .map_elements(binary_string)
... .map_elements(binary_string, return_dtype=pl.String)
... .alias("bin_xor_xy"),
... )
shape: (4, 6)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/io/spreadsheet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def _read_spreadsheet_pyxlsb(
if schema_overrides:
for idx, s in enumerate(series_data):
if schema_overrides.get(s.name) in (Datetime, Date):
series_data[idx] = s.map_elements(convert_date)
series_data[idx] = s.map_elements(convert_date, return_dtype=Datetime)

df = pl.DataFrame(
{s.name: s for s in series_data},
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5253,7 +5253,7 @@ def map_elements(
Examples
--------
>>> s = pl.Series("a", [1, 2, 3])
>>> s.map_elements(lambda x: x + 10) # doctest: +SKIP
>>> s.map_elements(lambda x: x + 10, return_dtype=pl.Int64) # doctest: +SKIP
shape: (3,)
Series: 'a' [i64]
[
Expand Down
6 changes: 6 additions & 0 deletions py-polars/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,12 @@ impl PySeries {
) -> PyResult<PySeries> {
let series = &self.series;

if output_type.is_none() {
polars_warn!(
"calling `map_elements` without specifying `return_dtype` can lead to unpredictable results. \
Please specify `return_dtype` to silence this warning.")
}

if skip_nulls && (series.null_count() == series.len()) {
if let Some(output_type) = output_type {
return Ok(Series::full_null(series.name(), series.len(), &output_type.0).into());
Expand Down
18 changes: 15 additions & 3 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,14 +961,26 @@ def test_temporal_dtypes_map_elements(
[
# don't actually do this; native expressions are MUCH faster ;)
pl.col("timestamp")
.map_elements(lambda x: const_dtm, skip_nulls=skip_nulls)
.map_elements(
lambda x: const_dtm,
skip_nulls=skip_nulls,
return_dtype=pl.Datetime,
)
.alias("const_dtm"),
# note: the below now trigger a PolarsInefficientMapWarning
pl.col("timestamp")
.map_elements(lambda x: x and x.date(), skip_nulls=skip_nulls)
.map_elements(
lambda x: x and x.date(),
skip_nulls=skip_nulls,
return_dtype=pl.Date,
)
.alias("date"),
pl.col("timestamp")
.map_elements(lambda x: x and x.time(), skip_nulls=skip_nulls)
.map_elements(
lambda x: x and x.time(),
skip_nulls=skip_nulls,
return_dtype=pl.Time,
)
.alias("time"),
]
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,10 @@ def test_parse_invalid_function(func: str) -> None:
("col", "func", "expr_repr"),
TEST_CASES,
)
@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")
@pytest.mark.filterwarnings(
"ignore:invalid value encountered:RuntimeWarning",
"ignore:.*without specifying `rerturn_dtype`:UserWarning",
)
def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
with pytest.warns(
PolarsInefficientMapWarning,
Expand Down Expand Up @@ -294,7 +297,10 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
)


@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")
@pytest.mark.filterwarnings(
"ignore:invalid value encountered:RuntimeWarning",
"ignore:.*without specifying `rerturn_dtype`:UserWarning",
)
def test_parse_apply_raw_functions() -> None:
lf = pl.LazyFrame({"a": [1.1, 2.0, 3.4]})

Expand Down Expand Up @@ -373,7 +379,9 @@ def x10(self, x: pl.Expr) -> pl.Expr:
):
pl_series = pl.Series("srs", [0, 1, 2, 3, 4])
assert_series_equal(
pl_series.map_elements(lambda x: numpy.cos(3) + x - abs(-1)),
pl_series.map_elements(
lambda x: numpy.cos(3) + x - abs(-1), return_dtype=pl.Float64
),
numpy.cos(3) + pl_series - 1,
)

Expand Down Expand Up @@ -405,6 +413,7 @@ def x10(self, x: pl.Expr) -> pl.Expr:
),
],
)
@pytest.mark.filterwarnings("ignore:.*without specifying `rerturn_dtype`:UserWarning")
def test_parse_apply_series(
data: list[Any], func: Callable[[Any], Any], expr_repr: str
) -> None:
Expand Down Expand Up @@ -443,7 +452,7 @@ def test_expr_exact_warning_message() -> None:
# and to keep the assertion on `len(warnings)`.
with pytest.warns(PolarsInefficientMapWarning, match=rf"^{msg}$") as warnings:
df = pl.DataFrame({"a": [1, 2, 3]})
df.select(pl.col("a").map_elements(lambda x: x + 1))
df.select(pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64))

assert len(warnings) == 1

Expand Down
50 changes: 35 additions & 15 deletions py-polars/tests/unit/operations/map/test_map_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ def test_map_elements_infer_list() -> None:
def test_map_elements_arithmetic_consistency() -> None:
df = pl.DataFrame({"A": ["a", "a"], "B": [2, 3]})
with pytest.warns(PolarsInefficientMapWarning, match="with this one instead"):
assert df.group_by("A").agg(pl.col("B").map_elements(lambda x: x + 1.0))[
"B"
].to_list() == [[3.0, 4.0]]
assert df.group_by("A").agg(
pl.col("B").map_elements(
lambda x: x + 1.0, return_dtype=pl.List(pl.Float64)
)
)["B"].to_list() == [[3.0, 4.0]]


def test_map_elements_struct() -> None:
Expand Down Expand Up @@ -85,9 +87,12 @@ def test_map_elements_list_any_value_fallback() -> None:
match=r'(?s)with this one instead:.*pl.col\("text"\).str.json_decode()',
):
df = pl.DataFrame({"text": ['[{"x": 1, "y": 2}, {"x": 3, "y": 4}]']})
assert df.select(pl.col("text").map_elements(json.loads)).to_dict(
as_series=False
) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]}
assert df.select(
pl.col("text").map_elements(
json.loads,
return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})),
)
).to_dict(as_series=False) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]}

# starts with empty list '[]'
df = pl.DataFrame(
Expand All @@ -99,9 +104,14 @@ def test_map_elements_list_any_value_fallback() -> None:
]
}
)
assert df.select(pl.col("text").map_elements(json.loads)).to_dict(
as_series=False
) == {"text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]]}
assert df.select(
pl.col("text").map_elements(
json.loads,
return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})),
)
).to_dict(as_series=False) == {
"text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]]
}


def test_map_elements_all_types() -> None:
Expand Down Expand Up @@ -183,7 +193,9 @@ def test_map_elements_object_dtypes() -> None:
)
.alias("is_numeric1"),
pl.col("a")
.map_elements(lambda x: isinstance(x, (int, float)))
.map_elements(
lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean
)
.alias("is_numeric_infer"),
]
).to_dict(as_series=False) == {
Expand Down Expand Up @@ -212,12 +224,20 @@ def test_map_elements_dict() -> None:
match=r'(?s)with this one instead:.*pl.col\("abc"\).str.json_decode()',
):
df = pl.DataFrame({"abc": ['{"A":"Value1"}', '{"B":"Value2"}']})
assert df.select(pl.col("abc").map_elements(json.loads)).to_dict(
as_series=False
) == {"abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}]}
assert df.select(
pl.col("abc").map_elements(
json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String})
)
).to_dict(as_series=False) == {
"abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}]
}
assert pl.DataFrame(
{"abc": ['{"A":"Value1", "B":"Value2"}', '{"B":"Value3"}']}
).select(pl.col("abc").map_elements(json.loads)).to_dict(as_series=False) == {
).select(
pl.col("abc").map_elements(
json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String})
)
).to_dict(as_series=False) == {
"abc": [{"A": "Value1", "B": "Value2"}, {"A": None, "B": "Value3"}]
}

Expand Down Expand Up @@ -305,7 +325,7 @@ def test_apply_deprecated() -> None:
with pytest.deprecated_call():
pl.col("a").apply(np.abs)
with pytest.deprecated_call():
pl.Series([1, 2, 3]).apply(np.abs)
pl.Series([1, 2, 3]).apply(np.abs, return_dtype=pl.Float64)


def test_cabbage_strategy_14396() -> None:
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,12 +1023,12 @@ def test_fill_nan() -> None:
def test_map_elements() -> None:
with pytest.warns(PolarsInefficientMapWarning):
a = pl.Series("a", [1, 2, None])
b = a.map_elements(lambda x: x**2)
b = a.map_elements(lambda x: x**2, return_dtype=pl.Int64)
assert list(b) == [1, 4, None]

with pytest.warns(PolarsInefficientMapWarning):
a = pl.Series("a", ["foo", "bar", None])
b = a.map_elements(lambda x: x + "py")
b = a.map_elements(lambda x: x + "py", return_dtype=pl.String)
assert list(b) == ["foopy", "barpy", None]

b = a.map_elements(lambda x: len(x), return_dtype=pl.Int32)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_apply() -> None:
ldf = pl.LazyFrame({"a": [1, 2, 3] * 20, "b": [1.0, 2.0, 3.0] * 20})
new = ldf.with_columns(
pl.col("a")
.map_elements(lambda s: s * 2, strategy=strategy) # type: ignore[arg-type]
.map_elements(lambda s: s * 2, strategy=strategy, return_dtype=pl.Int64) # type: ignore[arg-type]
.alias("foo")
)
expected = ldf.clone().with_columns((pl.col("a") * 2).alias("foo"))
Expand Down

0 comments on commit d379102

Please sign in to comment.