diff --git a/frontend/c_api.pyi b/frontend/c_api.pyi index 6e5dbe973959..f23ab8c63ab1 100644 --- a/frontend/c_api.pyi +++ b/frontend/c_api.pyi @@ -112,3 +112,7 @@ def set_cell(cell: CellType, value: Any) -> None: def set_local(frame: FrameType, idx: int, value: Any) -> None: pass + + +def parse_type_obj(obj: Any) -> str: + pass \ No newline at end of file diff --git a/frontend/csrc/csrc.h b/frontend/csrc/csrc.h index 5da58fbcc0f2..4c440002041b 100644 --- a/frontend/csrc/csrc.h +++ b/frontend/csrc/csrc.h @@ -58,5 +58,6 @@ PyObject *parse_mapproxyobject(PyObject *self, PyObject *args); PyObject *parse_mapobject(PyObject *self, PyObject *args); PyObject *parse_cell(PyObject *self, PyObject *args); PyObject *set_cell(PyObject *self, PyObject *args); +PyObject *parse_type_obj(PyObject *self, PyObject *args); } // namespace frontend_csrc diff --git a/frontend/csrc/frame_evaluation.cpp b/frontend/csrc/frame_evaluation.cpp index 05557eaae48a..9d3d7e93e3e5 100644 --- a/frontend/csrc/frame_evaluation.cpp +++ b/frontend/csrc/frame_evaluation.cpp @@ -696,6 +696,7 @@ static PyMethodDef _methods[] = { {"parse_mapobject", frontend_csrc::parse_mapobject, METH_VARARGS, NULL}, {"parse_cell", frontend_csrc::parse_cell, METH_VARARGS, NULL}, {"set_cell", frontend_csrc::set_cell, METH_VARARGS, NULL}, + {"parse_type_obj", frontend_csrc::parse_type_obj, METH_VARARGS, NULL}, {NULL, NULL, 0, NULL}, }; diff --git a/frontend/csrc/parse_types.cpp b/frontend/csrc/parse_types.cpp index 8739d4c6b555..5c5708f30c83 100644 --- a/frontend/csrc/parse_types.cpp +++ b/frontend/csrc/parse_types.cpp @@ -110,4 +110,15 @@ PyObject *set_cell(PyObject *self, PyObject *args) { return Py_None; } +PyObject *parse_type_obj(PyObject *self, PyObject *args) { + PyObject *obj; + if (!PyArg_ParseTuple(args, "O", &obj)) { + return NULL; + } + if (PyType_Check(obj)) { + return PyUnicode_FromString(((PyTypeObject *)obj)->tp_name); + } + PyErr_SetString(PyExc_TypeError, "Expected type object"); + return NULL; +} } // namespace frontend_csrc \ No newline at end of file diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index f692484fc259..cebe0f88c16c 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -145,10 +145,11 @@ def add_submodule(self, module: torch.nn.Module) -> None: self.update_subpath(module, new_module_name) # self.written = True # not mark as written as graph break may happen - def add_subparam(self, param: torch.nn.Parameter) -> None: + def add_subparam(self, param: torch.nn.Parameter) -> str: new_param_name = "external_param__" + str(len(self.subparam_paths)) self.root.register_parameter(new_param_name, param) self.subparam_paths[param] = new_param_name + return new_param_name def as_node_args_kwargs( self, args: list[Any], kwargs: dict[str, Any] @@ -172,6 +173,11 @@ def as_fx_node(arg: Any) -> NodeArgs: if isinstance(arg, slice): return slice(as_fx_node(arg.start), as_fx_node(arg.stop), as_fx_node(arg.step)) + if isinstance(arg, np.ndarray): + param_name = self.add_subparam( + torch.nn.Parameter(torch.tensor(arg), requires_grad=False)) + return self.fx_graph.create_node("get_attr", param_name, (), {}) + var = self.objects.get(arg, allow_unexist_const=True, fx_graph=self.fx_graph) @@ -192,6 +198,9 @@ def as_fx_node(arg: Any) -> NodeArgs: else: # TODO: record all operation in SymInt or SymFloat pass + + if f"{type(arg).__module__}.{type(arg).__qualname__}" == "torch.tensortype": # torch.LongTensor + return f"torch.{arg.__name__}" return var.as_fx_node() if isinstance(args, torch.Tensor): @@ -225,6 +234,19 @@ def record_function(self, add_partial_var: bool = True, inplace_ref: Any = None, force_new_value: bool = False) -> None: + if hasattr(func, '__self__') and isinstance( + func.__self__, torch.autograd.grad_mode.no_grad): + if func.__name__ == '__enter__': + target_state = False + elif func.__name__ == '__exit__': + target_state = func.__self__.prev + else: + raise ValueError(func) + args = [ + target_state, + ] + func = torch._C._set_grad_enabled + kwargs = {} pargs, pkwargs = self.as_node_args_kwargs(args, kwargs) if func in fx_graph_inplace_functions: scalar = None @@ -268,6 +290,8 @@ def record_function(self, func = func_dict[func] if func in math2torch: func = math2torch[func] + if func == torch.from_numpy: + func = torch.tensor self.written = True scalar2tensor: dict[Callable[..., Any], Callable[..., Any]] = { @@ -1360,7 +1384,6 @@ def make_sub_var(value: Any, fx_node: torch.fx.Node) -> None: self.state.inplace_update_objs.clear() self.state.partial_var.clear() - print("clear partial var") self.state.written = False self.state.unmark_calling_func() # print('process last instruction done') @@ -1418,6 +1441,15 @@ def is_builtin_func(self, func: Callable[..., Any]) -> bool: return func in (dict, tuple, set, list, hasattr, slice, range, len, type) + def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool: + print(dir(func)) + if (hasattr(func, '__module__') and 'numpy' in func.__module__ and + 'random' not in func.__module__): + return True + if type(func) == np.ufunc: + return True + return False + def get_live_objs(self, pc: int = -1) -> list[tuple[str, Any]]: if pc == -1: pc = self.frame.f_lasti // 2 @@ -1603,6 +1635,8 @@ def set_if_inplace_return() -> None: return elif len(args) > 0 and isinstance(args[0], torch.nn.ModuleList): return + elif self.is_numpy_constant_func(func): + return elif self.has_unknown_arg(args, kwargs): print( f"func is {func}, {is_user_defined_func(func)}, args: {args}, kwargs:{kwargs}" @@ -1789,7 +1823,9 @@ def SETUP_FINALLY(self, _inst: Instruction) -> None: pass def SETUP_WITH(self, _inst: Instruction) -> None: - pass + mgr = get_value_stack_from_top(self.frame, 0) + if type(mgr) == torch.autograd.grad_mode.no_grad: + self.call_function(mgr.__enter__, [], {}) # def WITH_EXCEPT_START(self, _inst: Instruction) -> None: # pass @@ -1873,9 +1909,9 @@ def LOAD_ATTR(self, inst: Instruction) -> None: if inst.argval in obj_var.modified_attrs: return need_guard_check = obj_var.need_guard_check - if obj == self.state.varargs and inst.argval in dir(tuple): + if id(obj) == id(self.state.varargs) and inst.argval in dir(tuple): need_guard_check = False - if obj == self.state.varkw and inst.argval in dir(dict): + if id(obj) == id(self.state.varkw) and inst.argval in dir(dict): need_guard_check = False if config.get_config('dynshape') and isinstance( obj, torch.Tensor) and inst.argval == 'shape': @@ -1957,7 +1993,8 @@ def CALL_FUNCTION_KW(self, inst: Instruction) -> None: '__self__') and func.__self__ is not None and not isinstance( func.__self__, ModuleType): args = [func.__self__] + list(args) - # print(f"function kw: {func}, type: {type(func)},args:{args}, kwargs:{kwargs}") + for i, obj in enumerate(itertools.chain(args, kwargs.values())): + self.state.fetch_function_parameters(obj) self.call_function(func, args, kwargs) def CALL_FUNCTION_EX(self, inst: Instruction) -> None: @@ -1973,6 +2010,9 @@ def CALL_FUNCTION_EX(self, inst: Instruction) -> None: '__self__') and func.__self__ is not None and not isinstance( func.__self__, ModuleType): args = [func.__self__] + list(args) + if not isinstance(args, torch.Tensor): # call(*x) + for i, obj in enumerate(itertools.chain(args, kwargs.values())): + self.state.fetch_function_parameters(obj) self.call_function(func, args, kwargs) def STORE_FAST(self, inst: Instruction) -> None: @@ -2076,19 +2116,15 @@ def IMPORT_FROM(self, inst: Instruction) -> None: pass def UNPACK_SEQUENCE(self, inst: Instruction) -> None: - seq = get_value_stack_from_top(self.frame, 0) - if isinstance(seq, (tuple, list)): - self.state.set_partial_var({ - -1: [ - PartialVar(node=None, - need_guard_check=False, - extract_code_at_start=[], - make_var_fn=vs.make_var_from_value) - for _ in range(len(seq)) - ] - }) - else: - raise NotImplementedError + self.state.set_partial_var({ + -1: [ + PartialVar(node=None, + need_guard_check=False, + extract_code_at_start=[], + make_var_fn=vs.make_var_from_value) + for _ in range(inst.argval) + ] + }) def UNPACK_EX(self, inst: Instruction) -> None: seq = get_value_stack_from_top(self.frame, 0) diff --git a/frontend/utils.py b/frontend/utils.py index 6686ab4fcfcb..6ac86998fc79 100644 --- a/frontend/utils.py +++ b/frontend/utils.py @@ -12,6 +12,7 @@ import torch._C import collections from .config import get_config, set_config +from .c_api import parse_type_obj if TYPE_CHECKING: from .instruction import Instruction @@ -202,6 +203,12 @@ def is_user_defined_func(func: Callable[..., Any]) -> bool: assert hasattr(func, '__self__') return is_user_defined_func(func.__self__) + if inspect.isclass(func): + tp_name = parse_type_obj(func) + module = tp_name.split(".")[0] + if module in ("itertools",): + return False + if func is super: return False @@ -393,7 +400,7 @@ def enable_dyn_shape() -> Iterator[None]: def is_high_order_func(func: Callable[..., Any]) -> bool: - return func in high_order_func_list + return func in high_order_func_list or isinstance(func, Generator) def is_high_order_func_with_udf(func: Callable[..., Any], args: List[Any], @@ -431,5 +438,7 @@ def call_user_defined_iterator(x: Any) -> bool: return len(args) >= 1 and is_user_defined_iter(args[0]) elif func == enumerate: return len(args) >= 1 and is_user_defined_iter(args[0]) + elif isinstance(func, Generator): + return True else: raise NotImplementedError diff --git a/test/test_call_function_ex.py b/test/test_call_function_ex.py index 6eab6c02f821..241d4b32e75f 100644 --- a/test/test_call_function_ex.py +++ b/test/test_call_function_ex.py @@ -96,3 +96,22 @@ def test_call_ex_with_update(caplog): compiled = compile(outer_call_ex_with_update) run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b) run_and_check(compiled, [HIT], 1, caplog, expect, a, b) + + +def callee_kw(a, b): + return a[0] + b + + +def caller_kw(a, b): + return callee_kw((a, 2), b=b) + + +def test_caller_kw(caplog): + reset() + with torch.no_grad(): + a = 1 + b = 3 + expect = caller_kw(a, b) + compiled = compile(caller_kw) + run_and_check(compiled, [ALL_MISS], 1, caplog, expect, a, b) + run_and_check(compiled, [HIT], 1, caplog, expect, a, b) diff --git a/test/test_list.py b/test/test_list.py index 2364cfdbbc78..b1a2d8e48ad2 100644 --- a/test/test_list.py +++ b/test/test_list.py @@ -1,5 +1,5 @@ from frontend.compile import compile, reset -from common.checker import run_and_check, HIT, MISS, assert_equal +from common.checker import run_and_check, HIT, MISS, ALL_MISS, assert_equal import torch import numpy as np @@ -204,3 +204,20 @@ def test_list_inplace(caplog): expect = list_inplace() run_and_check(compiled, [MISS], 1, caplog, expect) run_and_check(compiled, [HIT], 1, caplog, expect) + + +# def unpack_list(a, b): +# a, b = (y + 1 for y in [a,b]) +# return a + b + +# def test_unpack_list(caplog): +# reset() +# compiled = compile(unpack_list) +# expect = unpack_list(1, 2) +# run_and_check(compiled, [ALL_MISS], 1, caplog, expect, 1,2) +# run_and_check(compiled, [HIT], 1, caplog, expect, 1, 2) +# a = torch.rand((2,2)) +# b = torch.rand((2,2)) +# expect = unpack_list(a, b) +# run_and_check(compiled, [ALL_MISS], 2, caplog, expect, a, b) +# run_and_check(compiled, [HIT], 2, caplog, expect, a, b) diff --git a/test/test_numpy.py b/test/test_numpy.py index bb787d715c71..217f1cb5e826 100644 --- a/test/test_numpy.py +++ b/test/test_numpy.py @@ -30,3 +30,19 @@ def test_numpy_to_int(caplog): result = numpy_to_int(10) run_and_check(compiled_numpy_to_int, [MISS], 1, caplog, result, 10) run_and_check(compiled_numpy_to_int, [HIT], 1, caplog, result, 10) + + +def numpy_to_torch(x): + y = np.floor((x - 1) / 2) + return torch.tensor(y) + + +def test_numpy_to_torch(caplog): + from frontend.utils import SetConfig + with SetConfig({"backend": "eager"}): + reset() + compiled = compile(numpy_to_torch) + a = np.array([1, 2.0, 3.33]) + result = numpy_to_torch(a) + run_and_check(compiled, [MISS], 1, caplog, result, a) + run_and_check(compiled, [HIT], 1, caplog, result, a) \ No newline at end of file diff --git a/test/test_scalar.py b/test/test_scalar.py index bd73263b5f4e..e6de11d4c4d6 100644 --- a/test/test_scalar.py +++ b/test/test_scalar.py @@ -225,3 +225,18 @@ def test_dynamic_scalar_from_tensor(caplog): bb = torch.tensor(5.0) expect = dynamic_scalar_from_tensor(aa, bb, c) run_and_check(compiled, [HIT], 1, caplog, expect, aa, bb, c) + + +def itertools_product(a, b): + import itertools + return list(itertools.product(a, b)) + + +def test_itertools_product(caplog): + reset() + a = [1, 2] + b = [3, 4] + expect = itertools_product(a, b) + compiled = compile(itertools_product) + run_and_check(compiled, [MISS], 1, caplog, expect, a, b) + run_and_check(compiled, [HIT], 1, caplog, expect, a, b) diff --git a/test/test_tensor.py b/test/test_tensor.py index 1e9ef1030c4d..c91e28cf95bd 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -390,3 +390,33 @@ def test_run_getattr_relu(caplog): compiled = compile(run_getattr_relu) run_and_check(compiled, [ALL_MISS], 1, caplog, expect, inp) run_and_check(compiled, [HIT], 1, caplog, expect, inp) + + +def run_type_tensor(x): + return x.type(torch.LongTensor) + + +def test_run_type_tensor(caplog): + reset() + with torch.no_grad(): + inp = torch.rand((2, 2)) + expect = run_type_tensor(inp) + compiled = compile(run_type_tensor) + run_and_check(compiled, [MISS], 1, caplog, expect, inp) + run_and_check(compiled, [HIT], 1, caplog, expect, inp) + + +def run_no_grad(x): + with torch.no_grad(): + y = x * 2 + return y + + +def test_no_grad(caplog): + reset() + with torch.no_grad(): + inp = torch.rand((2, 2)) + expect = run_no_grad(inp) + compiled = compile(run_no_grad) + run_and_check(compiled, [MISS], 1, caplog, expect, inp) + run_and_check(compiled, [HIT], 1, caplog, expect, inp)