Skip to content

Commit

Permalink
[mypyc] Optimize dunder methods (python#17934)
Browse files Browse the repository at this point in the history
This change gives mypyc the ability to optionally optimize dunder
methods that can guarantee strict adherence to its signature typing. The
optimization allows to bypass vtable for dunder methods in certain cases
that are applicable.

Currently, mypy has adopted the convention of accept dunder methods that
return `NotImplemented` value even when its signature do not reflect
this possibility. With this change and by enabling an special flag,
mypyc will expect strict typing be honored and will unleash more
optimizations like native call without vtable lookup for some cases on
dunder method calls.
For example it could avoid calls to RichCompare Python API making the
code can be fully optimized in by the C compiler when some comparison
with dunders are required.

Example:

```python
@Final
class A:
    def __init__(self, x: i32) -> None:
        self.x: Final = x

    def __lt__(self, other: "A") -> bool:
        return self.x < other.x

A(1) < A(2)
```

would produce:

```c
char CPyDef_A_____lt__(PyObject *cpy_r_self, PyObject *cpy_r_other) {
    int32_t cpy_r_r0;
    int32_t cpy_r_r1;
    char cpy_r_r2;
    cpy_r_r0 = ((AObject *)cpy_r_self)->_x;
    cpy_r_r1 = ((AObject *)cpy_r_other)->_x;
    cpy_r_r2 = cpy_r_r0 < cpy_r_r1;
    return cpy_r_r2;
}

...
cpy_r_r29 = CPyDef_A_____lt__(cpy_r_r27, cpy_r_r28);
...
```

Instead of:

```c
PyObject *CPyDef_A_____lt__(PyObject *cpy_r_self, PyObject *cpy_r_other) {
    int32_t cpy_r_r0;
    int32_t cpy_r_r1;
    char cpy_r_r2;
    PyObject *cpy_r_r3;
    cpy_r_r0 = ((AObject *)cpy_r_self)->_x;
    cpy_r_r1 = ((AObject *)cpy_r_other)->_x;
    cpy_r_r2 = cpy_r_r0 < cpy_r_r1;
    cpy_r_r3 = cpy_r_r2 ? Py_True : Py_False;
    CPy_INCREF(cpy_r_r3);
    return cpy_r_r3;
}

...
cpy_r_r29 = PyObject_RichCompare(cpy_r_r27, cpy_r_r28, 0);
...
```

Default behavior is kept.
Tests run with both of strict typing enabled and disabled.
  • Loading branch information
jairov4 authored Oct 18, 2024
1 parent bd2aafc commit c9d4c61
Show file tree
Hide file tree
Showing 11 changed files with 210 additions and 72 deletions.
5 changes: 3 additions & 2 deletions mypyc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from mypyc.build import mypycify
setup(name='mypyc_output',
ext_modules=mypycify({}, opt_level="{}", debug_level="{}"),
ext_modules=mypycify({}, opt_level="{}", debug_level="{}", strict_dunder_typing={}),
)
"""

Expand All @@ -38,10 +38,11 @@ def main() -> None:

opt_level = os.getenv("MYPYC_OPT_LEVEL", "3")
debug_level = os.getenv("MYPYC_DEBUG_LEVEL", "1")
strict_dunder_typing = bool(int(os.getenv("MYPYC_STRICT_DUNDER_TYPING", "0")))

setup_file = os.path.join(build_dir, "setup.py")
with open(setup_file, "w") as f:
f.write(setup_format.format(sys.argv[1:], opt_level, debug_level))
f.write(setup_format.format(sys.argv[1:], opt_level, debug_level, strict_dunder_typing))

# We don't use run_setup (like we do in the test suite) because it throws
# away the error code from distutils, and we don't care about the slight
Expand Down
5 changes: 5 additions & 0 deletions mypyc/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def mypycify(
skip_cgen_input: Any | None = None,
target_dir: str | None = None,
include_runtime_files: bool | None = None,
strict_dunder_typing: bool = False,
) -> list[Extension]:
"""Main entry point to building using mypyc.
Expand Down Expand Up @@ -509,6 +510,9 @@ def mypycify(
should be directly #include'd instead of linked
separately in order to reduce compiler invocations.
Defaults to False in multi_file mode, True otherwise.
strict_dunder_typing: If True, force dunder methods to have the return type
of the method strictly, which can lead to more
optimization opportunities. Defaults to False.
"""

# Figure out our configuration
Expand All @@ -519,6 +523,7 @@ def mypycify(
separate=separate is not False,
target_dir=target_dir,
include_runtime_files=include_runtime_files,
strict_dunder_typing=strict_dunder_typing,
)

# Generate all the actual important C code
Expand Down
55 changes: 34 additions & 21 deletions mypyc/irbuild/classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
pytype_from_template_op,
type_object_op,
)
from mypyc.subtype import is_subtype


def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None:
Expand Down Expand Up @@ -801,30 +802,42 @@ def create_ne_from_eq(builder: IRBuilder, cdef: ClassDef) -> None:

def gen_glue_ne_method(builder: IRBuilder, cls: ClassIR, line: int) -> None:
"""Generate a "__ne__" method from a "__eq__" method."""
with builder.enter_method(cls, "__ne__", object_rprimitive):
rhs_arg = builder.add_argument("rhs", object_rprimitive)

# If __eq__ returns NotImplemented, then __ne__ should also
not_implemented_block, regular_block = BasicBlock(), BasicBlock()
func_ir = cls.get_method("__eq__")
assert func_ir
eq_sig = func_ir.decl.sig
strict_typing = builder.options.strict_dunders_typing
with builder.enter_method(cls, "__ne__", eq_sig.ret_type):
rhs_type = eq_sig.args[0].type if strict_typing else object_rprimitive
rhs_arg = builder.add_argument("rhs", rhs_type)
eqval = builder.add(MethodCall(builder.self(), "__eq__", [rhs_arg], line))
not_implemented = builder.add(
LoadAddress(not_implemented_op.type, not_implemented_op.src, line)
)
builder.add(
Branch(
builder.translate_is_op(eqval, not_implemented, "is", line),
not_implemented_block,
regular_block,
Branch.BOOL,
)
)

builder.activate_block(regular_block)
retval = builder.coerce(builder.unary_op(eqval, "not", line), object_rprimitive, line)
builder.add(Return(retval))
can_return_not_implemented = is_subtype(not_implemented_op.type, eq_sig.ret_type)
return_bool = is_subtype(eq_sig.ret_type, bool_rprimitive)

builder.activate_block(not_implemented_block)
builder.add(Return(not_implemented))
if not strict_typing or can_return_not_implemented:
# If __eq__ returns NotImplemented, then __ne__ should also
not_implemented_block, regular_block = BasicBlock(), BasicBlock()
not_implemented = builder.add(
LoadAddress(not_implemented_op.type, not_implemented_op.src, line)
)
builder.add(
Branch(
builder.translate_is_op(eqval, not_implemented, "is", line),
not_implemented_block,
regular_block,
Branch.BOOL,
)
)
builder.activate_block(regular_block)
rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive
retval = builder.coerce(builder.unary_op(eqval, "not", line), rettype, line)
builder.add(Return(retval))
builder.activate_block(not_implemented_block)
builder.add(Return(not_implemented))
else:
rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive
retval = builder.coerce(builder.unary_op(eqval, "not", line), rettype, line)
builder.add(Return(retval))


def load_non_ext_class(
Expand Down
14 changes: 8 additions & 6 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@


def transform_func_def(builder: IRBuilder, fdef: FuncDef) -> None:
func_ir, func_reg = gen_func_item(builder, fdef, fdef.name, builder.mapper.fdef_to_sig(fdef))
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
func_ir, func_reg = gen_func_item(builder, fdef, fdef.name, sig)

# If the function that was visited was a nested function, then either look it up in our
# current environment or define it if it was not already defined.
Expand All @@ -113,9 +114,8 @@ def transform_overloaded_func_def(builder: IRBuilder, o: OverloadedFuncDef) -> N


def transform_decorator(builder: IRBuilder, dec: Decorator) -> None:
func_ir, func_reg = gen_func_item(
builder, dec.func, dec.func.name, builder.mapper.fdef_to_sig(dec.func)
)
sig = builder.mapper.fdef_to_sig(dec.func, builder.options.strict_dunders_typing)
func_ir, func_reg = gen_func_item(builder, dec.func, dec.func.name, sig)
decorated_func: Value | None = None
if func_reg:
decorated_func = load_decorated_func(builder, dec.func, func_reg)
Expand Down Expand Up @@ -416,7 +416,8 @@ def handle_ext_method(builder: IRBuilder, cdef: ClassDef, fdef: FuncDef) -> None
# Perform the function of visit_method for methods inside extension classes.
name = fdef.name
class_ir = builder.mapper.type_to_ir[cdef.info]
func_ir, func_reg = gen_func_item(builder, fdef, name, builder.mapper.fdef_to_sig(fdef), cdef)
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef)
builder.functions.append(func_ir)

if is_decorated(builder, fdef):
Expand Down Expand Up @@ -481,7 +482,8 @@ def handle_non_ext_method(
) -> None:
# Perform the function of visit_method for methods inside non-extension classes.
name = fdef.name
func_ir, func_reg = gen_func_item(builder, fdef, name, builder.mapper.fdef_to_sig(fdef), cdef)
sig = builder.mapper.fdef_to_sig(fdef, builder.options.strict_dunders_typing)
func_ir, func_reg = gen_func_item(builder, fdef, name, sig, cdef)
assert func_reg is not None
builder.functions.append(func_ir)

Expand Down
53 changes: 42 additions & 11 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from mypy.argmap import map_actuals_to_formals
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind
from mypy.operators import op_methods
from mypy.operators import op_methods, unary_op_methods
from mypy.types import AnyType, TypeOfAny
from mypyc.common import (
BITMAP_BITS,
Expand Down Expand Up @@ -167,6 +167,7 @@
buf_init_item,
fast_isinstance_op,
none_object_op,
not_implemented_op,
var_object_size,
)
from mypyc.primitives.registry import (
Expand Down Expand Up @@ -1398,11 +1399,48 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
if base_op in float_op_to_id:
return self.float_op(lreg, rreg, base_op, line)

dunder_op = self.dunder_op(lreg, rreg, op, line)
if dunder_op:
return dunder_op

primitive_ops_candidates = binary_ops.get(op, [])
target = self.matching_primitive_op(primitive_ops_candidates, [lreg, rreg], line)
assert target, "Unsupported binary operation: %s" % op
return target

def dunder_op(self, lreg: Value, rreg: Value | None, op: str, line: int) -> Value | None:
"""
Dispatch a dunder method if applicable.
For example for `a + b` it will use `a.__add__(b)` which can lead to higher performance
due to the fact that the method could be already compiled and optimized instead of going
all the way through `PyNumber_Add(a, b)` python api (making a jump into the python DL).
"""
ltype = lreg.type
if not isinstance(ltype, RInstance):
return None

method_name = op_methods.get(op) if rreg else unary_op_methods.get(op)
if method_name is None:
return None

if not ltype.class_ir.has_method(method_name):
return None

decl = ltype.class_ir.method_decl(method_name)
if not rreg and len(decl.sig.args) != 1:
return None

if rreg and (len(decl.sig.args) != 2 or not is_subtype(rreg.type, decl.sig.args[1].type)):
return None

if rreg and is_subtype(not_implemented_op.type, decl.sig.ret_type):
# If the method is able to return NotImplemented, we should not optimize it.
# We can just let go so it will be handled through the python api.
return None

args = [rreg] if rreg else []
return self.gen_method_call(lreg, method_name, args, decl.sig.ret_type, line)

def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -> Value:
"""Check if a tagged integer is a short integer.
Expand Down Expand Up @@ -1558,16 +1596,9 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
if isinstance(value, Float):
return Float(-value.value, value.line)
if isinstance(typ, RInstance):
if expr_op == "-":
method = "__neg__"
elif expr_op == "+":
method = "__pos__"
elif expr_op == "~":
method = "__invert__"
else:
method = ""
if method and typ.class_ir.has_method(method):
return self.gen_method_call(value, method, [], None, line)
result = self.dunder_op(value, None, expr_op, line)
if result is not None:
return result
call_c_ops_candidates = unary_ops.get(expr_op, [])
target = self.matching_call_c(call_c_ops_candidates, [value], line)
assert target, "Unsupported unary operation: %s" % expr_op
Expand Down
15 changes: 9 additions & 6 deletions mypyc/irbuild/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_arg_rtype(self, typ: Type, kind: ArgKind) -> RType:
else:
return self.type_to_rtype(typ)

def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature:
def fdef_to_sig(self, fdef: FuncDef, strict_dunders_typing: bool) -> FuncSignature:
if isinstance(fdef.type, CallableType):
arg_types = [
self.get_arg_rtype(typ, kind)
Expand Down Expand Up @@ -199,11 +199,14 @@ def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature:
)
]

# We force certain dunder methods to return objects to support letting them
# return NotImplemented. It also avoids some pointless boxing and unboxing,
# since tp_richcompare needs an object anyways.
if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"):
ret = object_rprimitive
if not strict_dunders_typing:
# We force certain dunder methods to return objects to support letting them
# return NotImplemented. It also avoids some pointless boxing and unboxing,
# since tp_richcompare needs an object anyways.
# However, it also prevents some optimizations.
if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"):
ret = object_rprimitive

return FuncSignature(args, ret)

def is_native_module(self, module: str) -> bool:
Expand Down
Loading

0 comments on commit c9d4c61

Please sign in to comment.