diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 6c98f053c452..6659b01a7bd7 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -132,29 +132,34 @@ class OpNames: ) ) -# python functions that we can map to native expressions +# python attrs/funcs that map to native expressions +_PYTHON_ATTRS_MAP = { + "date": "dt.date()", + "day": "dt.day()", + "hour": "dt.hour()", + "microsecond": "dt.microsecond()", + "minute": "dt.minute()", + "month": "dt.month()", + "second": "dt.second()", + "year": "dt.year()", +} _PYTHON_CASTS_MAP = {"float": "Float64", "int": "Int64", "str": "String"} _PYTHON_BUILTINS = frozenset(_PYTHON_CASTS_MAP) | {"abs"} _PYTHON_METHODS_MAP = { + # string "lower": "str.to_lowercase", "title": "str.to_titlecase", "upper": "str.to_uppercase", + # temporal + "isoweekday": "dt.weekday", + "time": "dt.time", } -_FUNCTION_KINDS: list[dict[str, list[AbstractSet[str]]]] = [ - # lambda x: module.func(CONSTANT) - { - "argument_1_opname": [{"LOAD_CONST"}], - "argument_2_opname": [], - "module_opname": [OpNames.LOAD_ATTR], - "attribute_opname": [], - "module_name": [_NUMPY_MODULE_ALIASES], - "attribute_name": [], - "function_name": [_NUMPY_FUNCTIONS], - }, - # lambda x: module.func(x) +_MODULE_FUNCTIONS: list[dict[str, list[AbstractSet[str]]]] = [ + # lambda x: numpy.func(x) + # lambda x: numpy.func(CONSTANT) { - "argument_1_opname": [{"LOAD_FAST"}], + "argument_1_opname": [{"LOAD_FAST", "LOAD_CONST"}], "argument_2_opname": [], "module_opname": [OpNames.LOAD_ATTR], "attribute_opname": [], @@ -162,6 +167,7 @@ class OpNames: "attribute_name": [], "function_name": [_NUMPY_FUNCTIONS], }, + # lambda x: json.loads(x) { "argument_1_opname": [{"LOAD_FAST"}], "argument_2_opname": [], @@ -171,7 +177,7 @@ class OpNames: "attribute_name": [], "function_name": [{"loads"}], }, - # lambda x: module.func(x, CONSTANT) + # lambda x: datetime.strptime(x, CONSTANT) { "argument_1_opname": [{"LOAD_FAST"}], "argument_2_opname": [{"LOAD_CONST"}], @@ -194,11 +200,9 @@ class OpNames: ] # In addition to `lambda x: func(x)`, also support cases when a unary operation # has been applied to `x`, like `lambda x: func(-x)` or `lambda x: func(~x)`. -_FUNCTION_KINDS = [ - # Dict entry 1 has incompatible type "str": "object"; - # expected "str": "list[AbstractSet[str]]" +_MODULE_FUNCTIONS = [ {**kind, "argument_1_unary_opname": unary} # type: ignore[dict-item] - for kind in _FUNCTION_KINDS + for kind in _MODULE_FUNCTIONS for unary in [[set(OpNames.UNARY)], []] ] @@ -656,7 +660,8 @@ def _matches( idx: int, *, opnames: list[AbstractSet[str]], - argvals: list[AbstractSet[Any] | dict[Any, Any]] | None, + argvals: list[AbstractSet[Any] | dict[Any, Any] | None] | None, + is_attr: bool = False, ) -> list[Instruction]: """ Check if a sequence of Instructions matches the specified ops/argvals. @@ -669,9 +674,16 @@ def _matches( The full opname sequence that defines a match. argvals Associated argvals that must also match (in same position as opnames). + is_attr + Indicate if the match is expected to represent attribute access. """ n_required_ops, argvals = len(opnames), argvals or [] - instructions = self._instructions[idx : idx + n_required_ops] + idx_offset = idx + n_required_ops + if is_attr and (trailing_inst := self._instructions[idx_offset:1]): + if trailing_inst[0].opname == "CALL": + return [] + + instructions = self._instructions[idx:idx_offset] if len(instructions) == n_required_ops and all( inst.opname in match_opnames and (match_argval is None or inst.argval in match_argval) @@ -702,12 +714,30 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]: self._rewrite_functions, self._rewrite_methods, self._rewrite_builtins, + self._rewrite_attrs, ) ): updated_instructions.append(inst) idx += increment or 1 return updated_instructions + def _rewrite_attrs(self, idx: int, updated_instructions: list[Instruction]) -> int: + """Replace python attribute lookup with synthetic POLARS_EXPRESSION op.""" + if matching_instructions := self._matches( + idx, + opnames=[{"LOAD_FAST"}, {"LOAD_ATTR"}], + argvals=[None, _PYTHON_ATTRS_MAP], + is_attr=True, + ): + inst = matching_instructions[1] + expr_name = _PYTHON_ATTRS_MAP[inst.argval] + synthetic_call = inst._replace( + opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name + ) + updated_instructions.extend([matching_instructions[0], synthetic_call]) + + return len(matching_instructions) + def _rewrite_builtins( self, idx: int, updated_instructions: list[Instruction] ) -> int: @@ -738,7 +768,7 @@ def _rewrite_functions( self, idx: int, updated_instructions: list[Instruction] ) -> int: """Replace function calls with a synthetic POLARS_EXPRESSION op.""" - for function_kind in _FUNCTION_KINDS: + for function_kind in _MODULE_FUNCTIONS: opnames: list[AbstractSet[str]] = [ {"LOAD_GLOBAL", "LOAD_DEREF"}, *function_kind["module_opname"], diff --git a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py index 670299f889bf..4b86c83ed7cb 100644 --- a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py +++ b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py @@ -168,6 +168,19 @@ 'pl.col("d").str.to_datetime(format="%Y-%m-%d")', ), # --------------------------------------------- + # temporal attributes/methods + # --------------------------------------------- + ( + "f", + "lambda x: x.isoweekday()", + 'pl.col("f").dt.weekday()', + ), + ( + "f", + "lambda x: x.hour + x.minute + x.second", + '(pl.col("f").dt.hour() + pl.col("f").dt.minute()) + pl.col("f").dt.second()', + ), + # --------------------------------------------- # Bitwise shifts # --------------------------------------------- ( @@ -244,6 +257,11 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: "c": ['{"a": 1}', '{"b": 2}', '{"c": 3}'], "d": ["2020-01-01", "2020-01-02", "2020-01-03"], "e": [1.5, 2.4, 3.1], + "f": [ + datetime(1999, 12, 31), + datetime(2024, 5, 6), + datetime(2077, 10, 20), + ], } ) result_frame = df.select( @@ -254,7 +272,11 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: x=pl.col(col), y=pl.col(col).map_elements(eval(func)), ) - assert_frame_equal(result_frame, expected_frame) + assert_frame_equal( + result_frame, + expected_frame, + check_dtype=(".dt." not in suggested_expression), + ) @pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning")