Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Warn if map_elements is called without return_dtype specified #15188

Merged
merged 5 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/polars-error/src/warning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub unsafe fn set_warning_function(function: WarningFunction) {
pub enum PolarsWarning {
UserWarning,
CategoricalRemappingWarning,
MapWithoutReturnDtypeWarning,
}

fn eprintln(fmt: &str, warning: PolarsWarning) {
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
ComputeError,
DuplicateError,
InvalidOperationError,
MapWithoutReturnDtypeWarning,
NoDataError,
OutOfBoundsError,
PolarsError,
Expand Down Expand Up @@ -244,6 +245,7 @@
"PolarsWarning",
"CategoricalRemappingWarning",
"ChronoFormatWarning",
"MapWithoutReturnDtypeWarning",
"UnstableWarning",
# core classes
"DataFrame",
Expand Down
5 changes: 5 additions & 0 deletions py-polars/polars/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
ComputeError,
DuplicateError,
InvalidOperationError,
MapWithoutReturnDtypeWarning,
NoDataError,
OutOfBoundsError,
PolarsError,
Expand Down Expand Up @@ -64,6 +65,9 @@ class PolarsWarning(Exception): # type: ignore[no-redef]
class CategoricalRemappingWarning(PolarsWarning): # type: ignore[no-redef, misc]
"""Warning raised when a categorical needs to be remapped to be compatible with another categorical.""" # noqa: W505

class MapWithoutReturnDtypeWarning(PolarsWarning): # type: ignore[no-redef, misc]
"""Warning raised when `map_elements` is performed without specifying the return dtype.""" # noqa: W505


class InvalidAssert(PolarsError): # type: ignore[misc]
"""Exception raised when an unsupported testing assert is made."""
Expand Down Expand Up @@ -132,6 +136,7 @@ class CustomUFuncWarning(PolarsWarning): # type: ignore[misc]
"ChronoFormatWarning",
"DuplicateError",
"InvalidOperationError",
"MapWithoutReturnDtypeWarning",
"ModuleUpgradeRequired",
"NoDataError",
"NoRowsReturnedError",
Expand Down
20 changes: 14 additions & 6 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4255,7 +4255,9 @@ def map_elements(
The function is applied to each element of column `'a'`:

>>> df.with_columns( # doctest: +SKIP
... pl.col("a").map_elements(lambda x: x * 2).alias("a_times_2"),
... pl.col("a")
... .map_elements(lambda x: x * 2, return_dtype=pl.Int64)
... .alias("a_times_2"),
... )
shape: (4, 3)
┌─────┬─────┬───────────┐
Expand Down Expand Up @@ -4296,7 +4298,7 @@ def map_elements(
>>> (
... df.lazy()
... .group_by("b")
... .agg(pl.col("a").map_elements(lambda x: x.sum()))
... .agg(pl.col("a").map_elements(lambda x: x.sum(), return_dtype=pl.Int64))
... .collect()
... ) # doctest: +IGNORE_RESULT
shape: (3, 2)
Expand Down Expand Up @@ -4329,7 +4331,9 @@ def map_elements(
... }
... )
>>> df.with_columns(
... scaled=pl.col("val").map_elements(lambda s: s * len(s)).over("key"),
... scaled=pl.col("val")
... .map_elements(lambda s: s * len(s), return_dtype=pl.List(pl.Int64))
... .over("key"),
... ).sort("key")
shape: (6, 3)
┌─────┬─────┬────────┐
Expand Down Expand Up @@ -5315,12 +5319,16 @@ def xor(self, other: Any) -> Self:
... schema={"x": pl.UInt8, "y": pl.UInt8},
... )
>>> df.with_columns(
... pl.col("x").map_elements(binary_string).alias("bin_x"),
... pl.col("y").map_elements(binary_string).alias("bin_y"),
... pl.col("x")
... .map_elements(binary_string, return_dtype=pl.String)
... .alias("bin_x"),
... pl.col("y")
... .map_elements(binary_string, return_dtype=pl.String)
... .alias("bin_y"),
... pl.col("x").xor(pl.col("y")).alias("xor_xy"),
... pl.col("x")
... .xor(pl.col("y"))
... .map_elements(binary_string)
... .map_elements(binary_string, return_dtype=pl.String)
... .alias("bin_xor_xy"),
... )
shape: (4, 6)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/io/spreadsheet/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def _read_spreadsheet_pyxlsb(
if schema_overrides:
for idx, s in enumerate(series_data):
if schema_overrides.get(s.name) in (Datetime, Date):
series_data[idx] = s.map_elements(convert_date)
series_data[idx] = s.map_elements(convert_date, return_dtype=Datetime)

df = pl.DataFrame(
{s.name: s for s in series_data},
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5292,7 +5292,7 @@ def map_elements(
Examples
--------
>>> s = pl.Series("a", [1, 2, 3])
>>> s.map_elements(lambda x: x + 10) # doctest: +SKIP
>>> s.map_elements(lambda x: x + 10, return_dtype=pl.Int64) # doctest: +SKIP
shape: (3,)
Series: 'a' [i64]
[
Expand Down
8 changes: 8 additions & 0 deletions py-polars/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ create_exception!(
CategoricalRemappingWarning,
PolarsBaseWarning
);
create_exception!(
polars.exceptions,
MapWithoutReturnDtypeWarning,
PolarsBaseWarning
);

#[macro_export]
macro_rules! raise_err(
Expand All @@ -109,6 +114,9 @@ impl IntoPy<PyObject> for Wrap<PolarsWarning> {
PolarsWarning::CategoricalRemappingWarning => {
CategoricalRemappingWarning::type_object(py).to_object(py)
},
PolarsWarning::MapWithoutReturnDtypeWarning => {
MapWithoutReturnDtypeWarning::type_object(py).to_object(py)
},
PolarsWarning::UserWarning => PyUserWarning::type_object(py).to_object(py),
}
}
Expand Down
10 changes: 8 additions & 2 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ use crate::conversion::Wrap;
use crate::dataframe::PyDataFrame;
use crate::error::{
CategoricalRemappingWarning, ColumnNotFoundError, ComputeError, DuplicateError,
InvalidOperationError, NoDataError, OutOfBoundsError, PolarsBaseError, PolarsBaseWarning,
PyPolarsErr, SchemaError, SchemaFieldNotFoundError, StructFieldNotFoundError,
InvalidOperationError, MapWithoutReturnDtypeWarning, NoDataError, OutOfBoundsError,
PolarsBaseError, PolarsBaseWarning, PyPolarsErr, SchemaError, SchemaFieldNotFoundError,
StructFieldNotFoundError,
};
use crate::expr::PyExpr;
use crate::functions::PyStringCacheHolder;
Expand Down Expand Up @@ -296,6 +297,11 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> {
py.get_type::<CategoricalRemappingWarning>(),
)
.unwrap();
m.add(
"MapWithoutReturnDtypeWarning",
py.get_type::<MapWithoutReturnDtypeWarning>(),
)
.unwrap();

// Build info
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Expand Down
7 changes: 7 additions & 0 deletions py-polars/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,13 @@ impl PySeries {
) -> PyResult<PySeries> {
let series = &self.series;

if output_type.is_none() {
polars_warn!(
MapWithoutReturnDtypeWarning,
"Calling `map_elements` without specifying `return_dtype` can lead to unpredictable results. \
Specify `return_dtype` to silence this warning.")
}

if skip_nulls && (series.null_count() == series.len()) {
if let Some(output_type) = output_type {
return Ok(Series::full_null(series.name(), series.len(), &output_type.0).into());
Expand Down
18 changes: 15 additions & 3 deletions py-polars/tests/unit/datatypes/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,14 +961,26 @@ def test_temporal_dtypes_map_elements(
[
# don't actually do this; native expressions are MUCH faster ;)
pl.col("timestamp")
.map_elements(lambda x: const_dtm, skip_nulls=skip_nulls)
.map_elements(
lambda x: const_dtm,
skip_nulls=skip_nulls,
return_dtype=pl.Datetime,
)
.alias("const_dtm"),
# note: the below now trigger a PolarsInefficientMapWarning
pl.col("timestamp")
.map_elements(lambda x: x and x.date(), skip_nulls=skip_nulls)
.map_elements(
lambda x: x and x.date(),
skip_nulls=skip_nulls,
return_dtype=pl.Date,
)
.alias("date"),
pl.col("timestamp")
.map_elements(lambda x: x and x.time(), skip_nulls=skip_nulls)
.map_elements(
lambda x: x and x.time(),
skip_nulls=skip_nulls,
return_dtype=pl.Time,
)
.alias("time"),
]
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,10 @@ def test_parse_invalid_function(func: str) -> None:
("col", "func", "expr_repr"),
TEST_CASES,
)
@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")
@pytest.mark.filterwarnings(
"ignore:invalid value encountered:RuntimeWarning",
"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning",
)
def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
with pytest.warns(
PolarsInefficientMapWarning,
Expand Down Expand Up @@ -294,7 +297,10 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
)


@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")
@pytest.mark.filterwarnings(
"ignore:invalid value encountered:RuntimeWarning",
"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning",
)
def test_parse_apply_raw_functions() -> None:
lf = pl.LazyFrame({"a": [1.1, 2.0, 3.4]})

Expand Down Expand Up @@ -373,7 +379,9 @@ def x10(self, x: pl.Expr) -> pl.Expr:
):
pl_series = pl.Series("srs", [0, 1, 2, 3, 4])
assert_series_equal(
pl_series.map_elements(lambda x: numpy.cos(3) + x - abs(-1)),
pl_series.map_elements(
lambda x: numpy.cos(3) + x - abs(-1), return_dtype=pl.Float64
),
numpy.cos(3) + pl_series - 1,
)

Expand Down Expand Up @@ -405,6 +413,9 @@ def x10(self, x: pl.Expr) -> pl.Expr:
),
],
)
@pytest.mark.filterwarnings(
"ignore:.*without specifying `return_dtype`:polars.exceptions.MapWithoutReturnDtypeWarning"
)
def test_parse_apply_series(
data: list[Any], func: Callable[[Any], Any], expr_repr: str
) -> None:
Expand Down Expand Up @@ -443,7 +454,7 @@ def test_expr_exact_warning_message() -> None:
# and to keep the assertion on `len(warnings)`.
with pytest.warns(PolarsInefficientMapWarning, match=rf"^{msg}$") as warnings:
df = pl.DataFrame({"a": [1, 2, 3]})
df.select(pl.col("a").map_elements(lambda x: x + 1))
df.select(pl.col("a").map_elements(lambda x: x + 1, return_dtype=pl.Int64))

assert len(warnings) == 1

Expand Down
50 changes: 35 additions & 15 deletions py-polars/tests/unit/operations/map/test_map_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@ def test_map_elements_infer_list() -> None:
def test_map_elements_arithmetic_consistency() -> None:
df = pl.DataFrame({"A": ["a", "a"], "B": [2, 3]})
with pytest.warns(PolarsInefficientMapWarning, match="with this one instead"):
assert df.group_by("A").agg(pl.col("B").map_elements(lambda x: x + 1.0))[
"B"
].to_list() == [[3.0, 4.0]]
assert df.group_by("A").agg(
pl.col("B").map_elements(
lambda x: x + 1.0, return_dtype=pl.List(pl.Float64)
)
)["B"].to_list() == [[3.0, 4.0]]


def test_map_elements_struct() -> None:
Expand Down Expand Up @@ -85,9 +87,12 @@ def test_map_elements_list_any_value_fallback() -> None:
match=r'(?s)with this one instead:.*pl.col\("text"\).str.json_decode()',
):
df = pl.DataFrame({"text": ['[{"x": 1, "y": 2}, {"x": 3, "y": 4}]']})
assert df.select(pl.col("text").map_elements(json.loads)).to_dict(
as_series=False
) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]}
assert df.select(
pl.col("text").map_elements(
json.loads,
return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})),
)
).to_dict(as_series=False) == {"text": [[{"x": 1, "y": 2}, {"x": 3, "y": 4}]]}

