Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Warn if map_elements is called without return_dtype specified #15188

Merged
merged 5 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4255,7 +4255,9 @@ def map_elements(
The function is applied to each element of column `'a'`:

>>> df.with_columns( # doctest: +SKIP
... pl.col("a").map_elements(lambda x: x * 2).alias("a_times_2"),
... pl.col("a")
... .map_elements(lambda x: x * 2, return_dtype=pl.Int64)
... .alias("a_times_2"),
... )
shape: (4, 3)
┌─────┬─────┬───────────┐
Expand Down Expand Up @@ -4296,7 +4298,7 @@ def map_elements(
>>> (
... df.lazy()
... .group_by("b")
... .agg(pl.col("a").map_elements(lambda x: x.sum()))
... .agg(pl.col("a").map_elements(lambda x: x.sum(), return_dtype=pl.Int64))
... .collect()
... ) # doctest: +IGNORE_RESULT
shape: (3, 2)
Expand Down Expand Up @@ -4329,7 +4331,9 @@ def map_elements(
... }
... )
>>> df.with_columns(
... scaled=pl.col("val").map_elements(lambda s: s * len(s)).over("key"),
... scaled=pl.col("val")
... .map_elements(lambda s: s * len(s), return_dtype=pl.List(pl.Int64))
... .over("key"),
... ).sort("key")
shape: (6, 3)
┌─────┬─────┬────────┐
Expand Down Expand Up @@ -5315,12 +5319,16 @@ def xor(self, other: Any) -> Self:
... schema={"x": pl.UInt8, "y": pl.UInt8},
... )
>>> df.with_columns(
... pl.col("x").map_elements(binary_string).alias("bin_x"),
... pl.col("y").map_elements(binary_string).alias("bin_y"),
... pl.col("x")
... .map_elements(binary_string, return_dtype=pl.String)
... .alias("bin_x"),
... pl.col("y")
... .map_elements(binary_string, return_dtype=pl.String)
... .alias("bin_y"),
... pl.col("x").xor(pl.col("y")).alias("xor_xy"),
... pl.col("x")
... .xor(pl.col("y"))
... .map_elements(binary_string)
... .map_elements(binary_string, return_dtype=pl.String)
... .alias("bin_xor_xy"),
... )
shape: (4, 6)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/io/spreadsheet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def _read_spreadsheet_pyxlsb(
if schema_overrides:
for idx, s in enumerate(series_data):
if schema_overrides.get(s.name) in (Datetime, Date):
series_data[idx] = s.map_elements(convert_date)
series_data[idx] = s.map_elements(convert_date, return_dtype=Datetime)

df = pl.DataFrame(
{s.name: s for s in series_data},
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5253,7 +5253,7 @@ def map_elements(
Examples
--------
>>> s = pl.Series("a", [1, 2, 3])
>>> s.map_elements(lambda x: x + 10) # doctest: +SKIP
>>> s.map_elements(lambda x: x + 10, return_dtype=pl.Int64) # doctest: +SKIP
shape: (3,)
Series: 'a' [i64]
[
Expand Down
6 changes: 6 additions & 0 deletions py-polars/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,12 @@ impl PySeries {
) -> PyResult<PySeries> {
let series = &self.series;

if output_type.is_none() {
polars_warn!(
"calling `map_elements` without specifying `return_dtype` can lead to unpredictable results, \
please specify `return_dtype` to silence this warning")
MarcoGorelli marked this conversation as resolved.
Show resolved Hide resolved
}
stinodego marked this conversation as resolved.
Show resolved Hide resolved

if skip_nulls && (series.null_count() == series.len()) {
if let Some(output_type) = output_type {
return Ok(Series::full_null(series.name(), series.len(), &output_type.0).into());
Expand Down
18 changes: 15 additions & 3 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,14 +961,26 @@ def test_temporal_dtypes_map_elements(
[
# don't actually do this; native expressions are MUCH faster ;)
pl.col("timestamp")
.map_elements(lambda x: const_dtm, skip_nulls=skip_nulls)
.map_elements(
lambda x: const_dtm,
skip_nulls=skip_nulls,
return_dtype=pl.Datetime,
)
.alias("const_dtm"),
# note: the below now trigger a PolarsInefficientMapWarning
pl.col("timestamp")
.map_elements(lambda x: x and x.date(), skip_nulls=skip_nulls)
.map_elements(
lambda x: x and x.date(),
skip_nulls=skip_nulls,
return_dtype=pl.Date,
)
.alias("date"),
pl.col("timestamp")
.map_elements(lambda x: x and x.time(), skip_nulls=skip_nulls)
.map_elements(
lambda x: x and x.time(),
skip_nulls=skip_nulls,
return_dtype=pl.Time,
)
.alias("time"),
]
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,10 @@ def test_parse_invalid_function(func: str) -> None:
("col", "func", "expr_repr"),
TEST_CASES,
)
@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")
@pytest.mark.filterwarnings(
"ignore:invalid value encountered:RuntimeWarning",
"ignore:.*without specifying `return_dtype`:UserWarning",
)
def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
with pytest.warns(
PolarsInefficientMapWarning,
Expand Down Expand Up @@ -294,7 +297,10 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
)


@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")
@pytest.mark.filterwarnings(
"ignore:invalid value encountered:RuntimeWarning",
"ignore:.*without specifying `return_dtype`:UserWarning",
)
def test_parse_apply_raw_functions() -> None:
lf = pl.LazyFrame({"a": [1.1, 2.0, 3.4]})

Expand Down Expand Up @@ -373,7 +379,9 @@ def x10(self, x: pl.Expr) -> pl.Expr:
):
pl_series = pl.Series("srs", [0, 1, 2, 3, 4])
assert_series_equal(
pl_series.map_elements(lambda x: numpy.cos(3) + x - abs(-1)),
pl_series.map_elements(
lambda x: numpy.cos(3) + x - abs(-1), return_dtype=pl.Float64
),
numpy.cos(3) + pl_series - 1,
)

Expand Down Expand Up @@ -405,6 +413,7 @@ def x10(self, x: pl.Expr) -> pl.Expr:
),
],
)
@pytest.mark.filterwarnings("ignore:.*without specifying `return_dtype`:UserWarning")
def test_parse_apply_series(
data: list[Any], func: Callable[[Any], Any], expr_repr: str
) -> None:
Expand Down Expand Up @@ -443,7 +452,7 @@ def test_expr_exact_warning_message() -> None:
# and to keep the assertion on `len(warnings)`.
with pytest.warns(PolarsInefficientMapWarning, match=rf"^{msg}$") as warnings:
df = pl.DataFrame({"a": [1, 2, 3]})
df.select(pl.col("a").map_elements(lambda x: x + 1))
df.select(pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64))

assert len(warnings) == 1

Expand Down
50 changes: 35 additions & 15 deletions py-polars/tests/unit/operations/map/test_map_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ def test_map_elements_infer_list() -> None:
def test_map_elements_arithmetic_consistency() -> None:
df = pl.DataFrame({"A": ["a", "a"], "B": [2, 3]})
with pytest.warns(PolarsInefficientMapWarning, match="with this one instead"):
assert df.group_by("A").agg(pl.col("B").map_elements(lambda x: x + 1.0))[
"B"
].to_list() == [[3.0, 4.0]]
assert df.group_by("A").agg(
pl.col("B").map_elements(
lambda x: x + 1.0, return_dtype=pl.List(pl.Float64)
)
)["B"].to_list() == [[3.0, 4.0]]


def test_map_elements_struct() -> None:
Expand Down Expand Up @@ -85,9 +87,12 @@ def test_map_elements_list_any_value_fallback() -> None:
match=r'(?s)with this one instead:.*pl.col\("text"\).str.json_decode()',
):
df = pl.DataFrame({"text": ['[{"x": 1, "y": 2}, {"x": 3, "y": 4}]']})
assert df.select(pl.col("text").map_elements(json.loads)).to_dict(
as_series=False
) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]}
assert df.select(
pl.col("text").map_elements(
json.loads,
return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})),
)
).to_dict(as_series=False) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]}

