diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 6c98f053c452..e9b018f633a5 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -132,29 +132,35 @@ class OpNames: ) ) -# python functions that we can map to native expressions +# python attrs/funcs that map to native expressions +_PYTHON_ATTRS_MAP = { + "date": "dt.date()", + "day": "dt.day()", + "hour": "dt.hour()", + "microsecond": "dt.microsecond()", + "minute": "dt.minute()", + "month": "dt.month()", + "second": "dt.second()", + "year": "dt.year()", +} _PYTHON_CASTS_MAP = {"float": "Float64", "int": "Int64", "str": "String"} _PYTHON_BUILTINS = frozenset(_PYTHON_CASTS_MAP) | {"abs"} _PYTHON_METHODS_MAP = { + # string "lower": "str.to_lowercase", "title": "str.to_titlecase", "upper": "str.to_uppercase", + # temporal + "isoweekday": "dt.weekday", + "date": "dt.date", + "time": "dt.time", } -_FUNCTION_KINDS: list[dict[str, list[AbstractSet[str]]]] = [ - # lambda x: module.func(CONSTANT) - { - "argument_1_opname": [{"LOAD_CONST"}], - "argument_2_opname": [], - "module_opname": [OpNames.LOAD_ATTR], - "attribute_opname": [], - "module_name": [_NUMPY_MODULE_ALIASES], - "attribute_name": [], - "function_name": [_NUMPY_FUNCTIONS], - }, - # lambda x: module.func(x) +_MODULE_FUNCTIONS: list[dict[str, list[AbstractSet[str]]]] = [ + # lambda x: numpy.func(x) + # lambda x: numpy.func(CONSTANT) { - "argument_1_opname": [{"LOAD_FAST"}], + "argument_1_opname": [{"LOAD_FAST", "LOAD_CONST"}], "argument_2_opname": [], "module_opname": [OpNames.LOAD_ATTR], "attribute_opname": [], @@ -162,6 +168,7 @@ class OpNames: "attribute_name": [], "function_name": [_NUMPY_FUNCTIONS], }, + # lambda x: json.loads(x) { "argument_1_opname": [{"LOAD_FAST"}], "argument_2_opname": [], @@ -171,7 +178,7 @@ class OpNames: "attribute_name": [], "function_name": [{"loads"}], }, - # lambda x: module.func(x, CONSTANT) + # lambda x: datetime.strptime(x, CONSTANT) { "argument_1_opname": [{"LOAD_FAST"}], "argument_2_opname": [{"LOAD_CONST"}], @@ -194,13 +201,12 @@ class OpNames: ] # In addition to `lambda x: func(x)`, also support cases when a unary operation # has been applied to `x`, like `lambda x: func(-x)` or `lambda x: func(~x)`. -_FUNCTION_KINDS = [ - # Dict entry 1 has incompatible type "str": "object"; - # expected "str": "list[AbstractSet[str]]" +_MODULE_FUNCTIONS = [ {**kind, "argument_1_unary_opname": unary} # type: ignore[dict-item] - for kind in _FUNCTION_KINDS + for kind in _MODULE_FUNCTIONS for unary in [[set(OpNames.UNARY)], []] ] +_RE_IMPLICIT_BOOL = re.compile(r'pl\.col\("([^"]*)"\) & pl\.col\("\1"\)\.(.+)') def _get_all_caller_variables() -> dict[str, Any]: @@ -252,6 +258,12 @@ def __init__(self, function: Callable[[Any], Any], map_target: MapTarget): instructions=original_instructions, ) + def _omit_implicit_bool(self, expr: str) -> str: + """Drop extraneous/implied bool (eg: `pl.col("d") & pl.col("d").dt.date()`).""" + while _RE_IMPLICIT_BOOL.search(expr): + expr = _RE_IMPLICIT_BOOL.sub(repl=r'pl.col("\1").\2', string=expr) + return expr + @staticmethod def _get_param_name(function: Callable[[Any], Any]) -> str | None: """Return single function parameter name.""" @@ -415,11 +427,13 @@ def to_expression(self, col: str) -> str | None: # constant value (e.g. `lambda x: CONST + 123`), so we don't want to warn if "pl.col(" not in polars_expr: return None - elif self._map_target == "series": - target_name = self._get_target_name(col, polars_expr) - return polars_expr.replace(f'pl.col("{col}")', target_name) else: - return polars_expr + polars_expr = self._omit_implicit_bool(polars_expr) + if self._map_target == "series": + target_name = self._get_target_name(col, polars_expr) + return polars_expr.replace(f'pl.col("{col}")', target_name) + else: + return polars_expr def warn( self, @@ -656,7 +670,8 @@ def _matches( idx: int, *, opnames: list[AbstractSet[str]], - argvals: list[AbstractSet[Any] | dict[Any, Any]] | None, + argvals: list[AbstractSet[Any] | dict[Any, Any] | None] | None, + is_attr: bool = False, ) -> list[Instruction]: """ Check if a sequence of Instructions matches the specified ops/argvals. @@ -669,9 +684,19 @@ def _matches( The full opname sequence that defines a match. argvals Associated argvals that must also match (in same position as opnames). + is_attr + Indicate if the match is expected to represent attribute access. """ n_required_ops, argvals = len(opnames), argvals or [] - instructions = self._instructions[idx : idx + n_required_ops] + idx_offset = idx + n_required_ops + if ( + is_attr + and (trailing_inst := self._instructions[idx_offset : idx_offset + 1]) + and trailing_inst[0].opname in OpNames.CALL # not pure attr if called + ): + return [] + + instructions = self._instructions[idx:idx_offset] if len(instructions) == n_required_ops and all( inst.opname in match_opnames and (match_argval is None or inst.argval in match_argval) @@ -702,12 +727,30 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]: self._rewrite_functions, self._rewrite_methods, self._rewrite_builtins, + self._rewrite_attrs, ) ): updated_instructions.append(inst) idx += increment or 1 return updated_instructions + def _rewrite_attrs(self, idx: int, updated_instructions: list[Instruction]) -> int: + """Replace python attribute lookup with synthetic POLARS_EXPRESSION op.""" + if matching_instructions := self._matches( + idx, + opnames=[{"LOAD_FAST"}, {"LOAD_ATTR"}], + argvals=[None, _PYTHON_ATTRS_MAP], + is_attr=True, + ): + inst = matching_instructions[1] + expr_name = _PYTHON_ATTRS_MAP[inst.argval] + synthetic_call = inst._replace( + opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name + ) + updated_instructions.extend([matching_instructions[0], synthetic_call]) + + return len(matching_instructions) + def _rewrite_builtins( self, idx: int, updated_instructions: list[Instruction] ) -> int: @@ -738,7 +781,7 @@ def _rewrite_functions( self, idx: int, updated_instructions: list[Instruction] ) -> int: """Replace function calls with a synthetic POLARS_EXPRESSION op.""" - for function_kind in _FUNCTION_KINDS: + for function_kind in _MODULE_FUNCTIONS: opnames: list[AbstractSet[str]] = [ {"LOAD_GLOBAL", "LOAD_DEREF"}, *function_kind["module_opname"], diff --git a/py-polars/tests/unit/datatypes/test_temporal.py b/py-polars/tests/unit/datatypes/test_temporal.py index a557cd039786..726629f0114f 100644 --- a/py-polars/tests/unit/datatypes/test_temporal.py +++ b/py-polars/tests/unit/datatypes/test_temporal.py @@ -12,7 +12,11 @@ import polars as pl from polars.datatypes import DATETIME_DTYPES, DTYPE_TEMPORAL_UNITS, TEMPORAL_DTYPES -from polars.exceptions import ComputeError, TimeZoneAwareConstructorWarning +from polars.exceptions import ( + ComputeError, + PolarsInefficientMapWarning, + TimeZoneAwareConstructorWarning, +) from polars.testing import ( assert_frame_equal, assert_series_equal, @@ -947,45 +951,50 @@ def test_temporal_dtypes_map_elements( ) const_dtm = datetime(2010, 9, 12) - assert_frame_equal( - df.with_columns( - [ - # don't actually do any of this; native expressions are MUCH faster ;) - pl.col("timestamp") - .map_elements(lambda x: const_dtm, skip_nulls=skip_nulls) - .alias("const_dtm"), - pl.col("timestamp") - .map_elements(lambda x: x and x.date(), skip_nulls=skip_nulls) - .alias("date"), - pl.col("timestamp") - .map_elements(lambda x: x and x.time(), skip_nulls=skip_nulls) - .alias("time"), - ] - ), - pl.DataFrame( - [ - ( - datetime(2010, 9, 12, 10, 19, 54), - datetime(2010, 9, 12, 0, 0), - date(2010, 9, 12), - time(10, 19, 54), - ), - (None, expected_value, None, None), - ( - datetime(2009, 2, 13, 23, 31, 30), - datetime(2010, 9, 12, 0, 0), - date(2009, 2, 13), - time(23, 31, 30), - ), - ], - schema={ - "timestamp": pl.Datetime("ms"), - "const_dtm": pl.Datetime("us"), - "date": pl.Date, - "time": pl.Time, - }, - ), - ) + with pytest.warns( + PolarsInefficientMapWarning, + match=r"(?s)Replace this expression.*lambda x:", + ): + assert_frame_equal( + df.with_columns( + [ + # don't actually do this; native expressions are MUCH faster ;) + pl.col("timestamp") + .map_elements(lambda x: const_dtm, skip_nulls=skip_nulls) + .alias("const_dtm"), + # note: the below now trigger a PolarsInefficientMapWarning + pl.col("timestamp") + .map_elements(lambda x: x and x.date(), skip_nulls=skip_nulls) + .alias("date"), + pl.col("timestamp") + .map_elements(lambda x: x and x.time(), skip_nulls=skip_nulls) + .alias("time"), + ] + ), + pl.DataFrame( + [ + ( + datetime(2010, 9, 12, 10, 19, 54), + datetime(2010, 9, 12, 0, 0), + date(2010, 9, 12), + time(10, 19, 54), + ), + (None, expected_value, None, None), + ( + datetime(2009, 2, 13, 23, 31, 30), + datetime(2010, 9, 12, 0, 0), + date(2009, 2, 13), + time(23, 31, 30), + ), + ], + schema={ + "timestamp": pl.Datetime("ms"), + "const_dtm": pl.Datetime("us"), + "date": pl.Date, + "time": pl.Time, + }, + ), + ) def test_timelike_init() -> None: diff --git a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py index 670299f889bf..a5e907306875 100644 --- a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py +++ b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py @@ -168,6 +168,19 @@ 'pl.col("d").str.to_datetime(format="%Y-%m-%d")', ), # --------------------------------------------- + # temporal attributes/methods + # --------------------------------------------- + ( + "f", + "lambda x: x.isoweekday()", + 'pl.col("f").dt.weekday()', + ), + ( + "f", + "lambda x: x.hour + x.minute + x.second", + '(pl.col("f").dt.hour() + pl.col("f").dt.minute()) + pl.col("f").dt.second()', + ), + # --------------------------------------------- # Bitwise shifts # --------------------------------------------- ( @@ -244,6 +257,11 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: "c": ['{"a": 1}', '{"b": 2}', '{"c": 3}'], "d": ["2020-01-01", "2020-01-02", "2020-01-03"], "e": [1.5, 2.4, 3.1], + "f": [ + datetime(1999, 12, 31), + datetime(2024, 5, 6), + datetime(2077, 10, 20), + ], } ) result_frame = df.select( @@ -254,7 +272,11 @@ def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: x=pl.col(col), y=pl.col(col).map_elements(eval(func)), ) - assert_frame_equal(result_frame, expected_frame) + assert_frame_equal( + result_frame, + expected_frame, + check_dtype=(".dt." not in suggested_expression), + ) @pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") @@ -411,6 +433,15 @@ def test_expr_exact_warning_message() -> None: assert len(warnings) == 1 +def test_omit_implicit_bool() -> None: + parser = BytecodeParser( + function=lambda x: x and x and x.date(), + map_target="expr", + ) + suggested_expression = parser.to_expression("d") + assert suggested_expression == 'pl.col("d").dt.date()' + + def test_partial_functions_13523() -> None: def plus(value, amount: int): # type: ignore[no-untyped-def] return value + amount