Skip to content

Commit

Permalink
[dynamo 3.11] changes to MAKE_FUNCTION and MATCH_KEYS (pytorch#94100)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamwen42 authored and jhavukainen committed Mar 15, 2024
1 parent f7d6555 commit 3b733a7
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
3 changes: 2 additions & 1 deletion torch/_dynamo/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ def make_function_with_closure(
)
output.append(create_instruction("BUILD_TUPLE", len(freevars)))
output.append(self.create_load_const(code))
output.append(self.create_load_const(fn_name))
if sys.version_info < (3, 11):
output.append(self.create_load_const(fn_name))
output.append(create_instruction("MAKE_FUNCTION", 0x08))
output.extend(self.rot_n(num_on_stack + 1))
self.clear_tos()
Expand Down
14 changes: 11 additions & 3 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,8 +1251,14 @@ def LIST_APPEND(self, inst):
def MAKE_FUNCTION(self, inst):
flags = inst.arg
old_stack = list(self.stack)
fn_name = self.pop()
if sys.version_info < (3, 11):
fn_name = self.pop()
code = self.pop()
if sys.version_info >= (3, 11):
# MAKE_FUNCTION behavior actually changed in 3.11, see
# https://github.com/python/cpython/pull/93189/
assert hasattr(code.value, "co_qualname")
fn_name = ConstantVariable(value=code.value.co_qualname)
defaults = None
closure = None
annotations = None
Expand Down Expand Up @@ -1470,10 +1476,12 @@ def MATCH_KEYS(self, inst):
match_obj = tos1.items
if all(key in match_obj for key in keys):
self.push(TupleVariable([match_obj[key] for key in keys]))
self.push(ConstantVariable(True))
if sys.version_info < (3, 11):
self.push(ConstantVariable(True))
else:
self.push(ConstantVariable(None))
self.push(ConstantVariable(False))
if sys.version_info < (3, 11):
self.push(ConstantVariable(False))

UNARY_POSITIVE = stack_op(operator.pos)
UNARY_NEGATIVE = stack_op(operator.neg)
Expand Down
4 changes: 3 additions & 1 deletion torch/_dynamo/variables/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import inspect
import itertools
import sys
import types
from typing import Dict, List

Expand Down Expand Up @@ -472,5 +473,6 @@ def reconstruct(self, codegen):
flags |= 0x08
codegen(self.closure)
codegen(self.code)
codegen(self.fn_name)
if sys.version_info < (3, 11):
codegen(self.fn_name)
return [create_instruction("MAKE_FUNCTION", flags)]

0 comments on commit 3b733a7

Please sign in to comment.