Skip to content

Commit

Permalink
fetch_param in call_ex and call_kw
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 committed Mar 1, 2024
1 parent 7a144ef commit 7fea895
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 17 deletions.
29 changes: 14 additions & 15 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,7 +1993,8 @@ def CALL_FUNCTION_KW(self, inst: Instruction) -> None:
'__self__') and func.__self__ is not None and not isinstance(
func.__self__, ModuleType):
args = [func.__self__] + list(args)
# print(f"function kw: {func}, type: {type(func)},args:{args}, kwargs:{kwargs}")
for i, obj in enumerate(itertools.chain(args, kwargs.values())):
self.state.fetch_function_parameters(obj)
self.call_function(func, args, kwargs)

def CALL_FUNCTION_EX(self, inst: Instruction) -> None:
Expand All @@ -2009,6 +2010,9 @@ def CALL_FUNCTION_EX(self, inst: Instruction) -> None:
'__self__') and func.__self__ is not None and not isinstance(
func.__self__, ModuleType):
args = [func.__self__] + list(args)
if not isinstance(args, torch.Tensor): # call(*x)
for i, obj in enumerate(itertools.chain(args, kwargs.values())):
self.state.fetch_function_parameters(obj)
self.call_function(func, args, kwargs)

def STORE_FAST(self, inst: Instruction) -> None:
Expand Down Expand Up @@ -2112,20 +2116,15 @@ def IMPORT_FROM(self, inst: Instruction) -> None:
pass

def UNPACK_SEQUENCE(self, inst: Instruction) -> None:
seq = get_value_stack_from_top(self.frame, 0)
if isinstance(seq, (tuple, list)):
self.state.set_partial_var({
-1: [
PartialVar(node=None,
need_guard_check=False,
extract_code_at_start=[],
make_var_fn=vs.make_var_from_value)
for _ in range(len(seq))
]
})
else:
print(type(seq))
raise NotImplementedError
self.state.set_partial_var({
-1: [
PartialVar(node=None,
need_guard_check=False,
extract_code_at_start=[],
make_var_fn=vs.make_var_from_value)
for _ in range(inst.argval)
]
})

def UNPACK_EX(self, inst: Instruction) -> None:
seq = get_value_stack_from_top(self.frame, 0)
Expand Down
4 changes: 3 additions & 1 deletion frontend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def enable_dyn_shape() -> Iterator[None]:


def is_high_order_func(func: Callable[..., Any]) -> bool:
return func in high_order_func_list
return func in high_order_func_list or isinstance(func, Generator)


def is_high_order_func_with_udf(func: Callable[..., Any], args: List[Any],
Expand Down Expand Up @@ -438,5 +438,7 @@ def call_user_defined_iterator(x: Any) -> bool:
return len(args) >= 1 and is_user_defined_iter(args[0])
elif func == enumerate:
return len(args) >= 1 and is_user_defined_iter(args[0])
elif isinstance(func, Generator):
return True
else:
raise NotImplementedError
19 changes: 19 additions & 0 deletions test/test_call_function_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,22 @@ def test_call_ex_with_update(caplog):
compiled = compile(outer_call_ex_with_update)
run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b)
run_and_check(compiled, [HIT], 1, caplog, expect, a, b)


def callee_kw(a, b):
return a[0] + b


def caller_kw(a, b):
return callee_kw((a, 2), b=b)


def test_caller_kw(caplog):
reset()
with torch.no_grad():
a = 1
b = 3
expect = caller_kw(a, b)
compiled = compile(caller_kw)
run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b)
run_and_check(compiled, [HIT], 1, caplog, expect, a, b)
19 changes: 18 additions & 1 deletion test/test_list.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from frontend.compile import compile, reset
from common.checker import run_and_check, HIT, MISS, assert_equal
from common.checker import run_and_check, HIT, MISS, ALL_MISS, assert_equal
import torch
import numpy as np

Expand Down Expand Up @@ -204,3 +204,20 @@ def test_list_inplace(caplog):
expect = list_inplace()
run_and_check(compiled, [MISS], 1, caplog, expect)
run_and_check(compiled, [HIT], 1, caplog, expect)


# def unpack_list(a, b):
# a, b = (y + 1 for y in [a,b])
# return a + b

# def test_unpack_list(caplog):
# reset()
# compiled = compile(unpack_list)
# expect = unpack_list(1, 2)
# run_and_check(compiled, [ALL_MISS], 1, caplog, expect, 1,2)
# run_and_check(compiled, [HIT], 1, caplog, expect, 1, 2)
# a = torch.rand((2,2))
# b = torch.rand((2,2))
# expect = unpack_list(a, b)
# run_and_check(compiled, [ALL_MISS], 2, caplog, expect, a, b)
# run_and_check(compiled, [HIT], 2, caplog, expect, a, b)

0 comments on commit 7fea895

Please sign in to comment.