# starts with empty list '[]'
df = pl.DataFrame(
Expand All @@ -99,9 +104,14 @@ def test_map_elements_list_any_value_fallback() -> None:
]
}
)
assert df.select(pl.col("text").map_elements(json.loads)).to_dict(
as_series=False
) == {"text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]]}
assert df.select(
pl.col("text").map_elements(
json.loads,
return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})),
)
).to_dict(as_series=False) == {
"text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]]
}


def test_map_elements_all_types() -> None:
Expand Down Expand Up @@ -183,7 +193,9 @@ def test_map_elements_object_dtypes() -> None:
)
.alias("is_numeric1"),
pl.col("a")
.map_elements(lambda x: isinstance(x, (int, float)))
.map_elements(
lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean
)
.alias("is_numeric_infer"),
]
).to_dict(as_series=False) == {
Expand Down Expand Up @@ -212,12 +224,20 @@ def test_map_elements_dict() -> None:
match=r'(?s)with this one instead:.*pl.col\("abc"\).str.json_decode()',
):
df = pl.DataFrame({"abc": ['{"A":"Value1"}', '{"B":"Value2"}']})
assert df.select(pl.col("abc").map_elements(json.loads)).to_dict(
as_series=False
) == {"abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}]}
assert df.select(
pl.col("abc").map_elements(
json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String})
)
).to_dict(as_series=False) == {
"abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}]
}
assert pl.DataFrame(
{"abc": ['{"A":"Value1", "B":"Value2"}', '{"B":"Value3"}']}
).select(pl.col("abc").map_elements(json.loads)).to_dict(as_series=False) == {
).select(
pl.col("abc").map_elements(
json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String})
)
).to_dict(as_series=False) == {
"abc": [{"A": "Value1", "B": "Value2"}, {"A": None, "B": "Value3"}]
}

