Skip to content

Commit

Permalink
[SOT][3.12] replace POP_JUMP_{BACKWARD,FORWARD}_IF_{TRUE,FALSE} to …
Browse files Browse the repository at this point in the history
…`POP_JUMP_IF_{TRUE,FALSE}` (#62155)
  • Loading branch information
gouzil authored Mar 4, 2024
1 parent adb8bc2 commit de1777b
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1697,8 +1697,9 @@ def FOR_ITER(self, instr):

self._inline_call_for_loop(iterator, instr)
self._lasti = self.indexof(instr.jump_to)
next_instr = self._instructions[self._lasti]
self._lasti += int(next_instr.opname == 'END_FOR')
if sys.version_info >= (3, 12):
assert self._instructions[self._lasti].opname == "END_FOR"
self._lasti += 1
except BreakGraphError as e:
log(3, f"[BreakGraph] FOR_ITER sim for loop failed for: {e}\n")
if backup_iter_idx:
Expand Down Expand Up @@ -2071,10 +2072,17 @@ def create_after_loop_fn():
return None
pycode_gen = PyCodeGen(self._frame)
origin_instrs = get_instructions(pycode_gen._origin_code)
resume_fn_end_idx = loop_body_end_idx

# skip resume END_FOR in python3.12
if sys.version_info >= (3, 12):
assert origin_instrs[loop_body_end_idx].opname == "END_FOR"
resume_fn_end_idx += 1

pycode_gen.set_function_inputs(
after_loop_fn_inputs, stack_size=len(self.stack) - 1
)
pycode_gen.extend_instrs(origin_instrs[loop_body_end_idx:])
pycode_gen.extend_instrs(origin_instrs[resume_fn_end_idx:])
# the resume_fn contains return code, so we don't need set output here
# global vars are updated correctly, and need local vars will return
after_loop_fn = pycode_gen.create_function()
Expand Down Expand Up @@ -2138,8 +2146,13 @@ def create_after_loop_fn():
self._graph.pycode_gen.gen_jump(
for_iter, direction=JumpDirection.BACKWARD
)

if sys.version_info >= (3, 12):
end_for = self._graph.pycode_gen.add_instr("END_FOR")

nop = self._graph.pycode_gen.add_instr("NOP")
for_iter.jump_to = nop

for_iter.jump_to = end_for if sys.version_info >= (3, 12) else nop
jump_if_break.jump_to = nop

# 9. prepare inputs and call after_loop_fn
Expand Down Expand Up @@ -2209,6 +2222,8 @@ def create_inline_call_fn():
for_iter_instr, direction=JumpDirection.BACKWARD
)

if sys.version_info >= (3, 12):
end_for = pycode_gen.add_instr("END_FOR")
nop_for_break = pycode_gen.add_instr("NOP")

# 2.4. relocate jumps
Expand All @@ -2223,6 +2238,8 @@ def create_inline_call_fn():
instr.jump_to = nop_for_break

jump.jump_to = for_iter_instr
if sys.version_info >= (3, 12):
for_iter_instr.jump_to = end_for

pycode_gen.set_function_outputs(output_var_names)
inline_call_fn = pycode_gen.create_function()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import contextlib
import inspect
import re
import sys
from typing import TYPE_CHECKING

from ...profiler import event_register
Expand Down Expand Up @@ -316,6 +317,9 @@ def FOR_ITER(self, instr: Instruction):
self.stack.pop()
assert isinstance(instr.jump_to, Instruction)
self._lasti = self.indexof(instr.jump_to)
if sys.version_info >= (3, 12):
assert self._instructions[self._lasti].opname == "END_FOR"
self._lasti += 1

