Skip to content

Commit

Permalink
[dynamo 3.11] changes to with contexts (pytorch#94101)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamwen42 authored and jhavukainen committed Mar 15, 2024
1 parent 3b733a7 commit 85cdf47
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 14 deletions.
17 changes: 8 additions & 9 deletions torch/_dynamo/resume_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,7 @@ def __call__(self, code_options, cleanup):
]

else:
# NOTE: copying over for now since more changes are anticipated
with_except_start = create_instruction("WITH_EXCEPT_START")
pop_top_after_with_except_start = create_instruction("POP_TOP")

cleanup_complete_jump_target = create_instruction("NOP")

def create_load_none():
Expand All @@ -110,7 +107,6 @@ def create_load_none():

cleanup[:] = (
[
create_instruction("POP_BLOCK"),
create_load_none(),
create_load_none(),
create_load_none(),
Expand All @@ -121,24 +117,27 @@ def create_load_none():
create_instruction(
"JUMP_FORWARD", target=cleanup_complete_jump_target
),
with_except_start,
create_instruction("PUSH_EXC_INFO"),
create_instruction("WITH_EXCEPT_START"),
create_instruction(
"POP_JUMP_FORWARD_IF_TRUE",
target=pop_top_after_with_except_start,
),
create_instruction("RERAISE"),
create_instruction("RERAISE", 2),
create_instruction("COPY", 3),
create_instruction("POP_EXCEPT"),
create_instruction("RERAISE", 1),
pop_top_after_with_except_start,
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
create_instruction("POP_EXCEPT"),
create_instruction("POP_TOP"),
create_instruction("POP_TOP"),
cleanup_complete_jump_target,
]
+ cleanup
)

return create_call_function(0, False) + [
create_instruction("SETUP_WITH", target=with_except_start),
create_instruction("BEFORE_WITH"),
create_instruction("POP_TOP"),
]

Expand Down
44 changes: 39 additions & 5 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _step_logger():

@dataclasses.dataclass
class BlockStackEntry:
id: int
target: Instruction
stack_index: Optional[int] = None
with_context: ContextWrappingVariable = None
Expand Down Expand Up @@ -878,11 +879,11 @@ def jump(self, inst):

def SETUP_LOOP(self, inst):
# only exists in python<=3.7
self.block_stack.append(BlockStackEntry(inst.target))
self.block_stack.append(BlockStackEntry(0, inst.target))

def SETUP_EXCEPT(self, inst):
# only exists in python<=3.7
self.block_stack.append(BlockStackEntry(inst.target))
self.block_stack.append(BlockStackEntry(0, inst.target))

def POP_BLOCK(self, inst):
self.block_stack.pop()
Expand All @@ -894,10 +895,12 @@ def SETUP_WITH(self, inst):
self.output.guards.update(ctx.guards)

if isinstance(self, InstructionTranslator):
self.block_stack.append(BlockStackEntry(inst.target, len(self.stack), ctx))
self.block_stack.append(
BlockStackEntry(0, inst.target, len(self.stack), ctx)
)
else:
# can't restore this while inlining
self.block_stack.append(BlockStackEntry(inst.target))
self.block_stack.append(BlockStackEntry(0, inst.target))
self.push(
WithExitFunctionVariable(
ctx,
Expand All @@ -908,7 +911,7 @@ def SETUP_WITH(self, inst):
self.push(ctx.enter(self))

def SETUP_FINALLY(self, inst):
self.block_stack.append(BlockStackEntry(inst.target))
self.block_stack.append(BlockStackEntry(0, inst.target))

def BEGIN_FINALLY(self, inst):
self.push(None)
Expand Down Expand Up @@ -1569,6 +1572,13 @@ def CALL(self, inst):
kwargs = {}
self.call_function(fn, args, kwargs)
self.kw_names = None
# 3.11 removed POP_BLOCK, so we manually pop the block stack here
if (
isinstance(fn, WithExitFunctionVariable)
and len(self.block_stack) > 0
and id(fn) == self.block_stack[-1].id
):
self.block_stack.pop()

def COPY(self, inst):
self.push(self.stack[-inst.arg])
Expand All @@ -1592,6 +1602,30 @@ def SWAP(self, inst):
def CACHE(self, inst):
pass

def BEFORE_WITH(self, inst):
ctx = self.pop()
if not isinstance(ctx, ContextWrappingVariable):
unimplemented(f"BEFORE_WITH {ctx}")
self.output.guards.update(ctx.guards)

exit = WithExitFunctionVariable(
ctx,
inst.target,
**VariableTracker.propagate(ctx),
)
# 3.11 no longer uses a block stack, but we still keep track of one
# so that we know which contexts are currently active.
if isinstance(self, InstructionTranslator):
self.block_stack.append(
BlockStackEntry(id(exit), inst.target, self.real_stack_len(), ctx)
)
else:
# can't restore this while inlining
self.block_stack.append(BlockStackEntry(id(exit), inst.target))

self.push(exit)
self.push(ctx.enter(self))

def copy_graphstate(self) -> InstructionTranslatorGraphState:
"""Create a checkpoint of the current state by copying everything"""
return InstructionTranslatorGraphState(
Expand Down

0 comments on commit 85cdf47

Please sign in to comment.