From 6713cf6fb5f72366addb1b5e16ff06aeab98a512 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 16 Feb 2024 08:24:41 +0000 Subject: [PATCH] tighten attribute detection --- py-polars/polars/utils/udfs.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 46c9850c77a3..5179c8785a86 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -659,6 +659,7 @@ def _matches( *, opnames: list[AbstractSet[str]], argvals: list[AbstractSet[Any] | dict[Any, Any] | None] | None, + is_attr: bool = False, ) -> list[Instruction]: """ Check if a sequence of Instructions matches the specified ops/argvals. @@ -671,9 +672,17 @@ def _matches( The full opname sequence that defines a match. argvals Associated argvals that must also match (in same position as opnames). + is_attr + Indicate if the match is expected to represent attribute access. """ n_required_ops, argvals = len(opnames), argvals or [] - instructions = self._instructions[idx : idx + n_required_ops] + idx_offset = idx + n_required_ops + if is_attr and (trailing_inst := self._instructions[idx_offset:1]): + # if there's a trailing CALL, it's an implicit method invocation + if trailing_inst.opname == "CALL": + return [] + + instructions = self._instructions[idx:idx_offset] if len(instructions) == n_required_ops and all( inst.opname in match_opnames and (match_argval is None or inst.argval in match_argval) @@ -717,6 +726,7 @@ def _rewrite_attrs(self, idx: int, updated_instructions: list[Instruction]) -> i idx, opnames=[{"LOAD_FAST"}, {"LOAD_ATTR"}], argvals=[None, _PYTHON_ATTRS_MAP], + is_attr=True, ): inst = matching_instructions[1] expr_name = _PYTHON_ATTRS_MAP[inst.argval]