Expand Down Expand Up @@ -305,7 +325,7 @@ def test_apply_deprecated() -> None:
with pytest.deprecated_call():
pl.col("a").apply(np.abs)
with pytest.deprecated_call():
pl.Series([1, 2, 3]).apply(np.abs)
pl.Series([1, 2, 3]).apply(np.abs, return_dtype=pl.Float64)


def test_cabbage_strategy_14396() -> None:
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,12 +1023,12 @@ def test_fill_nan() -> None:
def test_map_elements() -> None:
with pytest.warns(PolarsInefficientMapWarning):
a = pl.Series("a", [1, 2, None])
b = a.map_elements(lambda x: x**2)
b = a.map_elements(lambda x: x**2, return_dtype=pl.Int64)
assert list(b) == [1, 4, None]

with pytest.warns(PolarsInefficientMapWarning):
a = pl.Series("a", ["foo", "bar", None])
b = a.map_elements(lambda x: x + "py")
b = a.map_elements(lambda x: x + "py", return_dtype=pl.String)
assert list(b) == ["foopy", "barpy", None]

b = a.map_elements(lambda x: len(x), return_dtype=pl.Int32)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_apply() -> None:
ldf = pl.LazyFrame({"a": [1, 2, 3] * 20, "b": [1.0, 2.0, 3.0] * 20})
new = ldf.with_columns(
pl.col("a")
.map_elements(lambda s: s * 2, strategy=strategy) # type: ignore[arg-type]
.map_elements(lambda s: s * 2, strategy=strategy, return_dtype=pl.Int64) # type: ignore[arg-type]
.alias("foo")
)
expected = ldf.clone().with_columns((pl.col("a") * 2).alias("foo"))
Expand Down
Loading