else:
self._graph.remove_global_guarded_variable(iterator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ def gen_pop_jump(
direction: JumpDirection = JumpDirection.FORWARD,
suffix: PopJumpCond = PopJumpCond.NONE,
) -> Instruction:
if sys.version_info >= (3, 11):
if sys.version_info >= (3, 11) and sys.version_info < (3, 12):
return self.add_instr(
f"POP_JUMP_{direction.value}_IF_{suffix.value}", jump_to=jump_to
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import sys
from typing import TYPE_CHECKING

from paddle.jit.sot.utils import log, log_do

from ...utils import InnerError
from .instruction_utils import instrs_info
from .stack_analyse import StackAnalyser

if TYPE_CHECKING:
from .instruction_utils import Instruction


def apply_instr_pass(instrs, code_options):
def apply_instr_pass(instrs: list[Instruction], code_options):
log(4, f"[Opcode Pass]: Original New Code {code_options['co_name']}:\n")
log_do(4, lambda: print(instrs_info(instrs)))
supported_passes = (
supported_passes = [
remove_load_store_pass,
remove_duplicate_resume,
check_precall_followed_by_call,
)
]

if sys.version_info >= (3, 12):
supported_passes.append(check_for_iter_jump_to)

for instr_pass in supported_passes:
instr_pass(instrs, code_options)
Expand All @@ -38,7 +49,7 @@ def apply_instr_pass(instrs, code_options):
log_do(4, lambda: print(instrs_info(instrs)))


def find_stored_once_local_vars(instrs, code_options):
def find_stored_once_local_vars(instrs: list[Instruction], code_options):
"""
find out the local var names which is only stored once
"""
Expand All @@ -61,13 +72,13 @@ def find_stored_once_local_vars(instrs, code_options):
return stored_once


def find_loaded_once_local_vars(instrs, code_options):
def find_loaded_once_local_vars(instrs: list[Instruction], code_options):
"""
find out the local var names which is only stored once
"""
loaded_vars = {}
for instr in instrs:
if instr.opname == "LOAD_FAST":
if instr.opname in ["LOAD_FAST", "LOAD_FAST_CHECK"]:
if instr.argval in loaded_vars:
loaded_vars[instr.argval] += 1
else:
Expand All @@ -77,14 +88,14 @@ def find_loaded_once_local_vars(instrs, code_options):
return loaded_once


def find_related_local_opcodes(instrs, code_options):
def find_related_local_opcodes(instrs: list[Instruction], code_options):
"""
find out the opcode pairs consist with LOAD_FAST and STORE_FAST
find out the opcode pairs consist with LOAD_FAST and STORE_FAST and LOAD_FAST_CHECK
"""
stack = []
opcode_pairs = []
for instr in instrs:
if instr.opname == "LOAD_FAST":
if instr.opname in ["LOAD_FAST", "LOAD_FAST_CHECK"]:
stack.append(instr)
elif instr.opname == "STORE_FAST":
if len(stack) > 0 and stack[-1] is not None:
Expand All @@ -105,7 +116,7 @@ def find_related_local_opcodes(instrs, code_options):
return opcode_pairs


def remove_load_store_pass(instrs, code_options):
def remove_load_store_pass(instrs: list[Instruction], code_options):
"""
This question is extremely complex, so we just simplify it as
'remove renames which is between var names who only stored once'
Expand Down Expand Up @@ -158,7 +169,8 @@ def code_exist(opname, argval, instrs):
if a_name != b_name:
for instr in instrs:
if (
instr.opname in ("LOAD_FAST", "STORE_FAST")
instr.opname
in ("LOAD_FAST_CHECK", "LOAD_FAST", "STORE_FAST")
and instr.argval == b_name
):
instr.argval = a_name
Expand Down Expand Up @@ -211,7 +223,13 @@ def code_exist(opname, argval, instrs):
code_range = instrs[last_store_idx : instrs.index(store_b)]
if (
not code_exist("STORE_FAST", b_name, code_range)
and not code_exist("LOAD_FAST_CHECK", b_name, code_range)
and not code_exist("LOAD_FAST", b_name, code_range)
and not code_exist(
"LOAD_FAST_CHECK",
a_name,
instrs[instrs.index(store_b) :],
)
and not code_exist(
"LOAD_FAST", a_name, instrs[instrs.index(store_b) :]
)
Expand All @@ -222,7 +240,8 @@ def code_exist(opname, argval, instrs):
instrs.remove(store_b)
for instr in instrs[last_store_idx:]:
if (
instr.opname in ("LOAD_FAST", "STORE_FAST")
instr.opname
in ("LOAD_FAST_CHECK", "LOAD_FAST", "STORE_FAST")
and instr.argval == a_name
):
instr.argval = b_name
Expand All @@ -245,6 +264,7 @@ def code_exist(opname, argval, instrs):
and opcode2 not in jump_target
and opcode1.opname == "STORE_FAST"
and opcode2.opname == "LOAD_FAST"
and opcode2.opname == "LOAD_FAST_CHECK"
and opcode1.argval == opcode2.argval
and opcode1.argval in loaded_once
):
Expand All @@ -255,15 +275,15 @@ def code_exist(opname, argval, instrs):
idx += 1


def remove_duplicate_resume(instrs, code_options):
def remove_duplicate_resume(instrs: list[Instruction], code_options):
resumes = list(filter(lambda instr: instr.opname == "RESUME", instrs))
if not resumes:
return
for resume in resumes[1:]:
instrs.remove(resume)


def check_precall_followed_by_call(instrs, code_options):
def check_precall_followed_by_call(instrs: list[Instruction], code_options):
"""
PRECALL should be followed by CALL, otherwise it will cause a segmentation fault
"""
Expand All @@ -272,3 +292,14 @@ def check_precall_followed_by_call(instrs, code_options):
raise InnerError(
f"PRECALL is not followed by CALL in {code_options['co_name']}"
)


def check_for_iter_jump_to(instrs: list[Instruction], code_options):
"""
Check if the `jump_to` of FOR_ITER is END_FOR, in Python3.12+
"""
for instr in instrs:
if instr.opname == "FOR_ITER":
assert instr.jump_to is not None
if instr.jump_to.opname != "END_FOR":
raise InnerError("FOR_ITER jump_to is not END_FOR")
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
from typing import TYPE_CHECKING, Any

from ...utils import InnerError
from .opcode_info import ABS_JUMP, ALL_JUMP, REL_BWD_JUMP, REL_JUMP
from .opcode_info import (
ABS_JUMP,
ALL_JUMP,
PYOPCODE_CACHE_SIZE,
REL_BWD_JUMP,
REL_JUMP,
)

if TYPE_CHECKING:
import types
Expand Down Expand Up @@ -239,7 +245,8 @@ def relocate_jump_target(instructions: list[Instruction]) -> None:
if instr.opname in ABS_JUMP:
new_arg = jump_target
else: # instr.opname in REL_JUMP
new_arg = jump_target - instr.offset - 2
cache_size = PYOPCODE_CACHE_SIZE.get(instr.opname, 0)
new_arg = jump_target - (2 * cache_size) - instr.offset - 2
if instr.opname in REL_BWD_JUMP:
new_arg = -new_arg

Expand Down Expand Up @@ -315,12 +322,12 @@ def bind_ex_arg_with_instr(ex_arg, instr):
return modify_completed


def modify_vars(instructions, code_options):
def modify_vars(instructions: list[Instruction], code_options):
co_names = code_options['co_names']
co_varnames = code_options['co_varnames']
co_freevars = code_options['co_freevars']
for instrs in instructions:
if instrs.opname == 'LOAD_FAST' or instrs.opname == 'STORE_FAST':
if instrs.opname in ['LOAD_FAST', 'LOAD_FAST_CHECK', 'STORE_FAST']:
assert (
instrs.argval in co_varnames
), f"`{instrs.argval}` not in {co_varnames}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class PopJumpCond(Enum):
NOT_NONE = "NOT_NONE"


def get_pyopcode_cache_size() -> dict[str, int]:
def _get_pyopcode_cache_size() -> dict[str, int]:
if sys.version_info >= (3, 11) and sys.version_info < (3, 12):
# Cache for some opcodes, it's for Python 3.11+
# https://github.com/python/cpython/blob/3.11/Include/internal/pycore_opcode.h#L41-L53
Expand Down Expand Up @@ -87,4 +87,4 @@ def get_pyopcode_cache_size() -> dict[str, int]:
return {}


PYOPCODE_CACHE_SIZE = get_pyopcode_cache_size()
PYOPCODE_CACHE_SIZE = _get_pyopcode_cache_size()
5 changes: 0 additions & 5 deletions test/sot/skip_files_py312
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
./test_11_jumps.py
./test_12_for_loop.py
./test_builtin_zip.py
./test_inplace_api.py
./test_min_graph_size.py
./test_side_effects.py
./test_sot_cost_model.py
./test_sot_resnet.py
./test_sot_resnet50_backward.py

0 comments on commit de1777b

Please sign in to comment.