Skip to content

Commit

Permalink
fix some bug that cause wrong answer (apache#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 authored Feb 29, 2024
2 parents 753d0f4 + 119344f commit eaccf13
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 14 deletions.
63 changes: 49 additions & 14 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from types import FrameType, MappingProxyType
from types import FrameType, MappingProxyType, ModuleType
from typing import Dict, Any, Callable, List, Optional, cast, Union
import inspect
import logging
Expand Down Expand Up @@ -31,7 +31,6 @@
from .variables.tuple_ import TupleVar
from .variables.base import Variable
from .control_flow import ControlFlowInfo, LoopModule, ForLoopInfo, LoopPosMap, if_stmt, IfStmtInfo
from types import ModuleType

MAKE_VAR_FN_TYPE = Callable[[
Any, bool, vs.HelperFunctions, Optional[FxGraph], Optional[list[StorePos]]
Expand Down Expand Up @@ -81,6 +80,8 @@ class State:
inplace_update_objs: list[Any]
guarded_pcs: list[int]
initial_args: list[Any]
varargs: Optional[Any]
varkw: Optional[Any]
calling_func: Optional[Callable[..., Any]]
callee_returns: Any
can_guard: bool
Expand Down Expand Up @@ -116,6 +117,8 @@ def get_mark_written_fn(state: 'State') -> Callable[[], None]:
self.inplace_update_objs = []
self.guarded_pcs = []
self.initial_args = []
self.varargs = None
self.varkw = None
self.calling_func = None
self.can_guard = True
self.frame_id = -1
Expand Down Expand Up @@ -351,9 +354,22 @@ def from_frame(cls, frame: FrameType, frame_id: int, read_stack: bool,
state.fx_graph,
[StoreInLocal(f"__stack__{i}")])
state.objects.add(var, value)
f_code = frame.f_code
# state.written may be assigned inside make_var_from_value
for var_name in frame.f_code.co_varnames[:frame.f_code.co_argcount]:
nargs = f_code.co_argcount + f_code.co_kwonlyargcount
for var_name in frame.f_code.co_varnames[:nargs]:
state.initial_args.append(frame.f_locals[var_name])
CO_VARARGS = 0x4
if f_code.co_flags & CO_VARARGS:
var_name = f_code.co_varnames[nargs]
nargs += 1
state.varargs = frame.f_locals[var_name]
CO_VARKEYWORDS = 0x8
if f_code.co_flags & CO_VARKEYWORDS:
var_name = f_code.co_varnames[nargs]
nargs += 1
state.varkw = frame.f_locals[var_name]

state.written = False
state.frame_id = frame_id
state.frame_cf_info = frame_cf_info
Expand Down Expand Up @@ -1344,6 +1360,7 @@ def make_sub_var(value: Any, fx_node: torch.fx.Node) -> None:

self.state.inplace_update_objs.clear()
self.state.partial_var.clear()
print("clear partial var")
self.state.written = False
self.state.unmark_calling_func()
# print('process last instruction done')
Expand Down Expand Up @@ -1399,7 +1416,7 @@ def is_genexpr_func(self, func: Callable[..., Any]) -> bool:

def is_builtin_func(self, func: Callable[..., Any]) -> bool:
return func in (dict, tuple, set, list, hasattr, slice, range, len,
super, type, map, filter, enumerate)
type)

def get_live_objs(self, pc: int = -1) -> list[tuple[str, Any]]:
if pc == -1:
Expand Down Expand Up @@ -1560,20 +1577,27 @@ def set_if_inplace_return() -> None:
(set, list, dict, collections.OrderedDict,
MappingProxyType)) and get_root_module(func) != 'torch':
set_if_inplace_return()
if len(args) > 0 and isinstance(
args, list) and func in (list.append, list.extend,
list.clear, list.pop, list.remove,
list.reverse, list.sort):
self.state.add_inplace_update_obj(args[0])
return
elif self.has_arg_of_type(args, kwargs, np.generic):
return
elif self.is_genexpr_func(func):
return
elif self.is_builtin_func(func):
self.state.set_partial_var({
-1: [
PartialVar(node=None,
need_guard_check=False,
extract_code_at_start=[])
]
})
return
elif func in (super, map, filter, enumerate):
# TODO: add map and set correct partial var
# self.state.set_partial_var({
# -1: [
# PartialVar(node=None,
# need_guard_check=False,
# extract_code_at_start=[])
# ]
# })
return
elif is_graph_func(func):
return
Expand All @@ -1586,9 +1610,10 @@ def set_if_inplace_return() -> None:
raise NotImplementedError
elif func == getattr:
if get_method_defined_class(type(args[0]), '__getattr__') in (
torch.nn.Module, object) and get_method_defined_class(
torch.nn.Module, object, None) and get_method_defined_class(
type(args[0]),
'__getattribute__') in (torch.nn.Module, object):
'__getattribute__') in (torch.nn.Module, object,
ModuleType):
arg_obj = self.state.objects.get(args[0])

