From ef4f58dad3dd283038a6926976b1b7f9440552d6 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Mon, 26 Feb 2024 16:06:47 +0800 Subject: [PATCH 1/8] debug END_FOR --- .../executor/opcode_executor.py | 17 ++++++++++++++++- .../executor/pycode_generator.py | 2 +- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index e9a985e5b728c7..ebe788ae3cbbb3 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -767,6 +767,9 @@ def LOAD_FAST(self, instr: Instruction): var = self._locals[instr.argval] self.stack.push(var) + def LOAD_FAST_CHECK(self, instr: Instruction): + self.LOAD_FAST(instr) + def DELETE_FAST(self, instr: Instruction): varname = self._code.co_varnames[instr.arg] del self._locals[varname] @@ -1493,6 +1496,10 @@ def LIST_TO_TUPLE(self, instr: Instruction): ) ) + def END_FOR(self, instr: Instruction): + # breakpoint() + pass + class OpcodeExecutor(OpcodeExecutorBase): """ @@ -1811,6 +1818,9 @@ def _gen_loop_body_between( pycode_gen.extend_instrs(origin_instrs[start:end]) # break should jump to this nop + # if sys.version_info >= (3,12): + # nop_for_break = pycode_gen._add_instr("END_FOR") + # else: nop_for_break = pycode_gen._add_instr("NOP") # need do additional operates when break @@ -1971,6 +1981,9 @@ def _break_graph_when_for_loop( for_iter, direction=JumpDirection.BACKWARD ) nop = self._graph.pycode_gen._add_instr("NOP") + # if sys.version_info >= (3,12): + # for_iter.jump_to = self._graph.pycode_gen._add_instr("END_FOR") + # else: for_iter.jump_to = nop jump_if_break.jump_to = nop @@ -2006,6 +2019,8 @@ def _inline_call_for_loop( start_idx = self.indexof(for_iter) end_idx = self.indexof(for_iter.jump_to) + # breakpoint() + all_used_vars = analysis_used_names_with_space( origin_instrs, start_idx, end_idx ) @@ -2020,7 +2035,7 @@ def _inline_call_for_loop( pycode_gen.gen_load_fast(iterator.id) # 2. copy main logic - pycode_gen.extend_instrs(origin_instrs[start_idx:end_idx]) + pycode_gen.extend_instrs(origin_instrs[start_idx : end_idx + 1]) # 3. add break, continue marker and relocate jump for_iter_instr = origin_instrs[start_idx] diff --git a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py index 69e174818d6627..06fddd6875ae84 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py +++ b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py @@ -1003,7 +1003,7 @@ def gen_pop_jump( direction: JumpDirection = JumpDirection.FORWARD, suffix: PopJumpCond = PopJumpCond.NONE, ) -> Instruction: - if sys.version_info >= (3, 11): + if sys.version_info >= (3, 11) and sys.version_info < (3, 12): return self._add_instr( f"POP_JUMP_{direction.value}_IF_{suffix.value}", jump_to=jump_to ) From 601a67a3a0f72440365dc99a545be00ec94f4b4c Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Mon, 26 Feb 2024 19:45:17 +0800 Subject: [PATCH 2/8] fix --- .../jit/sot/opcode_translator/executor/opcode_executor.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index 60f4fe877fa1c2..0c215e2e9bcfb9 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -2108,6 +2108,9 @@ def create_after_loop_fn(): for_iter, direction=JumpDirection.BACKWARD ) nop = self._graph.pycode_gen.add_instr("NOP") + if sys.version_info >= (3, 12): + self._graph.pycode_gen.add_instr("END_FOR") + for_iter.jump_to = nop jump_if_break.jump_to = nop @@ -2179,6 +2182,8 @@ def create_inline_call_fn(): ) nop_for_break = pycode_gen.add_instr("NOP") + if sys.version_info >= (3, 12): + pycode_gen.add_instr("END_FOR") # 2.4. relocate jumps for instr in pycode_gen._instructions: From c4e53a30a3ea8f30e54352ac663d2ab5107f42d1 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 27 Feb 2024 12:46:30 +0800 Subject: [PATCH 3/8] tmp storage --- .../executor/opcode_executor.py | 22 ++++++++++++++++--- .../instruction_utils/instruction_pass.py | 3 +++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index 0c215e2e9bcfb9..9d53d1d07b6e36 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -1549,8 +1549,11 @@ def LIST_TO_TUPLE(self, instr: Instruction): ) def END_FOR(self, instr: Instruction): - # breakpoint() + # 我们不应该跑到这个字节码 pass + # breakpoint() + # self.POP_TOP(instr) + # self.POP_TOP(instr) class OpcodeExecutor(OpcodeExecutorBase): @@ -2040,13 +2043,26 @@ def create_after_loop_fn(): return None pycode_gen = PyCodeGen(self._frame) origin_instrs = get_instructions(pycode_gen._origin_code) + resume_fn_end_idx = loop_body_end_idx + + # skip END_FOR in python3.12 + if ( + sys.version_info >= (3, 12) + and origin_instrs[loop_body_end_idx].opname == "END_FOR" + ): + pycode_gen.add_instr("NOP") + resume_fn_end_idx += 1 + # origin_instrs[loop_body_end_idx] = self._graph.pycode_gen.add_instr("NOP") + pycode_gen.set_function_inputs( after_loop_fn_inputs, stack_size=len(self.stack) - 1 ) - pycode_gen.extend_instrs(origin_instrs[loop_body_end_idx:]) + pycode_gen.extend_instrs(origin_instrs[resume_fn_end_idx:]) # the resume_fn contains return code, so we don't need set output here # global vars are updated correctly, and need local vars will return after_loop_fn = pycode_gen.create_function() + # 在这里去除 resume 后的 END_FOR + # breakpoint() return after_loop_fn after_loop_fn = create_after_loop_fn() @@ -2112,7 +2128,7 @@ def create_after_loop_fn(): self._graph.pycode_gen.add_instr("END_FOR") for_iter.jump_to = nop - jump_if_break.jump_to = nop + jump_if_break.jump_to = self._graph.pycode_gen.add_instr("NOP") # 9. prepare inputs and call after_loop_fn if after_loop_fn is not None: diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py index 5b0cc17fc808f2..cb3642fde038f2 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py @@ -272,3 +272,6 @@ def check_precall_followed_by_call(instrs, code_options): raise InnerError( f"PRECALL is not followed by CALL in {code_options['co_name']}" ) + + +# TODO(gouzil): 检查END_FOR, 看到这个别LGTM From b1e14779f6eaf694ee697cba79f576a455ec2832 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 27 Feb 2024 20:53:28 +0800 Subject: [PATCH 4/8] supplement `LOAD_FAST_CHECK` --- .../executor/opcode_executor.py | 4 ---- .../instruction_utils/instruction_pass.py | 19 ++++++++++++++----- .../instruction_utils/instruction_utils.py | 6 +++++- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index a9879ebadc51ea..d133785e55542d 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -2072,9 +2072,7 @@ def create_after_loop_fn(): sys.version_info >= (3, 12) and origin_instrs[loop_body_end_idx].opname == "END_FOR" ): - pycode_gen.add_instr("NOP") resume_fn_end_idx += 1 - # origin_instrs[loop_body_end_idx] = self._graph.pycode_gen.add_instr("NOP") pycode_gen.set_function_inputs( after_loop_fn_inputs, stack_size=len(self.stack) - 1 @@ -2083,8 +2081,6 @@ def create_after_loop_fn(): # the resume_fn contains return code, so we don't need set output here # global vars are updated correctly, and need local vars will return after_loop_fn = pycode_gen.create_function() - # 在这里去除 resume 后的 END_FOR - # breakpoint() return after_loop_fn after_loop_fn = create_after_loop_fn() diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py index cb3642fde038f2..be1b13fc269c7a 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py @@ -67,7 +67,7 @@ def find_loaded_once_local_vars(instrs, code_options): """ loaded_vars = {} for instr in instrs: - if instr.opname == "LOAD_FAST": + if instr.opname == "LOAD_FAST" or instr.opname == "LOAD_FAST_CHECK": if instr.argval in loaded_vars: loaded_vars[instr.argval] += 1 else: @@ -79,12 +79,12 @@ def find_loaded_once_local_vars(instrs, code_options): def find_related_local_opcodes(instrs, code_options): """ - find out the opcode pairs consist with LOAD_FAST and STORE_FAST + find out the opcode pairs consist with LOAD_FAST and STORE_FAST and LOAD_FAST_CHECK """ stack = [] opcode_pairs = [] for instr in instrs: - if instr.opname == "LOAD_FAST": + if instr.opname == "LOAD_FAST" or instr.opname == "LOAD_FAST_CHECK": stack.append(instr) elif instr.opname == "STORE_FAST": if len(stack) > 0 and stack[-1] is not None: @@ -158,7 +158,8 @@ def code_exist(opname, argval, instrs): if a_name != b_name: for instr in instrs: if ( - instr.opname in ("LOAD_FAST", "STORE_FAST") + instr.opname + in ("LOAD_FAST_CHECK", "LOAD_FAST", "STORE_FAST") and instr.argval == b_name ): instr.argval = a_name @@ -211,7 +212,13 @@ def code_exist(opname, argval, instrs): code_range = instrs[last_store_idx : instrs.index(store_b)] if ( not code_exist("STORE_FAST", b_name, code_range) + and not code_exist("LOAD_FAST_CHECK", b_name, code_range) and not code_exist("LOAD_FAST", b_name, code_range) + and not code_exist( + "LOAD_FAST_CHECK", + a_name, + instrs[instrs.index(store_b) :], + ) and not code_exist( "LOAD_FAST", a_name, instrs[instrs.index(store_b) :] ) @@ -222,7 +229,8 @@ def code_exist(opname, argval, instrs): instrs.remove(store_b) for instr in instrs[last_store_idx:]: if ( - instr.opname in ("LOAD_FAST", "STORE_FAST") + instr.opname + in ("LOAD_FAST_CHECK", "LOAD_FAST", "STORE_FAST") and instr.argval == a_name ): instr.argval = b_name @@ -245,6 +253,7 @@ def code_exist(opname, argval, instrs): and opcode2 not in jump_target and opcode1.opname == "STORE_FAST" and opcode2.opname == "LOAD_FAST" + and opcode2.opname == "LOAD_FAST_CHECK" and opcode1.argval == opcode2.argval and opcode1.argval in loaded_once ): diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py index 2965c8e6bc056e..3b5b1be8e85c96 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -320,7 +320,11 @@ def modify_vars(instructions, code_options): co_varnames = code_options['co_varnames'] co_freevars = code_options['co_freevars'] for instrs in instructions: - if instrs.opname == 'LOAD_FAST' or instrs.opname == 'STORE_FAST': + if ( + instrs.opname == 'LOAD_FAST' + or instrs.opname == 'LOAD_FAST_CHECK' + or instrs.opname == 'STORE_FAST' + ): assert ( instrs.argval in co_varnames ), f"`{instrs.argval}` not in {co_varnames}" From cd1ec5c4e8ad824d18c30b2182e4c36f59c49cfa Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sat, 2 Mar 2024 10:58:31 +0800 Subject: [PATCH 5/8] fix args and fix `FOR_ITER` jump_to error --- .../executor/opcode_executor.py | 29 ++++++++++--------- .../executor/opcode_inline_executor.py | 2 ++ .../instruction_utils/instruction_utils.py | 7 ++++- test/sot/skip_files_py312 | 5 ---- 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index 26737b93be66fb..d1f055408f2e3f 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -640,6 +640,10 @@ def _rot_top_n(self, n: int): def POP_TOP(self, instr: Instruction): self.stack.pop() + def END_FOR(self, instr: Instruction): + self.POP_TOP(instr) + self.POP_TOP(instr) + def PUSH_NULL(self, instr: Instruction): self.stack.push(NullVariable()) @@ -1570,13 +1574,6 @@ def CALL_INTRINSIC_1(self, instr: Instruction): else: raise FallbackError(f"No support Intrinsics, {intrinsic_func.name}") - def END_FOR(self, instr: Instruction): - # 我们不应该跑到这个字节码 - pass - # breakpoint() - # self.POP_TOP(instr) - # self.POP_TOP(instr) - class OpcodeExecutor(OpcodeExecutorBase): """ @@ -2069,7 +2066,7 @@ def create_after_loop_fn(): origin_instrs = get_instructions(pycode_gen._origin_code) resume_fn_end_idx = loop_body_end_idx - # skip END_FOR in python3.12 + # skip resume END_FOR in python3.12 if ( sys.version_info >= (3, 12) and origin_instrs[loop_body_end_idx].opname == "END_FOR" @@ -2143,12 +2140,14 @@ def create_after_loop_fn(): self._graph.pycode_gen.gen_jump( for_iter, direction=JumpDirection.BACKWARD ) - nop = self._graph.pycode_gen.add_instr("NOP") + if sys.version_info >= (3, 12): - self._graph.pycode_gen.add_instr("END_FOR") + end_for = self._graph.pycode_gen.add_instr("END_FOR") + + nop = self._graph.pycode_gen.add_instr("NOP") - for_iter.jump_to = nop - jump_if_break.jump_to = self._graph.pycode_gen.add_instr("NOP") + for_iter.jump_to = end_for if sys.version_info >= (3, 12) else nop + jump_if_break.jump_to = nop # 9. prepare inputs and call after_loop_fn if after_loop_fn is not None: @@ -2217,9 +2216,9 @@ def create_inline_call_fn(): for_iter_instr, direction=JumpDirection.BACKWARD ) - nop_for_break = pycode_gen.add_instr("NOP") if sys.version_info >= (3, 12): - pycode_gen.add_instr("END_FOR") + end_for = pycode_gen.add_instr("END_FOR") + nop_for_break = pycode_gen.add_instr("NOP") # 2.4. relocate jumps for instr in pycode_gen._instructions: @@ -2233,6 +2232,8 @@ def create_inline_call_fn(): instr.jump_to = nop_for_break jump.jump_to = for_iter_instr + if sys.version_info >= (3, 12): + for_iter_instr.jump_to = end_for pycode_gen.set_function_outputs(output_var_names) inline_call_fn = pycode_gen.create_function() diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py index 306166aa7d872c..5c723d39b3cca0 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py @@ -316,6 +316,8 @@ def FOR_ITER(self, instr: Instruction): self.stack.pop() assert isinstance(instr.jump_to, Instruction) self._lasti = self.indexof(instr.jump_to) + next_instr = self._instructions[self._lasti] + self._lasti += int(next_instr.opname == 'END_FOR') else: self._graph.remove_global_guarded_variable(iterator) diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py index 3b5b1be8e85c96..810eaeff22b090 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -315,7 +315,7 @@ def bind_ex_arg_with_instr(ex_arg, instr): return modify_completed -def modify_vars(instructions, code_options): +def modify_vars(instructions: list[Instruction], code_options): co_names = code_options['co_names'] co_varnames = code_options['co_varnames'] co_freevars = code_options['co_freevars'] @@ -336,6 +336,11 @@ def modify_vars(instructions, code_options): instrs.argval in namemap ), f"`{instrs.argval}` not in {namemap}" instrs.arg = namemap.index(instrs.argval) + elif instrs.opname == "FOR_ITER": + if sys.version_info >= (3, 12): + assert instrs.arg is not None + # END_FOR offset = ((FOR_ITER arg) + (FOR_ITER offset) - 1) << 1 + instrs.arg -= 1 def calc_offset_from_bytecode_offset( diff --git a/test/sot/skip_files_py312 b/test/sot/skip_files_py312 index 4d3ee9050ad6cc..82cabe1866d19d 100644 --- a/test/sot/skip_files_py312 +++ b/test/sot/skip_files_py312 @@ -1,9 +1,4 @@ ./test_11_jumps.py -./test_12_for_loop.py -./test_builtin_zip.py -./test_inplace_api.py -./test_min_graph_size.py ./test_side_effects.py -./test_sot_cost_model.py ./test_sot_resnet.py ./test_sot_resnet50_backward.py From fc254dea94bd40a61af93c5835b89f8cca0ea084 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sat, 2 Mar 2024 12:53:46 +0800 Subject: [PATCH 6/8] add check --- .../instruction_utils/instruction_pass.py | 39 ++++++++++++++----- .../instruction_utils/instruction_utils.py | 10 ++++- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py index be1b13fc269c7a..586665c20e55b0 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py @@ -12,21 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING + from paddle.jit.sot.utils import log, log_do from ...utils import InnerError from .instruction_utils import instrs_info from .stack_analyse import StackAnalyser +if TYPE_CHECKING: + from .instruction_utils import Instruction -def apply_instr_pass(instrs, code_options): + +def apply_instr_pass(instrs: list[Instruction], code_options): log(4, f"[Opcode Pass]: Original New Code {code_options['co_name']}:\n") log_do(4, lambda: print(instrs_info(instrs))) - supported_passes = ( + supported_passes = [ remove_load_store_pass, remove_duplicate_resume, check_precall_followed_by_call, - ) + ] + + if sys.version_info >= (3, 12): + supported_passes.append(check_for_iter_jump_to) for instr_pass in supported_passes: instr_pass(instrs, code_options) @@ -38,7 +49,7 @@ def apply_instr_pass(instrs, code_options): log_do(4, lambda: print(instrs_info(instrs))) -def find_stored_once_local_vars(instrs, code_options): +def find_stored_once_local_vars(instrs: list[Instruction], code_options): """ find out the local var names which is only stored once """ @@ -61,7 +72,7 @@ def find_stored_once_local_vars(instrs, code_options): return stored_once -def find_loaded_once_local_vars(instrs, code_options): +def find_loaded_once_local_vars(instrs: list[Instruction], code_options): """ find out the local var names which is only stored once """ @@ -77,7 +88,7 @@ def find_loaded_once_local_vars(instrs, code_options): return loaded_once -def find_related_local_opcodes(instrs, code_options): +def find_related_local_opcodes(instrs: list[Instruction], code_options): """ find out the opcode pairs consist with LOAD_FAST and STORE_FAST and LOAD_FAST_CHECK """ @@ -105,7 +116,7 @@ def find_related_local_opcodes(instrs, code_options): return opcode_pairs -def remove_load_store_pass(instrs, code_options): +def remove_load_store_pass(instrs: list[Instruction], code_options): """ This question is extremely complex, so we just simplify it as 'remove renames which is between var names who only stored once' @@ -264,7 +275,7 @@ def code_exist(opname, argval, instrs): idx += 1 -def remove_duplicate_resume(instrs, code_options): +def remove_duplicate_resume(instrs: list[Instruction], code_options): resumes = list(filter(lambda instr: instr.opname == "RESUME", instrs)) if not resumes: return @@ -272,7 +283,7 @@ def remove_duplicate_resume(instrs, code_options): instrs.remove(resume) -def check_precall_followed_by_call(instrs, code_options): +def check_precall_followed_by_call(instrs: list[Instruction], code_options): """ PRECALL should be followed by CALL, otherwise it will cause a segmentation fault """ @@ -283,4 +294,12 @@ def check_precall_followed_by_call(instrs, code_options): ) -# TODO(gouzil): 检查END_FOR, 看到这个别LGTM +def check_for_iter_jump_to(instrs: list[Instruction], code_options): + """ + Check if the `jump_to` of FOR_ITER is END_FOR, in Python3.12+ + """ + for instr in instrs: + if instr.opname == "FOR_ITER": + assert instr.jump_to is not None + if instr.jump_to.opname != "END_FOR": + raise InnerError("FOR_ITER jump_to is not END_FOR") diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py index 810eaeff22b090..16e3f25397b7fd 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -338,9 +338,17 @@ def modify_vars(instructions: list[Instruction], code_options): instrs.arg = namemap.index(instrs.argval) elif instrs.opname == "FOR_ITER": if sys.version_info >= (3, 12): + assert instrs.jump_to is not None assert instrs.arg is not None - # END_FOR offset = ((FOR_ITER arg) + (FOR_ITER offset) - 1) << 1 + assert instrs.offset is not None instrs.arg -= 1 + # END_FOR offset = (((FOR_ITER arg + 2) << 1) + FOR_ITER offset) + if instrs.jump_to.offset != ( + ((instrs.arg + 2) << 1) + instrs.offset + ): + raise InnerError( + 'FOR_ITER jump_to offset is not equal to "(((FOR_ITER arg + 2) << 1) + FOR_ITER offset)" ' + ) def calc_offset_from_bytecode_offset( From 29d396021089ca9a867d730cc335736453a5bfe1 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sun, 3 Mar 2024 20:47:25 +0800 Subject: [PATCH 7/8] fix relocate_jump_target and add END_FOR check --- .../executor/opcode_executor.py | 15 ++++-------- .../executor/opcode_inline_executor.py | 6 +++-- .../instruction_utils/instruction_utils.py | 24 ++++++++----------- .../instruction_utils/opcode_info.py | 4 ++-- 4 files changed, 21 insertions(+), 28 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index d1f055408f2e3f..a5527255f98f57 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -640,10 +640,6 @@ def _rot_top_n(self, n: int): def POP_TOP(self, instr: Instruction): self.stack.pop() - def END_FOR(self, instr: Instruction): - self.POP_TOP(instr) - self.POP_TOP(instr) - def PUSH_NULL(self, instr: Instruction): self.stack.push(NullVariable()) @@ -1690,8 +1686,9 @@ def FOR_ITER(self, instr): self._inline_call_for_loop(iterator, instr) self._lasti = self.indexof(instr.jump_to) - next_instr = self._instructions[self._lasti] - self._lasti += int(next_instr.opname == 'END_FOR') + if sys.version_info >= (3, 12): + assert self._instructions[self._lasti].opname == "END_FOR" + self._lasti += 1 except BreakGraphError as e: log(3, f"[BreakGraph] FOR_ITER sim for loop failed for: {e}\n") if backup_iter_idx: @@ -2067,10 +2064,8 @@ def create_after_loop_fn(): resume_fn_end_idx = loop_body_end_idx # skip resume END_FOR in python3.12 - if ( - sys.version_info >= (3, 12) - and origin_instrs[loop_body_end_idx].opname == "END_FOR" - ): + if sys.version_info >= (3, 12): + assert origin_instrs[loop_body_end_idx].opname == "END_FOR" resume_fn_end_idx += 1 pycode_gen.set_function_inputs( diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py index 5c723d39b3cca0..98cb2da36d02a2 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_inline_executor.py @@ -17,6 +17,7 @@ import contextlib import inspect import re +import sys from typing import TYPE_CHECKING from ...profiler import event_register @@ -316,8 +317,9 @@ def FOR_ITER(self, instr: Instruction): self.stack.pop() assert isinstance(instr.jump_to, Instruction) self._lasti = self.indexof(instr.jump_to) - next_instr = self._instructions[self._lasti] - self._lasti += int(next_instr.opname == 'END_FOR') + if sys.version_info >= (3, 12): + assert self._instructions[self._lasti].opname == "END_FOR" + self._lasti += 1 else: self._graph.remove_global_guarded_variable(iterator) diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py index 16e3f25397b7fd..568e8311882818 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any from ...utils import InnerError -from .opcode_info import ABS_JUMP, ALL_JUMP, REL_BWD_JUMP, REL_JUMP +from .opcode_info import ( + ABS_JUMP, + ALL_JUMP, + PYOPCODE_CACHE_SIZE, + REL_BWD_JUMP, + REL_JUMP, +) if TYPE_CHECKING: import types @@ -239,6 +245,9 @@ def relocate_jump_target(instructions: list[Instruction]) -> None: if instr.opname in ABS_JUMP: new_arg = jump_target else: # instr.opname in REL_JUMP + if sys.version_info >= (3, 12): + cache = PYOPCODE_CACHE_SIZE.get(instr.opname, 0) + jump_target -= 2 * cache new_arg = jump_target - instr.offset - 2 if instr.opname in REL_BWD_JUMP: new_arg = -new_arg @@ -336,19 +345,6 @@ def modify_vars(instructions: list[Instruction], code_options): instrs.argval in namemap ), f"`{instrs.argval}` not in {namemap}" instrs.arg = namemap.index(instrs.argval) - elif instrs.opname == "FOR_ITER": - if sys.version_info >= (3, 12): - assert instrs.jump_to is not None - assert instrs.arg is not None - assert instrs.offset is not None - instrs.arg -= 1 - # END_FOR offset = (((FOR_ITER arg + 2) << 1) + FOR_ITER offset) - if instrs.jump_to.offset != ( - ((instrs.arg + 2) << 1) + instrs.offset - ): - raise InnerError( - 'FOR_ITER jump_to offset is not equal to "(((FOR_ITER arg + 2) << 1) + FOR_ITER offset)" ' - ) def calc_offset_from_bytecode_offset( diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py index 2dc69b7565672a..d310f849930132 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_info.py @@ -45,7 +45,7 @@ class PopJumpCond(Enum): NOT_NONE = "NOT_NONE" -def get_pyopcode_cache_size() -> dict[str, int]: +def _get_pyopcode_cache_size() -> dict[str, int]: if sys.version_info >= (3, 11) and sys.version_info < (3, 12): # Cache for some opcodes, it's for Python 3.11+ # https://github.com/python/cpython/blob/3.11/Include/internal/pycore_opcode.h#L41-L53 @@ -87,4 +87,4 @@ def get_pyopcode_cache_size() -> dict[str, int]: return {} -PYOPCODE_CACHE_SIZE = get_pyopcode_cache_size() +PYOPCODE_CACHE_SIZE = _get_pyopcode_cache_size() From e40f30ebe7d67f27972e0df1f429c96272ec73da Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sun, 3 Mar 2024 22:50:50 +0800 Subject: [PATCH 8/8] simplify if --- .../instruction_utils/instruction_pass.py | 4 ++-- .../instruction_utils/instruction_utils.py | 12 +++--------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py index 586665c20e55b0..e790f720ee3f82 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_pass.py @@ -78,7 +78,7 @@ def find_loaded_once_local_vars(instrs: list[Instruction], code_options): """ loaded_vars = {} for instr in instrs: - if instr.opname == "LOAD_FAST" or instr.opname == "LOAD_FAST_CHECK": + if instr.opname in ["LOAD_FAST", "LOAD_FAST_CHECK"]: if instr.argval in loaded_vars: loaded_vars[instr.argval] += 1 else: @@ -95,7 +95,7 @@ def find_related_local_opcodes(instrs: list[Instruction], code_options): stack = [] opcode_pairs = [] for instr in instrs: - if instr.opname == "LOAD_FAST" or instr.opname == "LOAD_FAST_CHECK": + if instr.opname in ["LOAD_FAST", "LOAD_FAST_CHECK"]: stack.append(instr) elif instr.opname == "STORE_FAST": if len(stack) > 0 and stack[-1] is not None: diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py index 568e8311882818..c30e21f8fb096d 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -245,10 +245,8 @@ def relocate_jump_target(instructions: list[Instruction]) -> None: if instr.opname in ABS_JUMP: new_arg = jump_target else: # instr.opname in REL_JUMP - if sys.version_info >= (3, 12): - cache = PYOPCODE_CACHE_SIZE.get(instr.opname, 0) - jump_target -= 2 * cache - new_arg = jump_target - instr.offset - 2 + cache_size = PYOPCODE_CACHE_SIZE.get(instr.opname, 0) + new_arg = jump_target - (2 * cache_size) - instr.offset - 2 if instr.opname in REL_BWD_JUMP: new_arg = -new_arg @@ -329,11 +327,7 @@ def modify_vars(instructions: list[Instruction], code_options): co_varnames = code_options['co_varnames'] co_freevars = code_options['co_freevars'] for instrs in instructions: - if ( - instrs.opname == 'LOAD_FAST' - or instrs.opname == 'LOAD_FAST_CHECK' - or instrs.opname == 'STORE_FAST' - ): + if instrs.opname in ['LOAD_FAST', 'LOAD_FAST_CHECK', 'STORE_FAST']: assert ( instrs.argval in co_varnames ), f"`{instrs.argval}` not in {co_varnames}"