Skip to content

Commit

Permalink
fix some bugs (apache#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 authored Mar 1, 2024
2 parents eaccf13 + 7fea895 commit 38df973
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 21 deletions.
4 changes: 4 additions & 0 deletions frontend/c_api.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,7 @@ def set_cell(cell: CellType, value: Any) -> None:

def set_local(frame: FrameType, idx: int, value: Any) -> None:
pass


def parse_type_obj(obj: Any) -> str:
pass
1 change: 1 addition & 0 deletions frontend/csrc/csrc.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,6 @@ PyObject *parse_mapproxyobject(PyObject *self, PyObject *args);
PyObject *parse_mapobject(PyObject *self, PyObject *args);
PyObject *parse_cell(PyObject *self, PyObject *args);
PyObject *set_cell(PyObject *self, PyObject *args);
PyObject *parse_type_obj(PyObject *self, PyObject *args);

} // namespace frontend_csrc
1 change: 1 addition & 0 deletions frontend/csrc/frame_evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ static PyMethodDef _methods[] = {
{"parse_mapobject", frontend_csrc::parse_mapobject, METH_VARARGS, NULL},
{"parse_cell", frontend_csrc::parse_cell, METH_VARARGS, NULL},
{"set_cell", frontend_csrc::set_cell, METH_VARARGS, NULL},
{"parse_type_obj", frontend_csrc::parse_type_obj, METH_VARARGS, NULL},
{NULL, NULL, 0, NULL},
};

Expand Down
11 changes: 11 additions & 0 deletions frontend/csrc/parse_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,15 @@ PyObject *set_cell(PyObject *self, PyObject *args) {
return Py_None;
}

PyObject *parse_type_obj(PyObject *self, PyObject *args) {
PyObject *obj;
if (!PyArg_ParseTuple(args, "O", &obj)) {
return NULL;
}
if (PyType_Check(obj)) {
return PyUnicode_FromString(((PyTypeObject *)obj)->tp_name);
}
PyErr_SetString(PyExc_TypeError, "Expected type object");
return NULL;
}
} // namespace frontend_csrc
74 changes: 55 additions & 19 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,11 @@ def add_submodule(self, module: torch.nn.Module) -> None:
self.update_subpath(module, new_module_name)
# self.written = True # not mark as written as graph break may happen

def add_subparam(self, param: torch.nn.Parameter) -> None:
def add_subparam(self, param: torch.nn.Parameter) -> str:
new_param_name = "external_param__" + str(len(self.subparam_paths))
self.root.register_parameter(new_param_name, param)
self.subparam_paths[param] = new_param_name
return new_param_name