# starts with empty list '[]'
df = pl.DataFrame(
Expand All @@ -99,9 +104,14 @@ def test_map_elements_list_any_value_fallback() -> None:
]
}
)
assert df.select(pl.col("text").map_elements(json.loads)).to_dict(
as_series=False
) == {"text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]]}
assert df.select(
pl.col("text").map_elements(
json.loads,
return_dtype=pl.List(pl.Struct({"x": pl.Int64, "y": pl.Int64})),
)
).to_dict(as_series=False) == {
"text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]]
}


def test_map_elements_all_types() -> None:
Expand Down Expand Up @@ -183,7 +193,9 @@ def test_map_elements_object_dtypes() -> None:
)
.alias("is_numeric1"),
pl.col("a")
.map_elements(lambda x: isinstance(x, (int, float)))
.map_elements(
lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean
)
.alias("is_numeric_infer"),
]
).to_dict(as_series=False) == {
Expand Down Expand Up @@ -212,12 +224,20 @@ def test_map_elements_dict() -> None:
match=r'(?s)with this one instead:.*pl.col\("abc"\).str.json_decode()',
):
df = pl.DataFrame({"abc": ['{"A":"Value1"}', '{"B":"Value2"}']})
assert df.select(pl.col("abc").map_elements(json.loads)).to_dict(
as_series=False
) == {"abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}]}
assert df.select(
pl.col("abc").map_elements(
json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String})
)
).to_dict(as_series=False) == {
"abc": [{"A": "Value1", "B": None}, {"A": None, "B": "Value2"}]
}
assert pl.DataFrame(
{"abc": ['{"A":"Value1", "B":"Value2"}', '{"B":"Value3"}']}
).select(pl.col("abc").map_elements(json.loads)).to_dict(as_series=False) == {
).select(
pl.col("abc").map_elements(
json.loads, return_dtype=pl.Struct({"A": pl.String, "B": pl.String})
)
).to_dict(as_series=False) == {
"abc": [{"A": "Value1", "B": "Value2"}, {"A": None, "B": "Value3"}]
}

