diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index d86cbecba213..7989a49a4826 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -21,4 +21,4 @@ set -u set -o pipefail # install libraries for python package on ubuntu -pip3 install six numpy pytest cython decorator scipy tornado typed_ast pytest pytest-xdist pytest-profiling mypy orderedset attrs requests Pillow packaging cloudpickle synr +pip3 install six numpy pytest cython decorator scipy tornado pytest pytest-xdist pytest-profiling mypy orderedset attrs requests Pillow packaging cloudpickle synr diff --git a/python/setup.py b/python/setup.py index 5333da0da239..ec98e94f80eb 100644 --- a/python/setup.py +++ b/python/setup.py @@ -183,7 +183,7 @@ def get_package_data_files(): "decorator", "attrs", "psutil", - "typed_ast", + "synr>=0.2.1", ], extras_require={ "test": ["pillow<7", "matplotlib"], diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 8ad39354e5cf..955266c4a3e0 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -70,5 +70,5 @@ def lookup_symbol(self, name): return symbols[name] return None - def report_error(self, message): - self.parser.report_error(message) + def report_error(self, message, span): + self.parser.report_error(message, span) diff --git a/python/tvm/script/diagnostics.py b/python/tvm/script/diagnostics.py new file mode 100644 index 000000000000..fc196f6b16ae --- /dev/null +++ b/python/tvm/script/diagnostics.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Bridge from synr's (the library used for parsing the python AST) + DiagnosticContext to TVM's diagnostics +""" +import tvm +from synr import DiagnosticContext, ast +from tvm.ir.diagnostics import DiagnosticContext as TVMCtx +from tvm.ir.diagnostics import get_renderer, DiagnosticLevel, Diagnostic + + +class TVMDiagnosticCtx(DiagnosticContext): + """TVM diagnostics for synr""" + + diag_ctx: TVMCtx + + def __init__(self) -> None: + self.diag_ctx = TVMCtx(tvm.IRModule(), get_renderer()) + self.source_name = None + + def to_tvm_span(self, src_name, ast_span: ast.Span) -> tvm.ir.Span: + return tvm.ir.Span( + src_name, + ast_span.start_line, + ast_span.end_line, + ast_span.start_column, + ast_span.end_column, + ) + + def add_source(self, name: str, source: str) -> None: + src_name = self.diag_ctx.module.source_map.add(name, source) + self.source_name = src_name + + def emit(self, _level, message, span): + span = self.to_tvm_span(self.source_name, span) + self.diag_ctx.emit(Diagnostic(DiagnosticLevel.ERROR, span, message)) + self.diag_ctx.render() # Raise exception on the first error we hit. TODO remove + + def render(self): + self.diag_ctx.render() diff --git a/python/tvm/script/meta_unparser.py b/python/tvm/script/meta_unparser.py index d56fbad3d1e3..b1472ccdc758 100644 --- a/python/tvm/script/meta_unparser.py +++ b/python/tvm/script/meta_unparser.py @@ -17,34 +17,29 @@ """Unparse meta AST node into a dict""" # pylint: disable=invalid-name -from typed_ast import ast3 as ast +from synr import Transformer -class MetaUnparser(ast.NodeVisitor): +class MetaUnparser(Transformer): """Python AST Visitor to unparse meta AST node into a dict""" - def visit_Dict(self, node): + def transform(self, node): + method = "transform_" + node.__class__.__name__ + visitor = getattr(self, method, None) + if visitor is None: + self.error(f"Unexpected node type {type(node)} when parsing __tvm_meta__", node.span) + return visitor(node) + + def transform_DictLiteral(self, node): keys = [self.visit(key) for key in node.keys] values = [self.visit(value) for value in node.values] return dict(zip(keys, values)) - def visit_Tuple(self, node): + def transform_Tuple(self, node): return tuple(self.visit(element) for element in node.elts) - def visit_List(self, node): + def transform_ArrayLiteral(self, node): return [self.visit(element) for element in node.elts] - def visit_keyword(self, node): - return node.arg, self.visit(node.value) - - def visit_NameConstant(self, node): - return node.value - - def visit_Constant(self, node): + def transform_Constant(self, node): return node.value - - def visit_Num(self, node): - return node.n - - def visit_Str(self, node): - return node.s diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 70aa3fe34387..6ce682778e5c 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -14,21 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""TVM Script Parser For TIR""" -# pylint: disable=invalid-name, missing-docstring, inconsistent-return-statements, no-else-return -# pylint: disable=unnecessary-comprehension, unused-argument -# pylint: disable=relative-beyond-top-level +"""TVM Script Parser For TIR + +We use [synr](https://synr.readthedocs.io) to get an AST that is stable over +different python versions. Synr also provides an error handling context that we +use for error reporting. +""" +# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return import json import operator import inspect -from typed_ast import ast3 as ast +from synr import ast, Transformer, to_ast import tvm from tvm import IRModule from tvm._ffi.base import TVMError from tvm.ir import GlobalVar -from tvm.tir import all as _all -from tvm.tir import expr as _expr from . import context_maintainer, ty from .meta_unparser import MetaUnparser @@ -37,31 +38,47 @@ from .special_stmt import SpecialStmt from .scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler from . import _ffi_api +from .diagnostics import TVMDiagnosticCtx class CallArgumentReader(object): - """A helper class which read required argument from passed arguments""" + """Helper class to read required arguments from passed arguments. + + When parsing a function call, we need to match the arguments provided in + the AST to the required arguments of the function. This class makes sure + all the positional arguments are filled and also fill keyword arguments + with thier default value if a different value was not provided. + """ - def __init__(self, func_name, args, kwargs, parser): + def __init__(self, func_name, args, kwargs, parser, node): self.func_name = func_name self.args = args self.kwargs = kwargs self.parser = parser + self.node = node def get_pos_only_arg(self, pos, name): """Get corresponding position only function argument from argument list""" if len(self.args) >= pos: arg = self.args[pos - 1] elif name not in self.kwargs: - self.parser.report_error(self.func_name + " misses argument " + name) + # If no positional argument was found in the AST, we see if it was + # defined by name instead. + # TODO(tkonolige): this error message is not quite correct. The + # number of required arguments is >= pos + self.parser.report_error( + f"{self.func_name} requires {pos} arguments, but only {len(self.args)} were given.", + self.node.span, + ) else: arg = self.kwargs[name] return arg def get_kwarg(self, pos, name, default): - """Get corresponding keyword function argument from argument list - If user doesn't provide the argument, set it to default value + """Get corresponding keyword function argument from argument list. + + If the user hasn't provided the argument, set it to the default value. """ if len(self.args) >= pos: arg = self.args[pos - 1] @@ -79,81 +96,76 @@ def get_varargs(self, pos): return [] -class TVMScriptParserError(RuntimeError): - """TVM script Parser Runtime Error""" +class TVMScriptParser(Transformer): + """Synr AST visitor pass which finally lowers to TIR. - -class TVMScriptParser(ast.NodeVisitor): - """Python AST visitor pass which finally lowers it to TIR - Notes for extension: - 1. To support new types of AST nodes. Add a function visit_xxx(). - 2. To support new functions + Notes for Extension + ------------------- + 1. To support a new type of AST node, add a function transform_xxx(). + 2. To support new functions, add the function to the appropriate registry: We divide allowed function calls in TVM script into 3 categories, - which is intrin, scope_handler and special_stmt. - 1) intrin functions ought to have return value. - User can also register intrin category function into parser. - 2) scope_handler functions have no return value and accepts parser and AST node - as its arguments, which is used in for scope and with scope. - 3) special_stmt functions have return value and accepts parser and AST node as its arguments - When visiting Call node, we check special_stmt registry at first. If no registered function - is found, we then check intrin. - When visiting With node, we check with_scope registry. - When visiting For node, we check for_scope registry. + intrin, scope_handler and special_stmt. + 1. intrin functions are low level functions like mod, load, and + constants. They correspond to a tir `IRNode`. They must have a + return value. The user can register intrin functions for the parser to + use. + 2. scope_handler functions have no return value. They take two + arguments: the parser and the AST node. scope_handler functions are + used in with and for statements. + 3. special_stmt functions handle cases that do not have a corresponding + tir `IRNode`. These functions take the parser and the AST node as + arguments and may return a value. + When visiting a Call node, we check the special_stmt registry first. If + no registered function is found, we then check the intrin registry. + When visiting With node, we check the with_scope registry. + When visiting For node, we check the for_scope registry. """ _binop_maker = { - ast.Add: tvm.tir.Add, - ast.Sub: tvm.tir.Sub, - ast.Mult: tvm.tir.Mul, - ast.Div: tvm.tir.Div, - ast.FloorDiv: tvm.tir.FloorDiv, - ast.Mod: tvm.tir.FloorMod, - ast.BitOr: operator.or_, - ast.BitAnd: operator.and_, - ast.BitXor: operator.xor, - ast.Gt: tvm.tir.GT, - ast.GtE: tvm.tir.GE, - ast.Lt: tvm.tir.LT, - ast.LtE: tvm.tir.LE, - ast.Eq: tvm.tir.EQ, - ast.NotEq: tvm.tir.NE, - ast.And: tvm.tir.And, - ast.Or: tvm.tir.Or, + ast.BuiltinOp.Add: tvm.tir.Add, + ast.BuiltinOp.Sub: tvm.tir.Sub, + ast.BuiltinOp.Mul: tvm.tir.Mul, + ast.BuiltinOp.Div: tvm.tir.Div, + ast.BuiltinOp.FloorDiv: tvm.tir.FloorDiv, + ast.BuiltinOp.Mod: tvm.tir.FloorMod, + ast.BuiltinOp.BitOr: operator.or_, + ast.BuiltinOp.BitAnd: operator.and_, + ast.BuiltinOp.BitXor: operator.xor, + ast.BuiltinOp.GT: tvm.tir.GT, + ast.BuiltinOp.GE: tvm.tir.GE, + ast.BuiltinOp.LT: tvm.tir.LT, + ast.BuiltinOp.LE: tvm.tir.LE, + ast.BuiltinOp.Eq: tvm.tir.EQ, + ast.BuiltinOp.NotEq: tvm.tir.NE, + ast.BuiltinOp.And: tvm.tir.And, + ast.BuiltinOp.Or: tvm.tir.Or, } - _unaryop_maker = {ast.USub: operator.neg, ast.Invert: operator.invert, ast.Not: tvm.tir.Not} + _unaryop_maker = { + ast.BuiltinOp.USub: operator.neg, + ast.BuiltinOp.Invert: operator.invert, + ast.BuiltinOp.Not: tvm.tir.Not, + } - def __init__(self, src, base_lienno): + def __init__(self, base_lienno): self.context = None - self.src = src.split("\n") self.base_lineno = base_lienno self.current_lineno = 0 self.current_col_offset = 0 self.meta = None - self.functions = {} - def init_function_parsing_env(self): """Initialize function parsing environment""" self.context = context_maintainer.ContextMaintainer(self) # scope emitter - @staticmethod - def is_meta(node): - """Judge whether an AST node is META""" - return ( - isinstance(node, ast.Assign) - and len(node.targets) == 1 - and isinstance(node.targets[0], ast.Name) - and node.targets[0].id == "__tvm_meta__" - ) - def init_meta(self, meta_dict): if meta_dict is not None: self.meta = tvm.ir.load_json(json.dumps(meta_dict)) - def visit(self, node): - """Override method in ast.NodeVisitor""" + def transform(self, node): + """Generic transformation for visiting the AST. Dispatches to + `transform_ClassName` for the appropriate ClassName.""" old_lineno, old_col_offset = self.current_lineno, self.current_col_offset if hasattr(node, "lineno"): @@ -161,72 +173,74 @@ def visit(self, node): if hasattr(node, "col_offset"): self.current_col_offset = node.col_offset - method = "visit_" + node.__class__.__name__ + method = "transform_" + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) - visit_res = visitor(node) + transform_res = visitor(node) self.current_lineno, self.current_col_offset = old_lineno, old_col_offset - return visit_res - - def wrap_line_col(self, message, lineno, col_offset): - """Wrap the message with line number and column offset""" - src_line = self.src[lineno - self.base_lineno] - leading_space = len(src_line) - len(src_line.lstrip(" ")) - col_offset = col_offset - leading_space - src_line = src_line[leading_space:] - return ( - "\n " - + src_line - + "\n " - + " " * col_offset - + "^\n" - + "ParserError in line " - + str(lineno) - + " : " - + message - ) + return transform_res + + def report_error(self, message, span): + """Report an error occuring at a location. + + This just dispatches to synr's DiagnosticContext. - def report_error(self, message, lineno=None, col_offset=None): - """Report an error occur in line lineno and column col_offset Parameters ---------- message : str Error message - lineno : int - Line number of error line - col_offset : int - Column offset of error line + span : synr.ast.Span + Location of the error """ + self.error(message, span) - if lineno is None: - lineno = self.current_lineno - if col_offset is None: - col_offset = self.current_col_offset - raise TVMScriptParserError(self.wrap_line_col(message, lineno, col_offset)) + def parse_body(self, parent): + """Parse remaining statements in this scope. - def parse_body(self): + Parameters + ---------- + parent : synr.ast.Node + Parent node of this scope. Errors will be reported here. + """ body = [] + stmt = parent while len(self.context.node_stack[-1]) > 0: - res = self.visit(self.context.node_stack[-1].pop()) + stmt = self.context.node_stack[-1].pop() + res = self.transform(stmt) if res is not None: body.append(res) - return tvm.tir.SeqStmt(body) if len(body) > 1 else body[0] + if len(body) == 0: + self.report_error( + "Expected another statement at the end of this block. Perhaps you " + "used a concise statement and forgot to include a body afterwards.", + stmt.span, + ) + else: + return tvm.tir.SeqStmt(body) if len(body) > 1 else body[0] def parse_arg_list(self, func, node_call): + """Match the arguments of a function call in the AST to the required + arguments of the function. This handles positional arguments, + positional arguments specified by name, keyword arguments, and varargs. + """ assert isinstance(node_call, ast.Call) # collect arguments - args = [self.visit(arg) for arg in node_call.args] - kw_args = [self.visit(keyword) for keyword in node_call.keywords] - kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} + args = [self.transform(arg) for arg in node_call.params] + kw_args = { + self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items() + } # get the name and parameter list of func if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)): func_name, param_list = func.signature() else: - print(func) - raise Exception("Internal Error") + self.report_error( + "Internal Error: function must be of type Intrin, ScopeHandler or SpecialStmt, " + f"but it is {type(func).__name__}", + node_call.span, + ) # check arguments and parameter list and get a list of arguments - reader = CallArgumentReader(func_name, args, kw_args, self) + reader = CallArgumentReader(func_name, args, kw_args, self, node_call) pos_only, kwargs, varargs = param_list internal_args = list() for i, arg_name in enumerate(pos_only): @@ -238,25 +252,26 @@ def parse_arg_list(self, func, node_call): internal_args.extend(reader.get_varargs(len(pos_only) + len(kwargs) + 1)) return internal_args - def parse_type(self, type_node): - """ Parse type """ + def parse_type(self, type_node, parent): + """Parse a type annotation. + + We require the parent object to the type so that we have a place to + report the error message if the type does not exist. + """ if type_node is None: - self.report_error("missing type annotation") - res_type = self.visit(type_node) + self.report_error("A type annotation is required", parent.span) + res_type = self.transform(type_node) return tvm.ir.TupleType([]) if res_type is None else res_type.evaluate() def generic_visit(self, node): - """Override method in ast.NodeVisitor. - To directly filter out invalidate type of stmt. - """ + """Fallback visitor if node type is not handled. Reports an error.""" - self.report_error(type(node).__name__ + " AST node is not supported now") + self.report_error(type(node).__name__ + " AST node is not supported", node.span) - def visit_Module(self, node): + def transform_Module(self, node): """Module visitor - AST abstract grammar: - Module(stmt* body, type_ignore* type_ignore) - By now we support two format of TVM script shown below. + + Right now, we only support two formats for TVM Script. Example ------- @@ -277,7 +292,7 @@ def A(...): import tvm - @tvm.script + @tvm.script.tir class MyMod(): def A(...): ... @@ -290,79 +305,103 @@ def B(...): # returns an IRModule mod = MyMod() """ + if len(node.funcs) == 1: + return self.transform(next(iter(node.funcs.values()))) + elif len(node.func) == 0: + self.report_error( + "You must supply at least one class or function definition", node.span + ) + else: + self.report_error( + "Only one-function, one-class or function-with-meta source code is allowed", + ast.Span.union([x.span for x in list(node.funcs.values())[1:]]), + ) - if len(node.body) == 1 and isinstance(node.body[0], (ast.ClassDef, ast.FunctionDef)): - # class or single function - return self.visit(node.body[0]) - elif len(node.body) == 2: - if isinstance(node.body[0], ast.Assign): - node.body[0], node.body[1] = node.body[1], node.body[0] - if isinstance(node.body[0], ast.FunctionDef) and TVMScriptParser.is_meta(node.body[1]): - # function with meta - self.init_meta(MetaUnparser().visit(node.body[1].value)) - return self.visit(node.body[0]) - self.report_error( - "Only one-function, one-class or function-with-meta source code is allowed" - ) + def transform_Class(self, node): + """Class definition visitor. - def visit_ClassDef(self, node): - """ClassDef visitor - AST abstract grammar: - ClassDef(identifier name, expr* bases, keyword* keywords, stmt* body, - expr* decorator_list) + A class can have multiple function definitions and a single + :code:`__tvm_meta__` statement. Each class corresponds to a single + :code:`IRModule`. + + Example + ------- + .. code-block:: python + + @tvm.script.tir + class MyClass: + __tvm_meta__ = {} + def A(): + tir.evaluate(0) """ + if len(node.assignments) == 1: + if not ( + isinstance(node.assignments[0].lhs, ast.Var) + and node.assignments[0].lhs.id.name == "__tvm_meta__" + ): + self.report_error( + "The only top level assignments allowed are `__tvm_meta__ = ...`", + node.assignments[0].lhs.span, + ) + self.init_meta( + MetaUnparser().do_transform(node.assignments[0].rhs, self._diagnostic_context) + ) + elif len(node.assignments) > 1: + self.report_error( + "Only a single top level `__tvm_meta__` is allowed", + ast.Span.union([x.span for x in node.assignments[1:]]), + ) + + return create_module( + {GlobalVar(name): self.transform(func) for name, func in node.funcs.items()} + ) - # parse meta - count = False - for body_element in node.body: - if isinstance(body_element, ast.FunctionDef): - pass - elif TVMScriptParser.is_meta(body_element) and not count: - count = True - self.init_meta(MetaUnparser().visit(body_element.value)) - else: - self.report_error("invalid class member") + def transform_Function(self, node): + """Function definition visitor. - # parse member functions - for body_element in node.body: - if isinstance(body_element, ast.FunctionDef): - self.visit(body_element) + Each function definition is translated to a single :code:`PrimFunc`. - return create_module(self.functions) + There are a couple restrictions on TVM Script functions: + 1. Function arguments must have their types specified. + 2. The body of the function can contain :code:`func_attr` to specify + attributes of the function (like it's name). + 3. The body of the function can also contain multiple :code:`buffer_bind`s, + which give shape and dtype information to arguments. + 4. Return statements are implicit. - def visit_FunctionDef(self, node): - """FunctionDef visitor - AST abstract grammar: - FunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list, - expr? returns, string? type_comment) - arguments = (arg* posonlyargs, arg* args, arg? vararg, arg* kwonlyargs, - expr* kw_defaults, arg? kwarg, expr* defaults) - arg = (identifier arg, expr? annotation, string? type_comment) + Example + ------- + .. code-block:: python + + @tvm.script.tir + def my_function(x: ty.handle): # 1. Argument types + tir.func_attr({"global_symbol": "mmult"}) # 2. Function attributes + X_1 = tir.buffer_bind(x, [1024, 1024]) # 3. Buffer binding + tir.evaluate(0) # 4. This function returns 0 """ self.init_function_parsing_env() - self.context.new_scope(nodes=node.body) + self.context.new_scope(nodes=node.body.stmts) # add parameters of function - for arg in node.args.args: - arg_var = tvm.te.var(arg.arg, self.parse_type(arg.annotation)) - self.context.update_symbol(arg.arg, arg_var) + for arg in node.params: + arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg)) + self.context.update_symbol(arg.name, arg_var) self.context.func_params.append(arg_var) # fetch the body and return a tir.PrimFunc func = tvm.tir.PrimFunc( self.context.func_params, - self.parse_body(), - ret_type=self.parse_type(node.returns), + self.parse_body(node.body), + ret_type=self.parse_type(node.ret_type, node), buffer_map=self.context.func_buffer_map, attrs=tvm.ir.make_node("DictAttrs", **self.context.func_dict_attr), ) - self.functions[GlobalVar(node.name)] = func self.context.pop_scope() return func - def visit_Assign(self, node): + def transform_Assign(self, node): """Assign visitor AST abstract grammar: Assign(expr* targets, expr value, string? type_comment) @@ -378,79 +417,76 @@ def visit_Assign(self, node): 4.1 var = tir.allocate() """ - if not len(node.targets) == 1: - self.report_error("Only one-valued assignment is supported now") - - if isinstance(node.targets[0], ast.Name) and isinstance(node.value, ast.Call): + if isinstance(node.rhs, ast.Call): # Pattern 1 & Pattern 4 - func = self.visit(node.value.func) - arg_list = self.parse_arg_list(func, node.value) + func = self.transform(node.rhs.func_name) if isinstance(func, WithScopeHandler): if not func.concise_scope or not func.def_symbol: self.report_error( - "with scope handler " + func.signature()[0] + " is not suitable here" + "with scope handler " + func.signature()[0] + " is not suitable here", + node.rhs.span, ) # Pattern 4 func.enter_scope(node, self.context) - arg_list = self.parse_arg_list(func, node.value) - func.body = self.parse_body() + arg_list = self.parse_arg_list(func, node.rhs) + func.body = self.parse_body(node) return func.exit_scope(node, self.context, arg_list) elif isinstance(func, SpecialStmt): # Pattern 1 + arg_list = self.parse_arg_list(func, node.rhs) func.handle(node, self.context, arg_list) + return self.parse_body(node) else: - self.report_error("Unsupported Assign stmt") - elif isinstance(node.targets[0], ast.Subscript): - # Pattern 2 & Pattern 3 - symbol, indexes = self.visit(node.targets[0]) - rhs = self.visit(node.value) - if isinstance(symbol, tvm.tir.Buffer): - # Pattern 2 - return tvm.tir.BufferStore(symbol, tvm.runtime.convert(rhs), indexes) - else: - if len(indexes) != 1: - self.report_error("Invalid Store stmt") - # Pattern 3 - return tvm.tir.Store( - symbol, tvm.runtime.convert(rhs), indexes[0], tvm.runtime.convert(True) - ) - else: - self.report_error("Unsupported Assign stmt") - - def visit_AnnAssign(self, node): - """AnnAssign visitor - AST abstract grammar: - AnnAssign(expr target, expr annotation, expr? value, int simple) - - Pattern corresponds to concise mode of with tir.let() - """ - - if isinstance(node.target, ast.Name): - value = self.visit(node.value) - var = tvm.te.var(node.target.id, self.parse_type(node.annotation)) - self.context.update_symbol(var.name, var) - body = self.parse_body() - self.context.remove_symbol(var.name) - return tvm.tir.LetStmt(var, value, body) + value = self.transform(node.rhs) + if not isinstance(node.lhs, ast.Var): + # This is a little confusing because it only is true when + # we have taken this branch. We might need to clarify what + # exectly is allowed in Assignments in tvmscript. + self.report_error( + "Left hand side of assignment must be an unqualified variable", + node.lhs.span, + ) + var = tvm.te.var(node.lhs.id.name, self.parse_type(node.ty, node.lhs)) + self.context.update_symbol(var.name, var) + body = self.parse_body(node) + self.context.remove_symbol(var.name) + return tvm.tir.LetStmt(var, value, body) + + self.report_error("Unsupported Assign stmt", node.span) + + def transform_SubscriptAssign(self, node): + """Visitor for statements of the form :code:`x[1] = 2`.""" + symbol = self.transform(node.params[0]) + indexes = self.transform(node.params[1]) + rhs = self.transform(node.params[2]) + if isinstance(symbol, tvm.tir.Buffer): + # BufferStore + return tvm.tir.BufferStore(symbol, tvm.runtime.convert(rhs), indexes) else: - self.report_error("Unsupported AnnAssign stmt") + if len(indexes) != 1: + self.report_error( + f"Store is only allowed with one index, but {len(indexes)} were provided.", + Span.union([x.span for x in indexes]), + ) + # Store + return tvm.tir.Store( + symbol, tvm.runtime.convert(rhs), indexes[0], tvm.runtime.convert(True) + ) - def visit_Assert(self, node): + def transform_Assert(self, node): """Assert visitor - AST abstract grammar: - Assert(expr test, expr? msg) - Pattern corresponds to concise mode of with tir.Assert() + Pattern corresponds to concise mode of :code:`with tir.Assert()`. """ - condition = self.visit(node.test) + condition = self.transform(node.condition) if node.msg is None: - self.report_error("Message of AssertStmt can't be None") - message = self.visit(node.msg) - body = self.parse_body() + self.report_error("Assert statements must have an error message.", node.span) + message = self.transform(node.msg) + body = self.parse_body(node) return tvm.tir.AssertStmt(condition, tvm.runtime.convert(message), body) - def visit_For(self, node): + def transform_For(self, node): """For visitor AST abstract grammar: For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) @@ -459,29 +495,29 @@ def visit_For(self, node): for name in tir.serial()/tir.parallel()/tir.vectorized()/tir.unroll() """ - if not isinstance(node.iter, ast.Call): - self.report_error("The loop iter should be a Call") - func = self.visit(node.iter.func) + if not isinstance(node.rhs, ast.Call): + self.report_error("The loop iterator should be a function call.", node.rhs.span) + func = self.transform(node.rhs.func_name) if not isinstance(func, ForScopeHandler): - self.report_error("Only for scope handlers can be used in for stmt") + self.report_error( + "Only For scope handlers can be used in a for statement.", node.rhs.func_name.span + ) # prepare for new for scope old_lineno, old_col_offset = self.current_lineno, self.current_col_offset - self.current_lineno, self.current_col_offset = ( - self.base_lineno + node.iter.lineno - 1, - node.iter.col_offset, - ) - self.context.new_scope(nodes=node.body) + self.current_lineno = node.span.start_line + self.current_col_offset = node.span.start_column + self.context.new_scope(nodes=node.body.stmts) # for scope handler process the scope func.enter_scope(node, self.context) - func.body = self.parse_body() - arg_list = self.parse_arg_list(func, node.iter) + func.body = self.parse_body(node) + arg_list = self.parse_arg_list(func, node.rhs) res = func.exit_scope(node, self.context, arg_list) # exit the scope self.context.pop_scope() self.current_lineno, self.current_col_offset = old_lineno, old_col_offset return res - def visit_With(self, node): + def transform_With(self, node): """With visitor AST abstract grammar: With(withitem* items, stmt* body, string? type_comment) @@ -493,299 +529,281 @@ def visit_With(self, node): with tir.let()/tir.Assert()/tir.attr()//tir.realize() """ - if not len(node.items) == 1: - self.report_error("Only one with element is supported now") - if not isinstance(node.items[0].context_expr, ast.Call): - self.report_error("The context expression of with should be a Call") + if not isinstance(node.rhs, ast.Call): + self.report_error( + "The context expression of a `with` statement should be a function call.", + node.rhs.span, + ) - func_call = node.items[0].context_expr - func_node = func_call.func - func = self.visit(func_node) + func = self.transform(node.rhs.func_name) if not isinstance(func, WithScopeHandler): - self.report_error("Function not allowed in with scope") + self.report_error( + f"Function {func} cannot be used in a `with` statement.", node.rhs.func_name.span + ) # prepare for new block scope old_lineno, old_col_offset = self.current_lineno, self.current_col_offset - self.current_lineno, self.current_col_offset = ( - self.base_lineno + func_call.lineno - 1, - func_call.col_offset, - ) - self.context.new_scope(nodes=node.body) + self.current_lineno = node.body.span.start_line + self.current_col_offset = node.body.span.start_column + self.context.new_scope(nodes=node.body.stmts) # with scope handler process the scope func.enter_scope(node, self.context) - func.body = self.parse_body() - arg_list = self.parse_arg_list(func, func_call) + func.body = self.parse_body(node) + arg_list = self.parse_arg_list(func, node.rhs) res = func.exit_scope(node, self.context, arg_list) # exit the scope self.context.pop_scope() self.current_lineno, self.current_col_offset = old_lineno, old_col_offset return res - def visit_If(self, node): + def transform_If(self, node): """If visitor AST abstract grammar: If(expr test, stmt* body, stmt* orelse) """ - condition = self.visit(node.test) + condition = self.transform(node.condition) # then body - self.context.new_scope(nodes=node.body) - then_body = self.parse_body() + self.context.new_scope(nodes=node.true.stmts) + then_body = self.parse_body(node) self.context.pop_scope() # else body - if len(node.orelse) > 0: - self.context.new_scope(nodes=node.orelse) - else_body = self.parse_body() + if len(node.false.stmts) > 0: + self.context.new_scope(nodes=node.false.stmts) + else_body = self.parse_body(node) self.context.pop_scope() else: else_body = None return tvm.tir.IfThenElse(condition, then_body, else_body) - def visit_Call(self, node): + def transform_Call(self, node): """Call visitor - AST abstract grammar: - Call(expr func, expr* args, keyword* keywords) - keyword = (identifier? arg, expr value) - By now 3 patterns of Call is allowed - 1. Intrin representing PrimExpr/IterVar + 3 different Call patterns are allowed: + 1. Intrin representing a PrimExpr/IterVar 1.1 tir.int/uint/float8/16/32/64/floormod/floordiv/load/cast/ramp/broadcast/max 1.2 tir.range/reduce_axis/scan_axis/opaque_axis 2. tir.Op(dtype, ...) 3. other callable functions """ - func = self.visit(node.func) - if isinstance(func, Intrin) and not func.stmt: - # pattern 1 - arg_list = self.parse_arg_list(func, node) - return func.handle(arg_list) + if isinstance(node.func_name, ast.Op): + if node.func_name.name == ast.BuiltinOp.Subscript: + return self.transform_Subscript(node) + if node.func_name.name in self._binop_maker: + lhs = self.transform(node.params[0]) + rhs = self.transform(node.params[1]) + return self._binop_maker[node.func_name.name](lhs, rhs) + if node.func_name.name in self._unaryop_maker: + rhs = self.transform(node.params[0]) + return self._unaryop_maker[node.func_name.name](rhs) + self.report_error(f"Unsupported operator {node.func_name.name}.", node.func_name.span) else: - args = [self.visit(arg) for arg in node.args] - kw_args = [self.visit(keyword) for keyword in node.keywords] - kw_args = {kw_arg[0]: kw_arg[1] for kw_arg in kw_args} - if isinstance(func, tvm.tir.op.Op): - # pattern 2 - return tvm.tir.Call(kw_args["dtype"], func, args) - elif callable(func): - # pattern 3 - return func(*args, **kw_args) - - self.report_error("Unsupported function call") - - def visit_Expr(self, node): - """Expr visitor - AST abstract grammar: - Expr(expr value) - - Now only 3 types of Expr stmt is allowed: - 1. Intrin representing Stmt without body - tir.store()/tir.evaluate() - 2. with scope handlers with concise scoping without var def - tir.attr()/tir.assert()/tir.allocate()/tir.realize() - 3. special stmt without var def - tir.func_attr() + func = self.transform(node.func_name) + if isinstance(func, Intrin) and not func.stmt: + # pattern 1 + arg_list = self.parse_arg_list(func, node) + return func.handle(arg_list) + else: + args = [self.transform(arg) for arg in node.params] + kw_args = { + self.transform(k): self.transform(v) for k, v in node.keyword_params.items() + } + if isinstance(func, tvm.tir.op.Op): + # pattern 2 + return tvm.tir.Call(kw_args["dtype"], func, args) + elif callable(func): + # pattern 3 + return func(*args, **kw_args) + + self.report_error("Unsupported function call.", node.func_name.span) + + def transform_UnassignedCall(self, node): + """Visitor for statements that are function calls. + + This handles function calls that appear on thier own line like `tir.realize`. + + Examples + -------- + .. code-block:: python + + @tvm.script.tir + def f(): + A = tir.buffer_decl([10, 10]) + tir.realize(A[1:2, 1:2], "") # This is an UnassignedCall + A[1, 1] = 2 # This is also an UnassignedCall """ + # Only allowed builtin operator that can be a statement is x[1] = 3 i.e. subscript assign. + if isinstance(node.call.func_name, ast.Op): + if node.call.func_name.name != ast.BuiltinOp.SubscriptAssign: + self.report_error( + "Binary and unary operators are not allowed as a statement", node.span + ) + else: + return self.transform_SubscriptAssign(node.call) - if not isinstance(node.value, ast.Call): - self.report_error("Unsupported Expr stmt") + # handle a regular function call + func = self.transform(node.call.func_name) + arg_list = self.parse_arg_list(func, node.call) - func = self.visit(node.value.func) - arg_list = self.parse_arg_list(func, node.value) + if isinstance(func, tvm.script.scope_handler.AssertHandler): + self.report_error( + "A standalone `tir.Assert` is not allowed. Use `assert condition, message` " + "instead.", + node.call.func_name.span, + ) if isinstance(func, Intrin) and func.stmt: - # pattern 1 return func.handle(arg_list) elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol: - # pattern 2 func.enter_scope(node, self.context) - func.body = self.parse_body() + func.body = self.parse_body(node) return func.exit_scope(node, self.context, arg_list) elif isinstance(func, SpecialStmt) and not func.def_symbol: - # pattern 3 func.handle(node, self.context, arg_list) return - self.report_error("Invalid Expr stmt") + self.report_error(f"Invalid Expr stmt {type(func).__name__}.", node.call.func_name.span) - def visit_BinOp(self, node): - """BinOp visitor - AST abstract grammar: - BinOp(expr left, operator op, expr right) - """ + def transform_Slice(self, node): + start = self.transform(node.start) + end = self.transform(node.end) + if not (isinstance(node.step, ast.Constant) and node.step.value == 1): + self.report_error("Only step size 1 is supported for slices.", node.step.span) + extent = end - start + if isinstance(extent, tvm.tir.PrimExpr): + ana = tvm.arith.Analyzer() + extent = ana.simplify(extent) + return tvm.ir.Range.from_min_extent(start, extent) - lhs = self.visit(node.left) - rhs = self.visit(node.right) - if not isinstance(node.op, tuple(TVMScriptParser._binop_maker.keys())): - self.report_error("BinOp " + str(type(node.op)) + " is not supported now") - return TVMScriptParser._binop_maker[type(node.op)](lhs, rhs) + def transform_Subscript(self, node): + """Array access visitor. - def visit_Compare(self, node): - """Compare visitor - AST abstract grammar: - Compare(expr left, expr right, ops=) - """ - - ops = [self.visit(node.left)] - ops += [self.visit(comparator) for comparator in node.comparators] - res = [] - for i in range(len(node.ops)): - lhs = ops[i] - rhs = ops[i + 1] - res.append(TVMScriptParser._binop_maker[type(node.ops[i])](lhs, rhs)) - return _all(*res) - - def visit_BoolOp(self, node): - """BoolOp visitor - AST abstract grammar: - BoolOp(boolop op, expr* values) - """ - - values = [self.visit(value) for value in node.values] - return TVMScriptParser._binop_maker[type(node.op)](*values) - - def visit_UnaryOp(self, node): - """UnaryOp visitor - AST abstract grammar: - UnaryOp(unaryop op, expr operand) - """ - - operand = self.visit(node.operand) - if not isinstance(node.op, tuple(TVMScriptParser._unaryop_maker.keys())): - self.report_error("UnaryOp " + str(type(node.op)) + " is not supported now") - return TVMScriptParser._unaryop_maker[type(node.op)](operand) - - def visit_Subscript(self, node): - """Subscript visitor - AST abstract grammar: - Subscript(expr value, slice slice, expr_context ctx) - slice = Slice(expr? lower, expr? upper, expr? step) - | ExtSlice(slice* dims) - | Index(expr value) - By now 2 patterns of Subscript are supported: + By now only 2 types of Subscript are supported: 1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore) Var[index] Buffer element access() 2. meta[type_key][index], Meta info access """ - symbol = self.visit(node.value) + symbol = self.transform(node.params[0]) if symbol is None: - self.report_error(node.value.id + " is not defined") - if isinstance(symbol, (tvm.tir.expr.Var, tvm.tir.Buffer)): - if isinstance(node.slice, ast.Index): - # BufferLoad & BufferStore, Buffer/Var[index, index, ...] - indexes = self.visit(node.slice.value) - indexes = list(indexes) if isinstance(indexes, tuple) else [indexes] - if isinstance(node.ctx, ast.Load): - if isinstance(symbol, tvm.tir.expr.Var): - return tvm.tir.Load("float32", symbol, indexes, True) - else: - return tvm.tir.BufferLoad(symbol, indexes) - else: - return symbol, indexes - else: - # Buffer Region, now used in tir.realize(buffer[bounds]) - doms = [] - slice_nodes = [] - if isinstance(node.slice, ast.Slice): - # Buffer[begin:end] - slice_nodes.append(node.slice) - elif isinstance(node.slice, ast.ExtSlice): - # Buffer[begin:end, begin:end] - slice_nodes.extend(node.slice.dims) - - for dim in slice_nodes: - if not hasattr(dim, "step"): - self.report_error("slice of Buffer Region ought to be begin:end") - if dim.step is not None: - self.report_error("step is not allowed in Buffer Region") - upper = self.visit(dim.upper) - lower = self.visit(dim.lower) - extent = upper - lower - if isinstance(extent, _expr.PrimExpr): - ana = tvm.arith.Analyzer() - extent = ana.simplify(extent) - doms.append(tvm.ir.Range.from_min_extent(lower, extent)) - return symbol, doms - else: - res = symbol[self.visit(slice)] - if res is None: - self.report_error("Only buffer variable and meta can be subscriptable") - return res + self.report_error(f"Variable {node.value.id} is not defined.", node.params[0].span) - def visit_Attribute(self, node): - """Attribute visitor - AST abstract grammar: - Attribute(expr value, identifier attr, expr_context ctx) + indexes = [self.transform(x) for x in node.params[1].values] + if isinstance(indexes[0], tvm.ir.Range): + return symbol, indexes + + if isinstance(symbol, tvm.tir.expr.Var): + return tvm.tir.Load("float32", symbol, indexes, True) + if isinstance(symbol, tvm.tir.Buffer): + return tvm.tir.BufferLoad(symbol, indexes) + + self.report_error( + f"Cannot subscript from a {type(symbol).__name__}. Only variables and " + "buffers are supported.", + node.params[0].span, + ) + + def transform_Attr(self, node): + """Visitor for field access of the form `x.y`. + + This visitor is used to lookup function and symbol names. We have two + cases to handle here: + 1. If we have a statement of the form `tir.something`, then we lookup + `tir.somthing` in the `Registry`. If the function is not in the + registry, then we try to find a `tvm.ir.op.Op` with the same name. + 2. All other names `tvm.something` are lookup up in this current python + namespace. """ - if isinstance(node.value, ast.Name): - if node.value.id == "tir": - func_name = "tir." + node.attr + if isinstance(node.object, ast.Var): + if node.object.id.name == "tir": + func_name = "tir." + node.field.name res = Registry.lookup(func_name) if res is not None: return res try: return tvm.ir.op.Op.get(func_name) - except AttributeError: - self.report_error("Unregistered function tir." + node.attr) - elif node.value.id == "ty": - if not hasattr(ty, node.attr): - self.report_error("invalid type annotation ty." + node.attr) - return getattr(ty, node.attr) - - symbol = self.visit(node.value) + except TVMError as e: + # Check if we got an attribute error + if e.args[0].find("AttributeError"): + self.report_error( + f"Unregistered function `tir.{node.field.name}`.", node.field.span + ) + else: + raise e + + symbol = self.transform(node.object) if symbol is None: - self.report_error("Unsupported Attribute expression") - if not hasattr(symbol, node.attr): - self.report_error("Type " + type(symbol) + " has not attr " + node.attr) - res = getattr(symbol, node.attr) + self.report_error("Unsupported Attribute expression.", node.object.span) + if not hasattr(symbol, node.field.name): + self.report_error( + f"Type {type(symbol)} does not have a field called `{node.field}`.", node.span + ) + res = getattr(symbol, node.field.name) return res - def visit_Dict(self, node): - """Dict visitor - AST abstract grammar: - Dict(expr* keys, expr* values) + def transform_TypeAttr(self, node): + """Visitor for field access of the form `x.y` for types. + + We have two cases here: + 1. If the type is of the form `ty.something`, we look up the type in + the `ty` namespace in this module. + 2. If the type is of the form `tvm.x.something` then we look up + `tvm.x.something` in this modules namespace. """ + if isinstance(node.object, ast.TypeVar): + if node.object.id.name == "ty": + if not hasattr(ty, node.field.name): + self.report_error(f"Invalid type annotation `ty.{node.field.name}`.", node.span) + return getattr(ty, node.field.name) - keys = [self.visit(key) for key in node.keys] - values = [self.visit(value) for value in node.values] + symbol = self.transform(node.object) + if symbol is None: + self.report_error("Unsupported Attribute expression", node.object.span) + if not hasattr(symbol, node.field): + self.report_error( + f"Type {type(symbol)} does not have a field called `{node.field}`.", node.span + ) + res = getattr(symbol, node.field) + return res - return {key: value for key, value in zip(keys, values)} + def transform_DictLiteral(self, node): + """Dictionary literal visitor. - def visit_Tuple(self, node): - """Tuple visitor - AST abstract grammar: - Tuple(expr* elts, expr_context ctx) + Handles dictionary literals of the form `{x:y, z:2}`. """ - return tuple(self.visit(element) for element in node.elts) + keys = [self.transform(key) for key in node.keys] + values = [self.transform(value) for value in node.values] - def visit_List(self, node): - """List visitor - AST abstract grammar: - List(expr* elts, expr_context ctx) + return dict(zip(keys, values)) + + def transform_Tuple(self, node): + """Tuple visitor. + + Handles tuples of the form `(x, y, 2)`. """ - return [self.visit(element) for element in node.elts] + return tuple(self.transform(element) for element in node.values) - def visit_keyword(self, node): - """Keyword visitor - AST abstract grammar: - keyword = (identifier? arg, expr value) + def transform_ArrayLiteral(self, node): + """List literal visitor. + + Handles lists of the form `[x, 2, 3]`. """ - return node.arg, self.visit(node.value) + return [self.transform(element) for element in node.values] - def visit_Name(self, node): - """Name visitor - AST abstract grammar: - Name(identifier id, expr_context ctx) + def transform_Var(self, node): + """Variable visitor + + Handles variables like `x` in `x = 2`. """ - name = node.id + name = node.id.name if name == "meta": return self.meta symbol = Registry.lookup(name) @@ -794,28 +812,51 @@ def visit_Name(self, node): symbol = self.context.lookup_symbol(name) if symbol is not None: return symbol - self.report_error("Unknown identifier %s" % name) + self.report_error(f"Unknown identifier {name}.", node.span) + + def transform_TypeVar(self, node): + """Type variable visitor. - # note that after Python3.8, ast.NameConstant, ast.Num, ast.Str are no longer used - def visit_Constant(self, node): + Equivalent to `transform_Var` but for types. + """ + name = node.id.name + symbol = Registry.lookup(name) or self.context.lookup_symbol(name) + if symbol is not None: + return symbol + self.report_error(f"Unknown identifier {name}.", node.span) + + def transform_Constant(self, node): + """Constant value visitor. + + Constant values include `None`, `"strings"`, `2` (integers), `4.2` + (floats), and `true` (booleans). + """ return node.value - def visit_NameConstant(self, node): + def transform_TypeConstant(self, node): + """Constant value visitor for types. + + See `transform_Constant`. + """ return node.value - def visit_Num(self, node): - return node.n + def transform_Return(self, node): + self.report_error( + "TVM script does not support return statements. Instead the last statement in any " + "block is implicitly returned.", + node.span, + ) - def visit_Str(self, node): - return node.s +def from_source(src): + """Parse function or string into TIR. -def from_source(src, func_lineno=0): - """Parse the src into TIR + If possible, pass the TVM script in as a function so that line numbers and + filename will be accurate. Parameters ---------- - src : str + src : [str, function, class] Pruned source of original script func_lineno : Optional[int] The line number of the first line of the script to be parsed @@ -824,32 +865,12 @@ def from_source(src, func_lineno=0): functions : PrimFunc or IRModule The PrimFunc or IRModule in IR. """ - - root = ast.parse(src) - parser = TVMScriptParser(src, func_lineno) - - try: - return parser.visit(root) - except TVMScriptParserError as e: - raise e - except TVMError as e: - # TVM internal c++ error, we have to process the error message and inject line info - inject_e = str(e).split("\n") - msg = inject_e[-1].split(":", maxsplit=1)[1].strip() - inject_e = inject_e[:-1] - inject_e.extend( - parser.wrap_line_col(msg, parser.current_lineno, parser.current_col_offset).split("\n") - ) - inject_e[-1] = "TVM" + inject_e[-1][6:] - raise TVMError("\n".join(inject_e)) from e - except Exception as e: - inject_e = parser.wrap_line_col(str(e), parser.current_lineno, parser.current_col_offset) - raise TVMScriptParserError(inject_e) from e - - -def _parse(script_in): - """Helper function to parse TVM script into TIR""" - return from_source(inspect.getsource(script_in), inspect.getsourcelines(script_in)[1]) + if isinstance(src, str): + start_line = 0 + else: + _, start_line = inspect.getsourcelines(src) + parser = TVMScriptParser(start_line) + return to_ast(src, TVMDiagnosticCtx(), parser) def create_module(functions=None): @@ -901,11 +922,11 @@ def tir(script_in): """ if inspect.isfunction(script_in): - result = _parse(script_in) + result = from_source(script_in) elif inspect.isclass(script_in): result = TVMScriptClass(script_in) else: - raise TypeError("Only function and class are supported") + raise TypeError("Only function and class definitions are supported.") result.__name__ = script_in.__name__ result.__qualname__ = script_in.__qualname__ return result @@ -932,4 +953,4 @@ def __init__(self, script_in): def __call__(self, *args, **kwargs): # call the parser to transform tvm script into TIR - return _parse(self.script) + return from_source(self.script) diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index 251df8c6d6cb..15197eaf50af 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -17,7 +17,7 @@ """TVM Script Parser Scope Handler Classes""" # pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level -from typed_ast import ast3 as ast +from synr import ast import tvm.tir from .utils import get_param_list from .registry import register @@ -92,7 +92,7 @@ def enter_scope(self, node, context): context.report_error("Unexpected number of vars") name = names[0] elif isinstance(node, ast.Assign): - name = node.targets[0].id + name = node.lhs.id.name else: raise Exception("Internal Bug") @@ -186,15 +186,15 @@ def enter_scope(self, node, context): assert isinstance(node, ast.For) loop_var_names = list() - if isinstance(node.target, ast.Name): - loop_var_names.append(node.target.id) - elif isinstance(node.target, ast.Tuple): - for elt in node.target.elts: - if not isinstance(elt, ast.Name): - context.report_error("Invalid loop var") - loop_var_names.append(elt.id) + if isinstance(node.lhs, ast.Var): + loop_var_names.append(node.lhs.id.name) + elif isinstance(node.lhs, ast.Tuple): + for elt in node.lhs.values: + if not isinstance(elt, ast.Var): + context.report_error("Invalid loop var", elt.span) + loop_var_names.append(elt.id.name) else: - context.report_error("Invalid loop var") + context.report_error("Invalid loop var", node.lhs) self.loop_vars = [tvm.te.var(name, dtype="int32") for name in loop_var_names] for loop_var in self.loop_vars: diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index 31fe0ed7cebf..f69475e37cfa 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -17,7 +17,7 @@ """TVM Script Parser Special Stmt Classes""" # pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements # pylint: disable=relative-beyond-top-level -from typed_ast import ast3 as ast +from synr import ast import tvm.tir from tvm import te @@ -69,7 +69,9 @@ def match_buffer( assert isinstance(self.node, ast.Assign) if param not in self.context.func_params: - self.context.report_error("Can not bind non-input param to buffer") + self.context.report_error( + "Can not bind non-input param to buffer", self.node.rhs.params[0].span + ) if strides is None: strides = [] align = align.value if not isinstance(align, int) else align @@ -79,7 +81,7 @@ def match_buffer( buffer = tvm.tir.decl_buffer( shape, dtype, - self.node.targets[0].id, + self.node.lhs.id.name, data, strides, elem_offset, @@ -89,7 +91,7 @@ def match_buffer( buffer_type, ) self.context.func_buffer_map[param] = buffer - self.context.update_symbol(self.node.targets[0].id, buffer) + self.context.update_symbol(self.node.lhs.id.name, buffer) super().__init__(match_buffer, def_symbol=True) @@ -127,7 +129,7 @@ def buffer_decl( buffer = tvm.tir.decl_buffer( shape, dtype, - self.node.targets[0].id, + self.node.lhs.id.name, data, strides, elem_offset, @@ -136,7 +138,7 @@ def buffer_decl( offset_factor, buffer_type, ) - self.context.update_symbol(self.node.targets[0].id, buffer) + self.context.update_symbol(self.node.lhs.id.name, buffer) return buffer super().__init__(buffer_decl, def_symbol=True) @@ -149,7 +151,7 @@ class VarDef(SpecialStmt): def __init__(self): def var(dtype): assert isinstance(self.node, ast.Assign) - v = te.var(self.node.targets[0].id, dtype) + v = te.var(self.node.lhs.id.name, dtype) self.context.update_symbol(v.name, v) super().__init__(var, def_symbol=True) @@ -162,7 +164,7 @@ class EnvThread(SpecialStmt): def __init__(self): def env_thread(env_name): assert isinstance(self.node, ast.Assign) - v = te.var(self.node.targets[0].id) + v = te.var(self.node.lhs.id.name) self.context.func_var_env_dict[v] = env_name self.context.update_symbol(v.name, v) diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index dd8621d0fbfe..048a9544d6df 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -14,120 +14,169 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - -import pytest - import tvm from tvm import tir -from tvm.script import ty -from tvm.script.parser import TVMScriptParserError +from tvm.script import ty, from_source +from tvm.ir.diagnostics import override_renderer +import inspect -@tvm.script.tir -class Module1: - def buffer_bind_missing_args(a: ty.handle) -> None: - A = tir.match_buffer((16, 16), "float32") +def buffer_bind_missing_args(a: ty.handle) -> None: + A = tir.match_buffer((16, 16), "float32") # error -@tvm.script.tir -class Module2: - def range_missing_args(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") +def test_buffer_bind(): + check_error(buffer_bind_missing_args, 2) - tir.attr(A, "realize_scope", "") - tir.realize(A[0:16, 0:16]) - for i in tir.serial(16): - for j in tir.serial(0, 16): - A[i, j] = 0.0 +def range_missing_args(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") -@tvm.script.tir -class Module3: - def undefined_buffer(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") + tir.attr(A, "realize_scope", "") + tir.realize(A[0:16, 0:16], "") + for i in tir.serial(16): # error + for j in tir.serial(0, 16): + A[i, j] = 0.0 - tir.attr(A, "realize_scope", "") - tir.realize(C[0:16, 0:16]) - for i in tir.serial(16): - for j in tir.serial(0, 16): - A[i, j] = 0.0 +def test_range_missing_args(): + check_error(range_missing_args, 6) -@tvm.script.tir -class Module4: - def unsupported_stmt(a: ty.int32) -> None: - if a > 0: - print("I love tvm") +def undefined_buffer(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") -@tvm.script.tir -class Module5: - def unsupported_function_call(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") + tir.attr(A, "realize_scope", "") + tir.realize(C[0:16, 0:16], "") # error + for i in tir.serial(16): + for j in tir.serial(0, 16): + A[i, j] = 0.0 - tir.attr(A, "realize_scope", "") - tir.realize(A[0:16, 0:16]) - for i in tir.const_range(16): - for j in tir.serial(0, 16): - A[i, j] = 0.0 +def test_undefined_buffer(): + check_error(undefined_buffer, 5) -@tvm.script.tir -class Module6: - def missing_type_annotation(a) -> None: - pass + +def unsupported_stmt(a: ty.int32) -> None: + if a > 0: + print("I love tvm") # error + + +def test_unsupported_stmt(): + check_error(unsupported_stmt, 3) + + +def unsupported_function_call(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + + tir.attr(A, "realize_scope", "") + tir.realize(A[0:16, 0:16], "") + for i in tir.const_range(16): # error + for j in tir.serial(0, 16): + A[i, j] = 0.0 + + +def test_unsupported_function_call(): + check_error(unsupported_function_call, 6) + + +def missing_type_annotation(a) -> None: # error + tir.evaluate(0.0) + + +def test_missing_type_annotation(): + check_error(missing_type_annotation, 1) + + +def invalid_expr_stmt() -> None: + tir.max(1, 2) # error -@tvm.script.tir -class Module7: - def invalid_concise_scoping() -> None: - tir.Assert(1.0 > 0.0, "aaaa") - tir.evaluate(0.0) +def test_invalid_expr_stmt(): + check_error(invalid_expr_stmt, 2) -@tvm.script.tir -class Module8: - def invalid_expr_stmt() -> None: - tir.max(1, 2) +def invalid_for_function(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + for i in tir.evaluate(0.0): # error + for j in tir.serial(0, 16): + A[i, j] = 0.0 -@tvm.script.tir -class Module9: - def invalid_for_function(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") - for i in tir.evaluate(0.0): - for j in tir.serial(0, 16): - A[i, j] = 0.0 +def test_invalid_for_function(): + check_error(invalid_for_function, 4) -@tvm.script.tir -class Module10: - def invalid_block_function(a: ty.handle) -> None: - A = tir.match_buffer(a, (16, 16), "float32") +def invalid_block_function(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") - with tir.evaluate(0.0): - pass + with tir.evaluate(0.0): # error + tir.evaluate(1.0) -def wrap_error(module, lineno): - with pytest.raises(TVMScriptParserError) as error: - mod = module() - assert error is not None - e = error.value - print(e) - msg = str(e).split("\n")[-1].split(":", maxsplit=1)[0].strip().split(" ")[-1].strip() - assert int(msg) == lineno +def test_invalid_block_function(): + check_error(invalid_block_function, 4) + + +def return_not_allowed(a: ty.handle) -> None: + return tir.evaluate(0) # error + + +def test_return_not_allowed(): + check_error(return_not_allowed, 2) + + +def tir_assert(a: ty.handle) -> None: + tir.Assert(0, "") # error + + +def test_tir_assert(): + check_error(tir_assert, 2) + + +def no_body(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + tir.realize(A, "") # error + + +def test_no_body(): + check_error(no_body, 3) + + +def check_error(module, rel_lineno): + # Override the default renderer to accumulate errors + _, start_line = inspect.getsourcelines(module) + lineno = start_line + rel_lineno - 1 + errors = [] + + def render(e): + for d in e.diagnostics: + errors.append(d) + + override_renderer(render) + # The diagnostic context throws an exception when it gets an error + try: + mod = from_source(module) + except tvm.error.DiagnosticError as e: + pass + assert len(errors) == 1, errors + for d in errors: + assert ( + d.span.line == lineno + ), f"Expected error to be on line {lineno}, but it was on {d.span.line}" if __name__ == "__main__": - wrap_error(Module1, 29) - wrap_error(Module2, 39) - wrap_error(Module3, 50) - wrap_error(Module4, 60) - wrap_error(Module5, 70) - wrap_error(Module6, 77) - wrap_error(Module7, 84) - wrap_error(Module8, 91) - wrap_error(Module9, 99) - wrap_error(Module10, 109) + test_buffer_bind() + test_range_missing_args() + test_undefined_buffer() + test_unsupported_stmt() + test_unsupported_function_call() + test_missing_type_annotation() + test_invalid_expr_stmt() + test_invalid_for_function() + test_invalid_block_function() + test_return_not_allowed() + test_tir_assert() + test_no_body() diff --git a/tests/scripts/task_ci_python_setup.sh b/tests/scripts/task_ci_python_setup.sh index 6463142a28c0..fe88ac650cc8 100755 --- a/tests/scripts/task_ci_python_setup.sh +++ b/tests/scripts/task_ci_python_setup.sh @@ -30,4 +30,4 @@ set -o pipefail # echo "Addtiional setup in" ${CI_IMAGE_NAME} -python3 -m pip install --user tlcpack-sphinx-addon==0.1.2 +python3 -m pip install --user tlcpack-sphinx-addon==0.1.2 synr==0.2.1