Skip to content

Commit

Permalink
Revert "[dynamo 3.11] changes to with contexts (pytorch#94101)"
Browse files Browse the repository at this point in the history
This reverts commit 1123ab8.
  • Loading branch information
pruthvistony committed May 2, 2023
1 parent 3881aea commit a4cadab
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 47 deletions.
17 changes: 9 additions & 8 deletions torch/_dynamo/resume_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def __call__(self, code_options, cleanup):
]

else:
# NOTE: copying over for now since more changes are anticipated
with_except_start = create_instruction("WITH_EXCEPT_START")
pop_top_after_with_except_start = create_instruction("POP_TOP")

cleanup_complete_jump_target = create_instruction("NOP")

def create_load_none():
Expand All @@ -107,6 +110,7 @@ def create_load_none():

cleanup[:] = (
[
create_instruction("POP_BLOCK"),
create_load_none(),
create_load_none(),
create_load_none(),
Expand All @@ -117,27 +121,24 @@ def create_load_none():
create_instruction(
"JUMP_FORWARD", target=cleanup_complete_jump_target
),
create_instruction("PUSH_EXC_INFO"),
create_instruction("WITH_EXCEPT_START"),
with_except_start,
create_instruction(
"POP_JUMP_FORWARD_IF_TRUE",
target=pop_top_after_with_except_start,
),
create_instruction("RERAISE", 2),
create_instruction("COPY", 3),
create_instruction("POP_EXCEPT"),
create_instruction("RERAISE", 1),
create_instruction("RERAISE"),
pop_top_after_with_except_start,
create_instruction("POP_EXCEPT"),
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
create_instruction("POP_EXCEPT"),
create_instruction("POP_TOP"),
cleanup_complete_jump_target,
]
+ cleanup
)

return create_call_function(0, False) + [
create_instruction("BEFORE_WITH"),
create_instruction("SETUP_WITH", target=with_except_start),
create_instruction("POP_TOP"),
]

Expand Down
44 changes: 5 additions & 39 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def _step_logger():

@dataclasses.dataclass
class BlockStackEntry:
id: int
target: Instruction
stack_index: Optional[int] = None
with_context: ContextWrappingVariable = None
Expand Down Expand Up @@ -879,11 +878,11 @@ def jump(self, inst):

def SETUP_LOOP(self, inst):
# only exists in python<=3.7
self.block_stack.append(BlockStackEntry(0, inst.target))
self.block_stack.append(BlockStackEntry(inst.target))

def SETUP_EXCEPT(self, inst):
# only exists in python<=3.7
self.block_stack.append(BlockStackEntry(0, inst.target))
self.block_stack.append(BlockStackEntry(inst.target))

def POP_BLOCK(self, inst):
self.block_stack.pop()
Expand All @@ -895,12 +894,10 @@ def SETUP_WITH(self, inst):
self.output.guards.update(ctx.guards)

if isinstance(self, InstructionTranslator):
self.block_stack.append(
BlockStackEntry(0, inst.target, len(self.stack), ctx)
)
self.block_stack.append(BlockStackEntry(inst.target, len(self.stack), ctx))
else:
# can't restore this while inlining
self.block_stack.append(BlockStackEntry(0, inst.target))
self.block_stack.append(BlockStackEntry(inst.target))
self.push(
WithExitFunctionVariable(
ctx,
Expand All @@ -911,7 +908,7 @@ def SETUP_WITH(self, inst):
self.push(ctx.enter(self))

def SETUP_FINALLY(self, inst):
self.block_stack.append(BlockStackEntry(0, inst.target))
self.block_stack.append(BlockStackEntry(inst.target))

def BEGIN_FINALLY(self, inst):
self.push(None)
Expand Down Expand Up @@ -1572,13 +1569,6 @@ def CALL(self, inst):
kwargs = {}
self.call_function(fn, args, kwargs)
self.kw_names = None
# 3.11 removed POP_BLOCK, so we manually pop the block stack here
if (
isinstance(fn, WithExitFunctionVariable)
and len(self.block_stack) > 0
and id(fn) == self.block_stack[-1].id
):
self.block_stack.pop()

def COPY(self, inst):
self.push(self.stack[-inst.arg])
Expand All @@ -1602,30 +1592,6 @@ def SWAP(self, inst):
def CACHE(self, inst):
pass

def BEFORE_WITH(self, inst):
ctx = self.pop()
if not isinstance(ctx, ContextWrappingVariable):
unimplemented(f"BEFORE_WITH {ctx}")
self.output.guards.update(ctx.guards)

exit = WithExitFunctionVariable(
ctx,
inst.target,
**VariableTracker.propagate(ctx),
)
# 3.11 no longer uses a block stack, but we still keep track of one
# so that we know which contexts are currently active.
if isinstance(self, InstructionTranslator):
self.block_stack.append(
BlockStackEntry(id(exit), inst.target, self.real_stack_len(), ctx)
)
else:
# can't restore this while inlining
self.block_stack.append(BlockStackEntry(id(exit), inst.target))

self.push(exit)
self.push(ctx.enter(self))

def copy_graphstate(self) -> InstructionTranslatorGraphState:
"""Create a checkpoint of the current state by copying everything"""
return InstructionTranslatorGraphState(
Expand Down

0 comments on commit a4cadab

Please sign in to comment.