Expand Down Expand Up @@ -305,7 +325,7 @@ def test_apply_deprecated() -> None:
with pytest.deprecated_call():
pl.col("a").apply(np.abs)
with pytest.deprecated_call():
pl.Series([1, 2, 3]).apply(np.abs)
pl.Series([1, 2, 3]).apply(np.abs, return_dtype=pl.Float64)


def test_cabbage_strategy_14396() -> None:
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,12 +1023,12 @@ def test_fill_nan() -> None:
def test_map_elements() -> None:
with pytest.warns(PolarsInefficientMapWarning):
a = pl.Series("a", [1, 2, None])
b = a.map_elements(lambda x: x**2)
b = a.map_elements(lambda x: x**2, return_dtype=pl.Int64)
assert list(b) == [1, 4, None]

with pytest.warns(PolarsInefficientMapWarning):
a = pl.Series("a", ["foo", "bar", None])
b = a.map_elements(lambda x: x + "py")
b = a.map_elements(lambda x: x + "py", return_dtype=pl.String)
assert list(b) == ["foopy", "barpy", None]

b = a.map_elements(lambda x: len(x), return_dtype=pl.Int32)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_apply() -> None:
ldf = pl.LazyFrame({"a": [1, 2, 3] * 20, "b": [1.0, 2.0, 3.0] * 20})
new = ldf.with_columns(
pl.col("a")
.map_elements(lambda s: s * 2, strategy=strategy) # type: ignore[arg-type]
.map_elements(lambda s: s * 2, strategy=strategy, return_dtype=pl.Int64) # type: ignore[arg-type]
.alias("foo")
)
expected = ldf.clone().with_columns((pl.col("a") * 2).alias("foo"))
Expand Down
Loading