def as_node_args_kwargs(
self, args: list[Any], kwargs: dict[str, Any]
Expand All @@ -172,6 +173,11 @@ def as_fx_node(arg: Any) -> NodeArgs:
if isinstance(arg, slice):
return slice(as_fx_node(arg.start), as_fx_node(arg.stop),
as_fx_node(arg.step))
if isinstance(arg, np.ndarray):
param_name = self.add_subparam(
torch.nn.Parameter(torch.tensor(arg), requires_grad=False))
return self.fx_graph.create_node("get_attr", param_name, (), {})

var = self.objects.get(arg,
allow_unexist_const=True,
fx_graph=self.fx_graph)
Expand All @@ -192,6 +198,9 @@ def as_fx_node(arg: Any) -> NodeArgs:
else:
# TODO: record all operation in SymInt or SymFloat
pass

if f"{type(arg).__module__}.{type(arg).__qualname__}" == "torch.tensortype": # torch.LongTensor
return f"torch.{arg.__name__}"
return var.as_fx_node()

if isinstance(args, torch.Tensor):
Expand Down Expand Up @@ -225,6 +234,19 @@ def record_function(self,
add_partial_var: bool = True,
inplace_ref: Any = None,
force_new_value: bool = False) -> None:
if hasattr(func, '__self__') and isinstance(
func.__self__, torch.autograd.grad_mode.no_grad):
if func.__name__ == '__enter__':
target_state = False
elif func.__name__ == '__exit__':
target_state = func.__self__.prev
else:
raise ValueError(func)
args = [
target_state,
]
func = torch._C._set_grad_enabled
kwargs = {}
pargs, pkwargs = self.as_node_args_kwargs(args, kwargs)
if func in fx_graph_inplace_functions:
scalar = None
Expand Down Expand Up @@ -268,6 +290,8 @@ def record_function(self,
func = func_dict[func]
if func in math2torch:
func = math2torch[func]
if func == torch.from_numpy:
func = torch.tensor

self.written = True
scalar2tensor: dict[Callable[..., Any], Callable[..., Any]] = {
Expand Down Expand Up @@ -1360,7 +1384,6 @@ def make_sub_var(value: Any, fx_node: torch.fx.Node) -> None:

self.state.inplace_update_objs.clear()
self.state.partial_var.clear()
print("clear partial var")
self.state.written = False
self.state.unmark_calling_func()
# print('process last instruction done')
Expand Down Expand Up @@ -1418,6 +1441,15 @@ def is_builtin_func(self, func: Callable[..., Any]) -> bool:
return func in (dict, tuple, set, list, hasattr, slice, range, len,
type)

def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool:
print(dir(func))
if (hasattr(func, '__module__') and 'numpy' in func.__module__ and
'random' not in func.__module__):
return True
if type(func) == np.ufunc:
return True
return False

def get_live_objs(self, pc: int = -1) -> list[tuple[str, Any]]:
if pc == -1:
pc = self.frame.f_lasti // 2
Expand Down Expand Up @@ -1603,6 +1635,8 @@ def set_if_inplace_return() -> None:
return
elif len(args) > 0 and isinstance(args[0], torch.nn.ModuleList):
return
elif self.is_numpy_constant_func(func):
return
elif self.has_unknown_arg(args, kwargs):
print(
f"func is {func}, {is_user_defined_func(func)}, args: {args}, kwargs:{kwargs}"
Expand Down Expand Up @@ -1789,7 +1823,9 @@ def SETUP_FINALLY(self, _inst: Instruction) -> None:
pass

def SETUP_WITH(self, _inst: Instruction) -> None:
pass
mgr = get_value_stack_from_top(self.frame, 0)
if type(mgr) == torch.autograd.grad_mode.no_grad:
self.call_function(mgr.__enter__, [], {})

# def WITH_EXCEPT_START(self, _inst: Instruction) -> None:
# pass
Expand Down Expand Up @@ -1873,9 +1909,9 @@ def LOAD_ATTR(self, inst: Instruction) -> None:
if inst.argval in obj_var.modified_attrs:
return
need_guard_check = obj_var.need_guard_check
if obj == self.state.varargs and inst.argval in dir(tuple):
if id(obj) == id(self.state.varargs) and inst.argval in dir(tuple):
need_guard_check = False
if obj == self.state.varkw and inst.argval in dir(dict):
if id(obj) == id(self.state.varkw) and inst.argval in dir(dict):
need_guard_check = False
if config.get_config('dynshape') and isinstance(
obj, torch.Tensor) and inst.argval == 'shape':
Expand Down Expand Up @@ -1957,7 +1993,8 @@ def CALL_FUNCTION_KW(self, inst: Instruction) -> None:
'__self__') and func.__self__ is not None and not isinstance(
func.__self__, ModuleType):
args = [func.__self__] + list(args)
# print(f"function kw: {func}, type: {type(func)},args:{args}, kwargs:{kwargs}")
for i, obj in enumerate(itertools.chain(args, kwargs.values())):
self.state.fetch_function_parameters(obj)
self.call_function(func, args, kwargs)

def CALL_FUNCTION_EX(self, inst: Instruction) -> None:
Expand All @@ -1973,6 +2010,9 @@ def CALL_FUNCTION_EX(self, inst: Instruction) -> None:
'__self__') and func.__self__ is not None and not isinstance(
func.__self__, ModuleType):
args = [func.__self__] + list(args)
if not isinstance(args, torch.Tensor): # call(*x)
for i, obj in enumerate(itertools.chain(args, kwargs.values())):
self.state.fetch_function_parameters(obj)
self.call_function(func, args, kwargs)

def STORE_FAST(self, inst: Instruction) -> None:
Expand Down Expand Up @@ -2076,19 +2116,15 @@ def IMPORT_FROM(self, inst: Instruction) -> None:
pass

def UNPACK_SEQUENCE(self, inst: Instruction) -> None:
seq = get_value_stack_from_top(self.frame, 0)
if isinstance(seq, (tuple, list)):
self.state.set_partial_var({
-1: [
PartialVar(node=None,
need_guard_check=False,
extract_code_at_start=[],
make_var_fn=vs.make_var_from_value)
for _ in range(len(seq))
]
})
else:
raise NotImplementedError
self.state.set_partial_var({
-1: [
PartialVar(node=None,
need_guard_check=False,
extract_code_at_start=[],
make_var_fn=vs.make_var_from_value)
for _ in range(inst.argval)
]
})

def UNPACK_EX(self, inst: Instruction) -> None:
seq = get_value_stack_from_top(self.frame, 0)
Expand Down
11 changes: 10 additions & 1 deletion frontend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch._C
import collections
from .config import get_config, set_config
from .c_api import parse_type_obj

if TYPE_CHECKING:
from .instruction import Instruction
Expand Down Expand Up @@ -202,6 +203,12 @@ def is_user_defined_func(func: Callable[..., Any]) -> bool:
assert hasattr(func, '__self__')
return is_user_defined_func(func.__self__)

if inspect.isclass(func):
tp_name = parse_type_obj(func)
module = tp_name.split(".")[0]
if module in ("itertools",):
return False

if func is super:
return False

Expand Down Expand Up @@ -393,7 +400,7 @@ def enable_dyn_shape() -> Iterator[None]:


def is_high_order_func(func: Callable[..., Any]) -> bool:
return func in high_order_func_list
return func in high_order_func_list or isinstance(func, Generator)


def is_high_order_func_with_udf(func: Callable[..., Any], args: List[Any],
Expand Down Expand Up @@ -431,5 +438,7 @@ def call_user_defined_iterator(x: Any) -> bool:
return len(args) >= 1 and is_user_defined_iter(args[0])
elif func == enumerate:
return len(args) >= 1 and is_user_defined_iter(args[0])
elif isinstance(func, Generator):
return True
else:
raise NotImplementedError
19 changes: 19 additions & 0 deletions test/test_call_function_ex.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,22 @@ def test_call_ex_with_update(caplog):
compiled = compile(outer_call_ex_with_update)
run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b)
run_and_check(compiled, [HIT], 1, caplog, expect, a, b)


def callee_kw(a, b):
return a[0] + b


def caller_kw(a, b):
return callee_kw((a, 2), b=b)


def test_caller_kw(caplog):
reset()
with torch.no_grad():
a = 1
b = 3
expect = caller_kw(a, b)
compiled = compile(caller_kw)
run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b)
run_and_check(compiled, [HIT], 1, caplog, expect, a, b)
19 changes: 18 additions & 1 deletion test/test_list.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from frontend.compile import compile, reset
from common.checker import run_and_check, HIT, MISS, assert_equal
from common.checker import run_and_check, HIT, MISS, ALL_MISS, assert_equal
import torch
import numpy as np

Expand Down Expand Up @@ -204,3 +204,20 @@ def test_list_inplace(caplog):
expect = list_inplace()
run_and_check(compiled, [MISS], 1, caplog, expect)
run_and_check(compiled, [HIT], 1, caplog, expect)


# def unpack_list(a, b):
# a, b = (y + 1 for y in [a,b])
# return a + b

# def test_unpack_list(caplog):
# reset()
# compiled = compile(unpack_list)
# expect = unpack_list(1, 2)
# run_and_check(compiled, [ALL_MISS], 1, caplog, expect, 1,2)
# run_and_check(compiled, [HIT], 1, caplog, expect, 1, 2)
# a = torch.rand((2,2))
# b = torch.rand((2,2))
# expect = unpack_list(a, b)
# run_and_check(compiled, [ALL_MISS], 2, caplog, expect, a, b)
# run_and_check(compiled, [HIT], 2, caplog, expect, a, b)
16 changes: 16 additions & 0 deletions test/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,19 @@ def test_numpy_to_int(caplog):
result = numpy_to_int(10)
run_and_check(compiled_numpy_to_int, [MISS], 1, caplog, result, 10)
run_and_check(compiled_numpy_to_int, [HIT], 1, caplog, result, 10)


def numpy_to_torch(x):
y = np.floor((x - 1) / 2)
return torch.tensor(y)


def test_numpy_to_torch(caplog):
from frontend.utils import SetConfig
with SetConfig({"backend": "eager"}):
reset()
compiled = compile(numpy_to_torch)
a = np.array([1, 2.0, 3.33])
result = numpy_to_torch(a)
run_and_check(compiled, [MISS], 1, caplog, result, a)
run_and_check(compiled, [HIT], 1, caplog, result, a)
15 changes: 15 additions & 0 deletions test/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,18 @@ def test_dynamic_scalar_from_tensor(caplog):
bb = torch.tensor(5.0)
expect = dynamic_scalar_from_tensor(aa, bb, c)
run_and_check(compiled, [HIT], 1, caplog, expect, aa, bb, c)


def itertools_product(a, b):
import itertools
return list(itertools.product(a, b))


def test_itertools_product(caplog):
reset()
a = [1, 2]
b = [3, 4]
expect = itertools_product(a, b)
compiled = compile(itertools_product)
run_and_check(compiled, [MISS], 1, caplog, expect, a, b)
run_and_check(compiled, [HIT], 1, caplog, expect, a, b)
30 changes: 30 additions & 0 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,33 @@ def test_run_getattr_relu(caplog):
compiled = compile(run_getattr_relu)
run_and_check(compiled, [ALL_MISS], 1, caplog, expect, inp)
run_and_check(compiled, [HIT], 1, caplog, expect, inp)


def run_type_tensor(x):
return x.type(torch.LongTensor)


def test_run_type_tensor(caplog):
reset()
with torch.no_grad():
inp = torch.rand((2, 2))
expect = run_type_tensor(inp)
compiled = compile(run_type_tensor)
run_and_check(compiled, [MISS], 1, caplog, expect, inp)
run_and_check(compiled, [HIT], 1, caplog, expect, inp)


def run_no_grad(x):
with torch.no_grad():
y = x * 2
return y


def test_no_grad(caplog):
reset()
with torch.no_grad():
inp = torch.rand((2, 2))
expect = run_no_grad(inp)
compiled = compile(run_no_grad)
run_and_check(compiled, [MISS], 1, caplog, expect, inp)
run_and_check(compiled, [HIT], 1, caplog, expect, inp)

0 comments on commit 38df973

Please sign in to comment.