self.state.set_partial_var({
Expand Down Expand Up @@ -1752,7 +1777,13 @@ def BUILD_STRING(self, _inst: Instruction) -> None:
pass

def LOAD_CONST(self, _inst: Instruction) -> None:
pass
self.state.set_partial_var({
-1: [
PartialVar(node=None,
need_guard_check=False,
extract_code_at_start=[])
]
})

def SETUP_FINALLY(self, _inst: Instruction) -> None:
pass
Expand Down Expand Up @@ -1842,6 +1873,10 @@ def LOAD_ATTR(self, inst: Instruction) -> None:
if inst.argval in obj_var.modified_attrs:
return
need_guard_check = obj_var.need_guard_check
if obj == self.state.varargs and inst.argval in dir(tuple):
need_guard_check = False
if obj == self.state.varkw and inst.argval in dir(dict):
need_guard_check = False
if config.get_config('dynshape') and isinstance(
obj, torch.Tensor) and inst.argval == 'shape':
node: Optional[torch.fx.Node] = self.state.fx_graph.create_node(
Expand Down
1 change: 1 addition & 0 deletions frontend/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def get_code(self) -> str:
writer.write(self.writer.get_code())
if len(self.checks) == 0:
writer.wl(f"ok = True")
writer.wl(f"missed_check = []")
else:
writer.wl(f"ok = True")
writer.wl(f"missed_check = []")
Expand Down
6 changes: 6 additions & 0 deletions frontend/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Variable:
obj: Any
modified_attrs: dict[str, 'Variable']
prev: Optional['Variable'] = None
succ: Optional['Variable'] = None

def __init__(self, need_guard_check: bool, obj: Any,
extract_code_at_start: list[StorePos]) -> None:
Expand Down Expand Up @@ -79,6 +80,9 @@ def make_guard_inner(self, codegen: "GuardFnCodegen",
def make_output(self, name_in_graph_fn: str, store_pos: StorePos,
codegen: "GraphFnCodegen", in_return: bool,
idx: int) -> None:
if self.succ is not None:
return self.succ.make_output(name_in_graph_fn, store_pos, codegen,
in_return, idx)
if idx in codegen.id2name:
codegen.output(name_in_graph_fn, store_pos, codegen.id2name[idx],
in_return, 0)
Expand Down Expand Up @@ -108,6 +112,8 @@ def add_subvars_to_table(self, table: 'ObjectTable') -> None:

def set_prev(self, prev: Optional['Variable']) -> None:
self.prev = prev
if prev is not None:
prev.succ = self

def get_subvars_with_idx(self) -> Iterable[Tuple["Variable", int]]:
return []
Expand Down
39 changes: 39 additions & 0 deletions test/test_call_function_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,42 @@ def test_closure_call(caplog):
expect_result = model(a)
run_and_check(compiled, [ALL_MISS], 1, caplog, expect_result, a)
run_and_check(compiled, [HIT], 1, caplog, expect_result, a)


def inner_call_ex(a, b, **kwargs):
return torch.add(a, b, **kwargs)


def outer_call_ex(a, b):
return inner_call_ex(a, b, alpha=1.0)


def test_call_ex(caplog):
reset()
with torch.no_grad():
a = torch.rand((2, 2))
b = torch.rand((2, 2))
expect = outer_call_ex(a, b)
compiled = compile(outer_call_ex)
run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b)
run_and_check(compiled, [HIT], 1, caplog, expect, a, b)


def inner_call_ex_with_update(a, b, **kwargs):
kwargs.update(alpha=1.0)
return torch.add(a, b, **kwargs)


def outer_call_ex_with_update(a, b):
return inner_call_ex_with_update(a, b, alpha=2.0)


def test_call_ex_with_update(caplog):
reset()
with torch.no_grad():
a = torch.rand((2, 2))
b = torch.rand((2, 2))
expect = outer_call_ex_with_update(a, b)
compiled = compile(outer_call_ex_with_update)
run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b)
run_and_check(compiled, [HIT], 1, caplog, expect, a, b)
14 changes: 14 additions & 0 deletions test/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,17 @@ def test_list_comp_with_wrapper(caplog):
caplog, expect, a, b)
run_and_check(compiled_list_comp_with_wrapper, [HIT], 1, caplog, expect, a,
b)


def list_inplace():
a = [1]
a.append(2)
return (3, a)


def test_list_inplace(caplog):
reset()
compiled = compile(list_inplace)
expect = list_inplace()
run_and_check(compiled, [MISS], 1, caplog, expect)
run_and_check(compiled, [HIT], 1, caplog, expect)
15 changes: 15 additions & 0 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,18 @@ def test_get_device_states(caplog):
# compiled = compile(tuple_view1)
# run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a)
# run_and_check(compiled, [HIT], 1, caplog, expect, a)


def run_getattr_relu(x):
func = getattr(torch.nn.functional, 'relu')
return func(x)


def test_run_getattr_relu(caplog):
reset()
with torch.no_grad():
inp = torch.rand((2, 2))
expect = run_getattr_relu(inp)
compiled = compile(run_getattr_relu)
run_and_check(compiled, [ALL_MISS], 1, caplog, expect, inp)
run_and_check(compiled, [HIT], 1, caplog, expect, inp)

0 comments on commit eaccf13

Please sign in to comment.