From cd45ed1a277411ee636205f2f3c0581b9c50b109 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 10 Mar 2021 22:24:14 +0800 Subject: [PATCH] [TensorIR] TVMScript Parser/Printer (#317) [TensorIR] TVMScript Parser/Printer Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Junru Shao Co-authored-by: Tianqi Chen Co-authored-by: Ruihang Lai Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Wuwei Lin --- include/tvm/tir/analysis.h | 15 + python/tvm/script/context_maintainer.py | 115 ++++- python/tvm/script/intrin.py | 20 +- python/tvm/script/node.py | 150 ++++++ python/tvm/script/parser.py | 173 ++++--- python/tvm/script/registry.py | 20 +- python/tvm/script/scope_handler.py | 444 +++++++++++++++--- python/tvm/script/special_stmt.py | 360 +++++++++++++- python/tvm/script/utils.py | 95 +++- python/tvm/tir/analysis/analysis.py | 23 + src/printer/tir_text_printer.cc | 3 +- src/printer/tvmscript_printer.cc | 232 ++++++++- .../analysis/block_access_region_detector.cc | 245 ++++++++++ src/tir/ir/script/script_complete.cc | 122 +++++ ...st_tir_analysis_get_block_access_region.py | 57 +++ .../unittest/test_tvmscript_error_report.py | 205 ++++++++ .../unittest/test_tvmscript_roundtrip.py | 170 +++++++ tests/scripts/task_ci_python_setup.sh | 2 +- 18 files changed, 2256 insertions(+), 195 deletions(-) create mode 100644 python/tvm/script/node.py create mode 100644 src/tir/analysis/block_access_region_detector.cc create mode 100644 src/tir/ir/script/script_complete.cc create mode 100644 tests/python/unittest/test_tir_analysis_get_block_access_region.py diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index e5b2c2b6957c..a4149548da1a 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -141,6 +141,21 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func); */ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constraints); +/*! + * \brief Auto detect the block read/write region according to body stmt + * It will detect the read/write region as an array in order of appearance in AST + * \param block The block to be detected + * \param buffer_var_map The outside buffers which may be accessed the block. + * It is a map from buffer var to the buffer. + * \return Array of access regions. + * There are three arrays of BufferRegion: + * - first: read regions + * - second: write regions + * - third: opaque regions + */ +Array> GetBlockAccessRegion(const Block& block, + const Map& buffer_var_map); + // Pass variants of verification analysis // directly throws RuntimeError when verification fails. namespace transform { diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 955266c4a3e0..48fd331ec563 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -16,59 +16,138 @@ # under the License. """TVM Script Context Maintainer for TIR""" -from tvm.te import schedule +from typing import List, Mapping, Union, Optional, Dict, Callable +import synr + + +import tvm +from tvm.ir import Span +from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion +from tvm.runtime import Object +from .node import BufferSlice + + +class BlockInfo: + """Information for block and block_realize signature""" + + alloc_buffers: List[Buffer] + match_buffers: List[MatchBufferRegion] + iter_bindings: Mapping[Var, PrimExpr] + reads: Optional[List[BufferSlice]] + writes: Optional[List[BufferSlice]] + annotations: Optional[Mapping[str, Object]] + predicate: Optional[PrimExpr] + init: Optional[Stmt] + + def __init__(self): + self.alloc_buffers = [] + self.match_buffers = [] + self.iter_bindings = {} + self.reads = None + self.writes = None + self.annotations = None + self.predicate = None + self.init = None class ContextMaintainer: """Maintain all the necessary context info""" - def __init__(self, parser): + # scope context + node_stack: List[List[synr.ast.Node]] + block_info_stack: List[BlockInfo] + loop_stack: List[List[Var]] + symbols: List[Dict[str, Union[Var, Buffer]]] + # function context + func_params: List[Var] + func_buffer_map: Mapping[Var, Buffer] + func_dict_attr: Mapping[str, Object] + func_var_env_dict: Mapping[Var, str] + # parser and analyzer + _report_error: Callable[[str, Union[Span, synr.ast.Span]], None] + analyzer: tvm.arith.Analyzer + + def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]): # scope context self.node_stack = [] # AST nodes of scopes + self.block_info_stack = [] # Block info of scopes + self.loop_stack = [] # stack of loop vars self.symbols = [] # symbols of scopes # function context self.func_params = [] # parameter list of function self.func_buffer_map = {} # buffer_map of function self.func_dict_attr = {} # func_attr of function self.func_var_env_dict = {} # map from var to env_name - # parser - self.parser = parser + # parser and analyzer + self._report_error = _report_error + self.analyzer = tvm.arith.Analyzer() - def pop_scope(self): - """Pop the inner most scope""" - self.symbols.pop() - self.node_stack.pop() + def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None): + """Creating a new scope - def new_scope(self, nodes=None): - """Creating a new scope""" + Parameters + ---------- + nodes : Optional[List[synr.ast.Node]] + The synr AST nodes in new scope + """ if nodes is None: nodes = [] self.node_stack.append(list(reversed(nodes))) self.symbols.append(dict()) - def update_symbol(self, name, symbol): + def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None): + """Creating a new block scope, the function will call `enter_scope` implicitly + + Parameters + ---------- + nodes : Optional[List[synr.ast.Node]] + The synr AST nodes in new scope + """ + self.enter_scope(nodes) + # Create a new loop stack for the new block + self.loop_stack.append([]) + # Create a new BlockInfo for the new block + self.block_info_stack.append(BlockInfo()) + + def exit_scope(self): + """Pop the inner most scope""" + self.symbols.pop() + self.node_stack.pop() + + def exit_block_scope(self): + """Pop the inner most block scope, the function will call `exit_scope` implicitly""" + self.exit_scope() + # Pop loop stack + self.loop_stack.pop() + # Pop block_info + self.block_info_stack.pop() + + def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node): """Append a symbol into current scope""" - if isinstance(symbol, schedule.Buffer): + if isinstance(symbol, Buffer): if name in self.symbols[0]: - self.parser.report_error("Duplicate Buffer name") + self.report_error("Duplicate Buffer name: " + symbol.name, node.span) self.symbols[0][name] = symbol else: self.symbols[-1][name] = symbol - def remove_symbol(self, name): + def remove_symbol(self, name: str): """Remove a symbol""" for symbols in reversed(self.symbols): if name in symbols: symbols.pop(name) return - raise RuntimeError("Internal error of tvm script parser: no symbol named" + name) + raise RuntimeError("Internal error of tvm script parser: no symbol named " + name) - def lookup_symbol(self, name): + def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]: """Look up symbol by name""" for symbols in reversed(self.symbols): if name in symbols: return symbols[name] return None - def report_error(self, message, span): - self.parser.report_error(message, span) + def report_error(self, message: str, span: Union[Span, synr.ast.Span]): + self._report_error(message, span) + + def current_block_scope(self) -> BlockInfo: + return self.block_info_stack[-1] diff --git a/python/tvm/script/intrin.py b/python/tvm/script/intrin.py index 053cd4a45846..48f50a2da442 100644 --- a/python/tvm/script/intrin.py +++ b/python/tvm/script/intrin.py @@ -16,9 +16,11 @@ # under the License. """TVM Script Parser Intrinsic Classes""" # pylint: disable=redefined-builtin, relative-beyond-top-level +from typing import List, Any + import tvm.tir from .registry import register -from .utils import get_param_list, from_synr_span +from .utils import get_param_list, tvm_span_from_synr class Intrin: @@ -29,8 +31,8 @@ def __init__(self, intrin, stmt=False): def signature(self): return "tir." + self.intrin.__name__, get_param_list(self.intrin) - def handle(self, arg_list, span): - return self.intrin(*arg_list, span=from_synr_span(span)) + def handle(self, arg_list: List[Any], span: tvm.ir.Span): + return self.intrin(*arg_list, span=tvm_span_from_synr(span)) @register @@ -98,6 +100,16 @@ def float64(imm, span): return tvm.tir.Cast("float64", imm, span) +@register +def min_value(dtype, span): + return tvm.tir.min_value(dtype, span) + + +@register +def max_value(dtype, span): + return tvm.tir.max_value(dtype, span) + + @register def floordiv(x, y, span): return tvm.tir.floordiv(x, y, span) @@ -145,7 +157,7 @@ def get_axis(begin, end, iter_type, span): block_var_dom = tvm.ir.Range.from_min_extent(begin, extent) iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4} - return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span) + return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span=span) @register diff --git a/python/tvm/script/node.py b/python/tvm/script/node.py new file mode 100644 index 000000000000..27b1203d3e3e --- /dev/null +++ b/python/tvm/script/node.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +"""TVM Script nodes.""" + +from typing import Optional, Union, List, Callable +import synr + +from tvm.runtime import ObjectGeneric +from tvm.tir import PrimExpr, Buffer, BufferLoad +from tvm.ir import Span + + +class Slice: + """A helper class to present slice information for BufferSlice + + Parameters + ---------- + start : Union[PrimExpr, int] + The start index. + + stop : Optional[Union[PrimExpr, int]] + The stop index, None means the Slice is an element-wise index + + span : Optional[Span] + The location of the slice in the source. + """ + + start: Union[PrimExpr, int] + stop: Optional[Union[PrimExpr, int]] + span: Optional[Span] + + def __init__( + self, + start: Union[PrimExpr, int], + stop: Optional[Union[PrimExpr, int]] = None, + span: Optional[Span] = None, + ): + self.start = start + self.stop = stop + self.span = span + + +class BufferSlice(ObjectGeneric): + """A generic object for representing general buffer access. Following cases are supported: + - element wise access buffer[i, j], which can be convert to BufferLoad if necessary + - slice access buffer[i: i + 1, j : j + 2] + - union of element and slice buffer[i, j: j + 2] + + This node is used in TVMScript to parse BufferLoad, BufferRegion and Realize + + Parameters + ---------- + buffer : Buffer + The buffer. + + indices : List[Union[Slice, PrimExpr, int]] + The access indexes can be slice, PrimExpr or int. + + report_error: Callable[[str, Union[Span, synr.ast.Span]], None] + The error report func + + span : Optional[Span] + The location of the buffer access in the source. + """ + + buffer: Buffer + slices: List[Slice] + report_error: Callable[[str, Union[Span, synr.ast.Span]], None] + span: Optional[Span] + + def __init__( + self, + buffer: Buffer, + indices: List[Union[Slice, PrimExpr, int]], + report_error: Callable[[str, Union[Span, synr.ast.Span]], None], + span: Optional[Span] = None, + ): + def check_index(index: Union[int, PrimExpr]): + """ Check input index is non-negative integer or PrimExpr""" + if isinstance(index, int): + if index < 0: + report_error("Negative index is not allowed during buffer access", span) + elif isinstance(index, PrimExpr): + if index.dtype != "int32": + report_error( + "index expects an int32 type PrimExpr but gets " + str(index.dtype), + index.span, + ) + else: + report_error( + "Unsupported index type, expects int or tvm.tir.PrimExpr, but gets " + + str(type(index)), + span, + ) + + slices: List[Slice] = [] + for index in indices: + if isinstance(index, Slice): + check_index(index.start) + check_index(index.stop) + slices.append(index) + elif isinstance(index, (PrimExpr, int)): + check_index(index) + slices.append(Slice(index)) + else: + report_error( + "Unsupported index type for BufferSlice, " + + "expects int, tvm.tir.PrimExpr, tvm.tir.Slice, but gets " + + str(type(index)), + span, + ) + + self.buffer = buffer + self.slices = slices + self.report_error = report_error + self.span = span + + def __str__(self): + regions: List[str] = [] + for s in self.slices: + if s.stop is None: + regions.append(str(s.start)) + else: + regions.append(str(s.start) + ": " + str(s.stop)) + + return self.buffer.name + "[" + ", ".join(regions) + "]" + + def asobject(self) -> BufferLoad: + """Convert object.""" + for s in self.slices: + if s.stop is not None: + self.report_error("BufferLoad only accepts elementwise access", self.span) + + indices = [s.start for s in self.slices] + return BufferLoad(self.buffer, indices, span=self.span) diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 33b0bab0d7e7..3acbf04eb700 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -24,6 +24,7 @@ import json import operator import inspect +from typing import Union from synr import ast, Transformer, to_ast import tvm @@ -32,6 +33,7 @@ from tvm.ir import GlobalVar from . import context_maintainer, ty +from .context_maintainer import BlockInfo from .meta_unparser import MetaUnparser from .registry import Registry from .intrin import Intrin @@ -39,7 +41,8 @@ from .scope_handler import ScopeHandler, WithScopeHandler, ForScopeHandler from . import _ffi_api from .diagnostics import TVMDiagnosticCtx -from .utils import from_synr_span +from .utils import tvm_span_from_synr, synr_span_from_tvm, call_with_error_reporting +from .node import Slice, BufferSlice class CallArgumentReader(object): @@ -158,7 +161,7 @@ def __init__(self, base_lienno): def init_function_parsing_env(self): """Initialize function parsing environment""" - self.context = context_maintainer.ContextMaintainer(self) # scope emitter + self.context = context_maintainer.ContextMaintainer(self.report_error) # scope emitter def init_meta(self, meta_dict): if meta_dict is not None: @@ -182,7 +185,7 @@ def transform(self, node): return transform_res - def report_error(self, message, span): + def report_error(self, message: str, span: Union[ast.Span, tvm.ir.Span]): """Report an error occuring at a location. This just dispatches to synr's DiagnosticContext. @@ -191,9 +194,11 @@ def report_error(self, message, span): ---------- message : str Error message - span : synr.ast.Span + span : synr.ast.Span or tvm.ir.Span Location of the error """ + if isinstance(span, tvm.ir.Span): + span = synr_span_from_tvm(span) self.error(message, span) def parse_body(self, parent): @@ -221,7 +226,7 @@ def parse_body(self, parent): ) else: return ( - tvm.tir.SeqStmt(body, from_synr_span(ast.Span.union(spans))) + tvm.tir.SeqStmt(body, tvm_span_from_synr(ast.Span.union(spans))) if len(body) > 1 else body[0] ) @@ -270,6 +275,13 @@ def parse_arg_list(self, func, node_call): internal_args.append(reader.get_kwarg(i + 1 + len(pos_only), arg_name, default=default)) if varargs is not None: internal_args.extend(reader.get_varargs(len(pos_only) + len(kwargs) + 1)) + elif len(args) + len(kw_args) > len(pos_only) + len(kwargs): + self.report_error( + "Arguments mismatched. " + + f"Expects {len(pos_only) + len(kwargs)} args but gets " + + f"{len(args) + len(kw_args)}", + node_call.span, + ) return internal_args def parse_type(self, type_node, parent): @@ -401,25 +413,47 @@ def my_function(x: ty.handle): # 1. Argument types """ self.init_function_parsing_env() - self.context.new_scope(nodes=node.body.stmts) + self.context.enter_scope(nodes=node.body.stmts) # add parameters of function for arg in node.params: arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg)) - self.context.update_symbol(arg.name, arg_var) + self.context.update_symbol(arg.name, arg_var, node) self.context.func_params.append(arg_var) - # fetch the body and return a tir.PrimFunc + # New Scope : Implicit root block + self.context.enter_block_scope(nodes=node.body.stmts) + + # fetch the body of root block + body = self.parse_body(node.body) + # Emit Scope : Implicit root block + root_info: BlockInfo = self.context.current_block_scope() + self.context.exit_block_scope() + + # return a tir.PrimFunc + dict_attr = self.context.func_dict_attr func = tvm.tir.PrimFunc( self.context.func_params, - self.parse_body(node.body), + body, ret_type=self.parse_type(node.ret_type, node), buffer_map=self.context.func_buffer_map, - attrs=tvm.ir.make_node("DictAttrs", **self.context.func_dict_attr), - span=from_synr_span(node.span), + attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None, + span=tvm_span_from_synr(node.span), + ) + + # Fix the PrimFunc + # 1. generate root block if necessary + # 2. generate surrounding loops for blocks if necessary + + func = call_with_error_reporting( + self.report_error, + node.span, + _ffi_api.Complete, + func, + root_info.alloc_buffers, ) - self.context.pop_scope() + self.context.exit_scope() return func def transform_Assign(self, node): @@ -470,12 +504,12 @@ def transform_Assign(self, node): var = tvm.te.var( node.lhs.id.name, self.parse_type(node.ty, node.lhs), - span=from_synr_span(node.lhs.span), + span=tvm_span_from_synr(node.lhs.span), ) - self.context.update_symbol(var.name, var) + self.context.update_symbol(var.name, var, node) body = self.parse_body(node) self.context.remove_symbol(var.name) - return tvm.tir.LetStmt(var, value, body, span=from_synr_span(node.span)) + return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) self.report_error("Unsupported Assign stmt", node.span) @@ -484,28 +518,28 @@ def transform_SubscriptAssign(self, node): symbol = self.transform(node.params[0]) indexes = self.transform(node.params[1]) rhs = self.transform(node.params[2]) - rhs_span = from_synr_span(node.params[2].span) + rhs_span = tvm_span_from_synr(node.params[2].span) if isinstance(symbol, tvm.tir.Buffer): # BufferStore return tvm.tir.BufferStore( symbol, tvm.runtime.convert(rhs, span=rhs_span), indexes, - span=from_synr_span(node.span), + span=tvm_span_from_synr(node.span), ) else: if len(indexes) != 1: self.report_error( f"Store is only allowed with one index, but {len(indexes)} were provided.", - Span.union([x.span for x in indexes]), + tvm.ir.Span.union([x.span for x in indexes]), ) # Store return tvm.tir.Store( symbol, tvm.runtime.convert(rhs, span=rhs_span), indexes[0], - tvm.runtime.convert(True, span=from_synr_span(node.span)), - span=from_synr_span(node.span), + tvm.runtime.convert(True, span=tvm_span_from_synr(node.span)), + span=tvm_span_from_synr(node.span), ) def transform_Assert(self, node): @@ -520,7 +554,7 @@ def transform_Assert(self, node): message = self.transform(node.msg) body = self.parse_body(node) return tvm.tir.AssertStmt( - condition, tvm.runtime.convert(message), body, span=from_synr_span(node.span) + condition, tvm.runtime.convert(message), body, span=tvm_span_from_synr(node.span) ) def transform_For(self, node): @@ -529,7 +563,8 @@ def transform_For(self, node): For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) By now 1 pattern of For is supported: 1. for scope handler - for name in tir.serial()/tir.parallel()/tir.vectorized()/tir.unroll() + for name in tir.serial()/tir.parallel()/tir.vectorized()/tir.unroll()/tir.range()/ + tir.grid()/tir.thread_binding() """ if not isinstance(node.rhs, ast.Call): @@ -543,14 +578,14 @@ def transform_For(self, node): old_lineno, old_col_offset = self.current_lineno, self.current_col_offset self.current_lineno = node.span.start_line self.current_col_offset = node.span.start_column - self.context.new_scope(nodes=node.body.stmts) + self.context.enter_scope(nodes=node.body.stmts) # for scope handler process the scope arg_list = self.parse_arg_list(func, node.rhs) func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) func.body = self.parse_body(node) res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span) # exit the scope - self.context.pop_scope() + self.context.exit_scope() self.current_lineno, self.current_col_offset = old_lineno, old_col_offset return res @@ -561,9 +596,9 @@ def transform_With(self, node): withitem = (expr context_expr, expr? optional_vars) By now 2 patterns of With is supported: 1. with scope handler with symbol def - with tir.allocate() as targets: + with tir.block(*axes)/tir.allocate() as targets: 2. with scope handler without symbol def - with tir.let()/tir.Assert()/tir.attr()//tir.realize() + with tir.let()/tir.Assert()/tir.attr()/tir.realize() """ if not isinstance(node.rhs, ast.Call): @@ -582,14 +617,14 @@ def transform_With(self, node): old_lineno, old_col_offset = self.current_lineno, self.current_col_offset self.current_lineno = node.body.span.start_line self.current_col_offset = node.body.span.start_column - self.context.new_scope(nodes=node.body.stmts) + self.context.enter_block_scope(nodes=node.body.stmts) # with scope handler process the scope arg_list = self.parse_arg_list(func, node.rhs) func.enter_scope(node, self.context, arg_list, node.rhs.func_name.span) func.body = self.parse_body(node) res = func.exit_scope(node, self.context, arg_list, node.rhs.func_name.span) # exit the scope - self.context.pop_scope() + self.context.exit_block_scope() self.current_lineno, self.current_col_offset = old_lineno, old_col_offset return res @@ -601,19 +636,21 @@ def transform_If(self, node): condition = self.transform(node.condition) # then body - self.context.new_scope(nodes=node.true.stmts) + self.context.enter_scope(nodes=node.true.stmts) then_body = self.parse_body(node) - self.context.pop_scope() + self.context.exit_scope() # else body if len(node.false.stmts) > 0: - self.context.new_scope(nodes=node.false.stmts) + self.context.enter_scope(nodes=node.false.stmts) else_body = self.parse_body(node) - self.context.pop_scope() + self.context.exit_scope() else: else_body = None - return tvm.tir.IfThenElse(condition, then_body, else_body, span=from_synr_span(node.span)) + return tvm.tir.IfThenElse( + condition, then_body, else_body, span=tvm_span_from_synr(node.span) + ) def transform_Call(self, node): """Call visitor @@ -633,18 +670,26 @@ def transform_Call(self, node): lhs = self.transform(node.params[0]) rhs = self.transform(node.params[1]) return self._binop_maker[node.func_name.name]( - lhs, rhs, span=from_synr_span(node.span) + lhs, rhs, span=tvm_span_from_synr(node.span) ) if node.func_name.name in self._unaryop_maker: rhs = self.transform(node.params[0]) - return self._unaryop_maker[node.func_name.name](rhs, span=from_synr_span(node.span)) + return self._unaryop_maker[node.func_name.name]( + rhs, span=tvm_span_from_synr(node.span) + ) self.report_error(f"Unsupported operator {node.func_name.name}.", node.func_name.span) else: func = self.transform(node.func_name) if isinstance(func, Intrin) and not func.stmt: # pattern 1 arg_list = self.parse_arg_list(func, node) - return func.handle(arg_list, node.func_name.span) + return call_with_error_reporting( + self.report_error, + node.func_name.span, + func.handle, + arg_list, + node.func_name.span, + ) else: args = [self.transform(arg) for arg in node.params] kw_args = { @@ -653,7 +698,7 @@ def transform_Call(self, node): if isinstance(func, tvm.tir.op.Op): # pattern 2 return tvm.tir.Call( - kw_args["dtype"], func, args, span=from_synr_span(node.span) + kw_args["dtype"], func, args, span=tvm_span_from_synr(node.span) ) elif callable(func): # pattern 3 @@ -700,7 +745,13 @@ def f(): ) if isinstance(func, Intrin) and func.stmt: - return func.handle(arg_list, node.call.func_name.span) + return call_with_error_reporting( + self.report_error, + node.call.func_name.span, + func.handle, + arg_list, + node.call.func_name.span, + ) elif isinstance(func, WithScopeHandler) and func.concise_scope and not func.def_symbol: func.enter_scope(node, self.context, arg_list, node.call.func_name.span) func.body = self.parse_body(node) @@ -716,11 +767,7 @@ def transform_Slice(self, node): end = self.transform(node.end) if not (isinstance(node.step, ast.Constant) and node.step.value == 1): self.report_error("Only step size 1 is supported for slices.", node.step.span) - extent = end - start - if isinstance(extent, tvm.tir.PrimExpr): - ana = tvm.arith.Analyzer() - extent = ana.simplify(extent) - return tvm.ir.Range.from_min_extent(start, extent, span=from_synr_span(node.span)) + return Slice(start, end) def transform_Subscript(self, node): """Array access visitor. @@ -728,7 +775,7 @@ def transform_Subscript(self, node): By now only 2 types of Subscript are supported: 1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore) Var[index] Buffer element access() - 2. meta[type_key][index], Meta info access + 2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...])) """ symbol = self.transform(node.params[0]) @@ -736,19 +783,26 @@ def transform_Subscript(self, node): self.report_error(f"Variable {node.value.id} is not defined.", node.params[0].span) indexes = [self.transform(x) for x in node.params[1].values] - if isinstance(indexes[0], tvm.ir.Range): - return symbol, indexes - if isinstance(symbol, tvm.tir.expr.Var): - return tvm.tir.Load("float32", symbol, indexes, True, span=from_synr_span(node.span)) - if isinstance(symbol, tvm.tir.Buffer): - return tvm.tir.BufferLoad(symbol, indexes, span=from_synr_span(node.span)) - - self.report_error( - f"Cannot subscript from a {type(symbol).__name__}. Only variables and " - "buffers are supported.", - node.params[0].span, - ) + for index in indexes: + if not isinstance(index, (tvm.tir.PrimExpr, int)): + self.report_error( + "Buffer load indexes expect int or PrimExpr, but get " + type(index), + node.span, + ) + return tvm.tir.Load( + "float32", symbol, indexes, True, span=tvm_span_from_synr(node.span) + ) + elif isinstance(symbol, tvm.tir.Buffer): + return BufferSlice( + symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span) + ) + else: + self.report_error( + f"Cannot subscript from a {type(symbol).__name__}. Only variables and " + "buffers are supported.", + node.params[0].span, + ) def transform_Attr(self, node): """Visitor for field access of the form `x.y`. @@ -756,7 +810,7 @@ def transform_Attr(self, node): This visitor is used to lookup function and symbol names. We have two cases to handle here: 1. If we have a statement of the form `tir.something`, then we lookup - `tir.somthing` in the `Registry`. If the function is not in the + `tir.something` in the `Registry`. If the function is not in the registry, then we try to find a `tvm.ir.op.Op` with the same name. 2. All other names `tvm.something` are lookup up in this current python namespace. @@ -875,7 +929,7 @@ def transform_Constant(self, node): Constant values include `None`, `"strings"`, `2` (integers), `4.2` (floats), and `true` (booleans). """ - return tvm.runtime.convert(node.value, span=from_synr_span(node.span)) + return tvm.runtime.convert(node.value, span=tvm_span_from_synr(node.span)) def transform_TypeConstant(self, node): """Constant value visitor for types. @@ -902,8 +956,7 @@ def from_source(src): ---------- src : [str, function, class] Pruned source of original script - func_lineno : Optional[int] - The line number of the first line of the script to be parsed + Returns ------- functions : PrimFunc or IRModule diff --git a/python/tvm/script/registry.py b/python/tvm/script/registry.py index 389570115935..245cc01051d5 100644 --- a/python/tvm/script/registry.py +++ b/python/tvm/script/registry.py @@ -16,7 +16,8 @@ # under the License. """TVM Script Parser Function Registry """ # pylint: disable=inconsistent-return-statements, relative-beyond-top-level, import-outside-toplevel -import inspect +import types +from typing import Union, Callable, Dict, Optional, Any class Registry(object): @@ -24,10 +25,10 @@ class Registry(object): All these maps are static """ - registrations = dict() + registrations: Dict[str, type] = dict() @staticmethod - def lookup(name): + def lookup(name: str) -> Optional[Any]: if name in Registry.registrations: # every time we create a new handler # since we may want to keep some local info inside it @@ -35,12 +36,14 @@ def lookup(name): return None -def register(inputs): +def register(inputs: Union[Callable, type]) -> type: """Register Intrin/ScopeHandler/SpecialStmt""" - if inspect.isfunction(inputs): + registration: type + if isinstance(inputs, types.FunctionType): + # is function from .intrin import Intrin - def create_new_intrin(func): + def create_new_intrin(func) -> type: class NewIntrin(Intrin): def __init__(self): super().__init__(func) @@ -48,11 +51,12 @@ def __init__(self): return NewIntrin registration = create_new_intrin(inputs) - elif inspect.isclass(inputs): + elif isinstance(inputs, type): + # is class registration = inputs else: raise ValueError() - key = registration().signature()[0] + key: str = registration().signature()[0] Registry.registrations[key] = registration return registration diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index 9449cbdc156c..8c6837895056 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -16,32 +16,59 @@ # under the License. """TVM Script Parser Scope Handler Classes""" # pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level +from typing import Tuple, Any, Callable, Optional, List, Union, Mapping +import synr from synr import ast import tvm.tir -from .utils import get_param_list, from_synr_span +from tvm.runtime import Object +from tvm.ir import Span, Range +from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion + +from .context_maintainer import ContextMaintainer +from .utils import ( + get_param_list, + tvm_span_from_synr, + buffer_slice_to_region, + call_with_error_reporting, +) from .registry import register +from .node import BufferSlice class ScopeHandler: """Base class for all scope handlers""" - def __init__(self, func): - self.func = func - self.body = None - self.node = None - self.context = None + def __init__(self, func: Callable): + self.func: Callable = func + self.body: Optional[Stmt] = None + self.node: Optional[synr.ast.Node] = None + self.context: Optional[ContextMaintainer] = None - def signature(self): + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: return "tir." + self.func.__name__, get_param_list(self.func) - def enter_scope(self, node, context, arg_list, span): + def enter_scope( + self, + node: synr.ast.Node, + context: ContextMaintainer, + arg_list: List[Any], + span: synr.ast.Span, + ): pass - def exit_scope(self, node, context, arg_list, span): + def exit_scope( + self, + node: synr.ast.Node, + context: ContextMaintainer, + arg_list: List[Any], + span: synr.ast.Span, + ): self.node = node self.context = context - return self.func(*arg_list, span=from_synr_span(span)) + return call_with_error_reporting( + context.report_error, span, self.func, *arg_list, span=tvm_span_from_synr(span) + ) class WithScopeHandler(ScopeHandler): @@ -57,22 +84,19 @@ def get_optional_var_names(node, context): """Get list of names from ast.With's optional_vars""" assert isinstance(node, ast.With) - var_names = None - if isinstance(node.items[0].optional_vars, ast.Name): - var_names = [node.items[0].optional_vars.id] - elif isinstance(node.items[0].optional_vars, (ast.List, ast.Tuple)): - for var in node.items[0].optional_vars.elts: - if not isinstance(var, ast.Name): - context.report_error("Invalid optional var definition") - var_names = [var.id for var in node.items[0].optional_vars.elts] + if isinstance(node.lhs, list): + for var in node.lhs: + if not isinstance(var, ast.Var): + context.report_error("Invalid optional var definition", node.span) + var_names = [var.id.name for var in node.lhs] else: - context.report_error("Invalid optional var definition") + context.report_error("Invalid optional var definition", node.span) return var_names @register class Allocate(WithScopeHandler): - """ With scope handler tir.alloc_with_scope(var, extents, dtype, scope, condition) """ + """ With scope handler tir.allocate(extents, dtype, scope, condition) """ def __init__(self): def allocate(extents, dtype, scope, condition=True, span=None): @@ -86,7 +110,13 @@ def allocate(extents, dtype, scope, condition=True, span=None): super().__init__(allocate, concise_scope=True, def_symbol=True) self.buffer_var = None - def enter_scope(self, node, context, arg_list, span): + def enter_scope( + self, + node: synr.ast.Node, + context: ContextMaintainer, + arg_list: List[Any], + span: synr.ast.Span, + ): # define buffer vars in symbol table if isinstance(node, ast.With): names = WithScopeHandler.get_optional_var_names(node, context) @@ -98,13 +128,13 @@ def enter_scope(self, node, context, arg_list, span): else: raise Exception("Internal Bug") - def setup_buffer_var(extents, dtype, scope, condition=True, span=None): + def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None): """Setup buffer var for a given type.""" buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) - setup_buffer_var(*arg_list, span=from_synr_span(node.lhs.id.span)) - context.update_symbol(name, self.buffer_var) + setup_buffer_var(*arg_list, span=tvm_span_from_synr(node.lhs.id.span)) + context.update_symbol(name, self.buffer_var, node) @register @@ -115,10 +145,10 @@ def __init__(self): def launch_thread(env_var, extent, span): extent = tvm.runtime.convert(extent, span=span) return tvm.tir.AttrStmt( - tvm.tir.IterVar( + IterVar( None, env_var, - getattr(tvm.tir.IterVar, "ThreadIndex"), + getattr(IterVar, "ThreadIndex"), self.context.func_var_env_dict[env_var], span=span, ), @@ -136,8 +166,19 @@ class Realize(WithScopeHandler): """ With scope handler tir.realize(buffer_bounds, scope, condition) """ def __init__(self): - def realize(buffer_bounds, scope, condition=True, span=None): - buffer, bounds = buffer_bounds + def realize( + buffer_slice: BufferSlice, scope: str, condition: bool = True, span: bool = None + ): + assert self.context + buffer: Buffer = buffer_slice.buffer + bounds: List[Range] = [] + for s in buffer_slice.slices: + min: Union[PrimExpr, int] = s.start + extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop - s.start + if isinstance(extent, PrimExpr): + extent = self.context.analyzer.simplify(extent) + bounds.append(Range.from_min_extent(min, extent, span=s.span)) + scope = tvm.runtime.convert(scope, span=span) return tvm.tir.AttrStmt( buffer, @@ -185,92 +226,365 @@ def let(var, value, span): super().__init__(let, concise_scope=False, def_symbol=False) +@register +class Block(WithScopeHandler): + """ With scope handler tir.block(extents, name) as iter_vars""" + + def __init__(self): + def block(axes=None, name_hint: str = "", span: Optional[Span] = None): + assert self.node + assert self.context + assert self.body + block_info = self.context.block_info_stack[-1] + if axes is None: + axes = [] + if len(axes) != len(self.block_vars): + self.context.report_error( + "Inconsistent number of block vars, " + + f"gets {len(axes)} axes but {len(self.block_vars)} block vars.", + self.node.span, + ) + block_iters: List[IterVar] = [] + for i, axis in enumerate(axes): + axis = tvm.runtime.convert(axis) + if isinstance(axis, tvm.tir.PrimExpr): + block_var_dom = Range.from_min_extent(0, axis) + block_iters.append(IterVar(block_var_dom, self.block_vars[i], 0)) + elif isinstance(axis, Range): + block_iters.append(IterVar(axis, self.block_vars[i], 0)) + elif isinstance(axis, IterVar): + block_iters.append(IterVar(axis.dom, self.block_vars[i], axis.iter_type)) + else: + self.context.report_error( + "Invalid argument of tir.block(), " + + f"expects PrimExpr, Range or IterVar, but gets {type(axis)}", + self.node.span, + ) + + # create block read/write regions + + reads: List[BufferRegion] = ( + [buffer_slice_to_region(read) for read in block_info.reads] + if block_info.reads + else [] + ) + writes: List[BufferRegion] = ( + [buffer_slice_to_region(write) for write in block_info.writes] + if block_info.writes + else [] + ) + inner = tvm.tir.Block( + block_iters, + reads, + writes, + name_hint, + self.body, + block_info.init, + block_info.alloc_buffers, + block_info.match_buffers, + block_info.annotations, + span, + ) + # create block var iter binding + values: List[PrimExpr] + if not block_info.iter_bindings: + values = self.context.loop_stack[-2].copy() + if len(values) == 0: + values = [tvm.tir.const(float("nan"), dtype="float32")] * len(block_iters) + elif len(values) != len(block_iters): + self.context.report_error( + "Autocomplete block iter var binding expects larger number of loops", + self.node.span, + ) + else: + for block_var in self.block_vars: + if block_var not in block_info.iter_bindings: + self.context.report_error( + "Missing block iter var binding for " + block_var.name, + self.node.span, + ) + values = [block_info.iter_bindings[block_var] for block_var in self.block_vars] + predicate = ( + tvm.tir.const(True, "bool") + if block_info.predicate is None + else block_info.predicate + ) + body = tvm.tir.BlockRealize(values, predicate, inner, span) + return body + + super().__init__(func=block, concise_scope=False, def_symbol=True) + self.block_vars = None + + def enter_scope( + self, + node: synr.ast.Node, + context: ContextMaintainer, + arg_list: List[Any], + span: synr.ast.Span, + ): + # define block vars + assert isinstance(node, ast.With) + + var_names = WithScopeHandler.get_optional_var_names(node, context) + self.block_vars = [tvm.te.var(name) for name in var_names] + for block_var in self.block_vars: + context.update_symbol(block_var.name, block_var, node) + + +@register +class InitBlock(WithScopeHandler): + """ With scope handler tir.init()""" + + def __init__(self): + def init(span: Span = None): + assert self.context + if self.context.block_info_stack[-2].init is not None: + self.context.report_error("Duplicate init block declaration", span) + self.context.block_info_stack[-2].init = self.body + + super().__init__(func=init, concise_scope=False, def_symbol=True) + + class ForScopeHandler(ScopeHandler): """Base class for all for scope handlers""" def __init__(self, func): super().__init__(func) - self.loop_vars = None - - def enter_scope(self, node, context, arg_list, span): + self.loop_vars: Optional[List[Var]] = None + + def enter_scope( + self, + node: synr.ast.Node, + context: ContextMaintainer, + arg_list: List[Any], + span: synr.ast.Span, + ): assert isinstance(node, ast.For) loop_var_names = list() spans = list() if isinstance(node.lhs, ast.Var): loop_var_names.append(node.lhs.id.name) - spans.append(from_synr_span(node.lhs.id.span)) - elif isinstance(node.lhs, ast.Tuple): - for elt in node.lhs.values: + spans.append(tvm_span_from_synr(node.lhs.id.span)) + elif isinstance(node.lhs, list): + for elt in node.lhs: if not isinstance(elt, ast.Var): context.report_error("Invalid loop var", elt.span) loop_var_names.append(elt.id.name) - spans.append(from_synr_span(elt.id.span)) + spans.append(tvm_span_from_synr(elt.id.span)) else: - context.report_error("Invalid loop var", node.lhs.span) + context.report_error("Invalid loop var in loop", span) self.loop_vars = [ tvm.te.var(name, dtype="int32", span=span) for name, span in zip(loop_var_names, spans) ] for loop_var in self.loop_vars: - context.update_symbol(loop_var.name, loop_var) + context.update_symbol(loop_var.name, loop_var, node) + context.loop_stack[-1].append(loop_var) + + def exit_scope( + self, + node: synr.ast.Node, + context: ContextMaintainer, + arg_list: List[Any], + span: synr.ast.Span, + ): + assert self.loop_vars + for _ in self.loop_vars: + context.loop_stack[-1].pop() + return super().exit_scope(node, context, arg_list, span) + + def create_loop( + self, + begin: PrimExpr, + end: PrimExpr, + kind: int, + thread_binding: Optional[str] = None, + annotations: Optional[Mapping[str, Object]] = None, + span: Optional[Span] = None, + ) -> tvm.tir.For: + """ + Helper function for creating For in TVM Script parser. + + Parameters + ---------- + begin : PrimExpr + The beginning value. + + end : PrimExpr + The endding value. + + kind : ForKind + The type of the for. + + thread_binding: Optional[str] + The thread this loop binds to. + + annotations : Optional[Mapping[str, Object]] + Additional annotation hints. + + span : Optional[Span] + The location of this for in the source code. + + Returns + ------- + for : For + The constructed For. + """ + assert self.node + assert self.context + assert self.loop_vars + if len(self.loop_vars) != 1: + self.context.report_error( + f"Expect exact only one loop var, but get {self.loop_vars}", self.node.span + ) + extent = end if begin == 0 else self.context.analyzer.simplify(end - begin) + annos: Mapping[str, Object] + if annotations is None: + annos = {} + else: + annos = { + key: tvm.tir.StringImm(val) if isinstance(val, str) else val + for key, val in annotations.items() + } + return tvm.tir.For( + self.loop_vars[0], + begin, + extent, + kind, + self.body, + thread_binding=thread_binding, + annotations=annos, + span=span, + ) @register class Serial(ForScopeHandler): - """ For scope handler tir.serial(begin, end)""" + """ For scope handler tir.serial(begin, end, annotations)""" def __init__(self): - def serial(begin, end, span): - if len(self.loop_vars) != 1: - self.context.report_error("Expect exact 1 loop var", span) - ana = tvm.arith.Analyzer() - extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(self.loop_vars[0], begin, extent, 0, self.body, span=span) + def serial( + begin: PrimExpr, + end: PrimExpr, + annotations: Optional[Mapping[str, Object]] = None, + span: Optional[Span] = None, + ): + return self.create_loop(begin, end, 0, annotations=annotations, span=span) super().__init__(serial) @register class Parallel(ForScopeHandler): - """ For scope handler tir.parallel(begin, end)""" + """ For scope handler tir.parallel(begin, end, annotations)""" def __init__(self): - def parallel(begin, end, span): - if len(self.loop_vars) != 1: - self.context.report_error("Expect exact 1 loop var") - ana = tvm.arith.Analyzer() - extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(self.loop_vars[0], begin, extent, 1, self.body, span=span) + def parallel( + begin: PrimExpr, + end: PrimExpr, + annotations: Optional[Mapping[str, Object]] = None, + span: Optional[Span] = None, + ): + return self.create_loop(begin, end, 1, annotations=annotations, span=span) super().__init__(parallel) @register class Vectorized(ForScopeHandler): - """ For scope handler tir.vectorized(begin, end)""" + """ For scope handler tir.vectorized(begin, end, annotations)""" def __init__(self): - def vectorized(begin, end, span): - if len(self.loop_vars) != 1: - self.context.report_error("Expect exact 1 loop var") - ana = tvm.arith.Analyzer() - extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(self.loop_vars[0], begin, extent, 2, self.body, span=span) + def vectorized( + begin: PrimExpr, + end: PrimExpr, + annotations: Optional[Mapping[str, Object]] = None, + span: Optional[Span] = None, + ): + return self.create_loop(begin, end, 2, annotations=annotations, span=span) super().__init__(vectorized) @register class Unroll(ForScopeHandler): - """ For scope handler tir.unroll(begin, end)""" + """ For scope handler tir.unroll(begin, end, annotations)""" def __init__(self): - def unroll(begin, end, span): - if len(self.loop_vars) != 1: - self.context.report_error("Expect exact 1 loop var") - ana = tvm.arith.Analyzer() - extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(self.loop_vars[0], begin, extent, 3, self.body, span=span) + def unroll( + begin: PrimExpr, + end: PrimExpr, + annotations: Optional[Mapping[str, Object]] = None, + span: Optional[Span] = None, + ): + return self.create_loop(begin, end, 3, annotations=annotations, span=span) super().__init__(unroll) + + +@register +class ThreadBinding(ForScopeHandler): + """ For scope handler tir.thread_binding(begin, end, thread, annotations)""" + + def __init__(self): + def thread_binding( + begin: PrimExpr, + end: PrimExpr, + thread: str, + annotations: Optional[Mapping[str, Object]] = None, + span: Optional[Span] = None, + ): + thread_iter_var = IterVar(None, None, 1, thread, span=span) + return self.create_loop( + begin, + end, + 4, + thread_binding=thread_iter_var, + annotations=annotations, + span=span, + ) + + super().__init__(thread_binding) + + +@register +class RangeHandler(ForScopeHandler): + """For scope handler range(begin, end, annotations) + Note that tir.range is totally the same as tir.serial + """ + + def __init__(self): + def for_range( + begin: PrimExpr, + end: PrimExpr, + annotations: Optional[Mapping[str, Object]] = None, + span: Optional[Span] = None, + ): + return self.create_loop(begin, end, 0, annotations=annotations, span=span) + + super().__init__(for_range) + + def signature(self): + return "range", get_param_list(self.func) + + +@register +class Grid(ForScopeHandler): + """ For scope handler tir.grid(extents)""" + + def __init__(self): + def grid(*extents: List[PrimExpr], span: Span): + assert self.node + assert self.context + assert self.loop_vars + if len(self.loop_vars) != len(extents): + self.context.report_error( + "Inconsistent number of loop vars and extents", self.node.span + ) + body = self.body + for loop_var, extent in zip(reversed(self.loop_vars), reversed(extents)): + body = tvm.tir.For(loop_var, 0, extent, 0, body, span=span) + return body + + super().__init__(grid) diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index 62ce1ea19d89..f0a1c800db1f 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -17,30 +17,62 @@ """TVM Script Parser Special Stmt Classes""" # pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements # pylint: disable=relative-beyond-top-level +from typing import Callable, List, Optional, Tuple, Any, Mapping, Union + +import synr from synr import ast import tvm.tir +from tvm.runtime import Object from tvm import te -from .utils import get_param_list, from_synr_span +from tvm.ir import Span +from tvm.tir import IntImm +from .utils import ( + get_param_list, + tvm_span_from_synr, + buffer_slice_to_region, + call_with_error_reporting, +) from .registry import register +from .context_maintainer import ContextMaintainer +from .node import BufferSlice + + +def convert_to_int(value, arg_name, report_error, span): + if isinstance(value, IntImm): + return value.value + if isinstance(value, int): + return value + report_error( + f"Expects int or IntImm for {arg_name}, but gets {str(type(value))}", + span, + ) class SpecialStmt: """Base class for all Special Stmts""" - def __init__(self, func, def_symbol): - self.func = func - self.def_symbol = def_symbol - self.node = None - self.context = None + def __init__(self, func: Callable, def_symbol: bool): + self.func: Callable = func + self.def_symbol: bool = def_symbol + self.node: Optional[synr.ast.Node] = None + self.context: Optional[ContextMaintainer] = None - def signature(self): + def signature(self) -> Tuple[str, Tuple[list, list, Any]]: return "tir." + self.func.__name__, get_param_list(self.func) - def handle(self, node, context, arg_list, span): + def handle( + self, + node: ast.Node, + context: ContextMaintainer, + arg_list: List[Any], + span: synr.ast.Span, + ): self.node = node self.context = context - return self.func(*arg_list, span=from_synr_span(span)) + return call_with_error_reporting( + context.report_error, span, self.func, *arg_list, span=tvm_span_from_synr(span) + ) @register @@ -67,17 +99,20 @@ def match_buffer( buffer_type="default", span=None, ): - assert isinstance(self.node, ast.Assign) - + if not isinstance(self.node, ast.Assign): + self.context.report_error( + "Need assign the match_buffer to a buffer, e.g. A = match_buffer(...)", + self.node.span, + ) if param not in self.context.func_params: self.context.report_error( "Can not bind non-input param to buffer", self.node.rhs.params[0].span ) if strides is None: strides = [] - align = align.value if not isinstance(align, int) else align - offset_factor = ( - offset_factor.value if not isinstance(offset_factor, int) else offset_factor + align = convert_to_int(align, "align", self.context.report_error, self.node.span) + offset_factor = convert_to_int( + offset_factor, "offset_factor", self.context.report_error, self.node.span ) buffer = tvm.tir.decl_buffer( shape, @@ -93,7 +128,7 @@ def match_buffer( span=span, ) self.context.func_buffer_map[param] = buffer - self.context.update_symbol(self.node.lhs.id.name, buffer) + self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) super().__init__(match_buffer, def_symbol=True) @@ -121,13 +156,17 @@ def buffer_decl( buffer_type="default", span=None, ): - assert isinstance(self.node, ast.Assign) + if not isinstance(self.node, ast.Assign): + self.context.report_error( + "Need assign the buffer_decl to a buffer, e.g. A = buffer_decl(...)", + self.node.span, + ) if strides is None: strides = [] - align = align.value if not isinstance(align, int) else align - offset_factor = ( - offset_factor.value if not isinstance(offset_factor, int) else offset_factor + align = convert_to_int(align, "align", self.context.report_error, self.node.span) + offset_factor = convert_to_int( + offset_factor, "offset_factor", self.context.report_error, self.node.span ) buffer = tvm.tir.decl_buffer( shape, @@ -142,12 +181,289 @@ def buffer_decl( buffer_type, span=span, ) - self.context.update_symbol(self.node.lhs.id.name, buffer) + self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) return buffer super().__init__(buffer_decl, def_symbol=True) +@register +class AllocBuffer(SpecialStmt): + """Special function alloc_buffer(shape, dtype, data, strides, elem_offset, scope, align, + offset_factor, buffer_type) + + Example + ------- + .. code-block:: python + + A = tir.alloc_buffer((128, 128), dtype="float32") + + """ + + def __init__(self): + def alloc_buffer( + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="", + align=-1, + offset_factor=0, + buffer_type="default", + span=None, + ): + if not isinstance(self.node, ast.Assign): + self.context.report_error( + "Need assign the alloc_buffer to a buffer, e.g. A = alloc_buffer(...)", + self.node.span, + ) + + if strides is None: + strides = [] + align = convert_to_int(align, "align", self.context.report_error, self.node.span) + offset_factor = convert_to_int( + offset_factor, "offset_factor", self.context.report_error, self.node.span + ) + buffer = tvm.tir.decl_buffer( + shape, + dtype, + self.node.lhs.id.name, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + span=span, + ) + self.context.current_block_scope().alloc_buffers.append(buffer) + self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) + + super().__init__(alloc_buffer, def_symbol=True) + + +@register +class BlockVarBind(SpecialStmt): + """Special function bind(block_iter, binding_value) + + Example + ------- + .. code-block:: python + + tir.bind(vx, i) + + """ + + def __init__(self): + def bind(iter_var, values, span=None): + block_scope = self.context.current_block_scope() + if iter_var in block_scope.iter_bindings: + self.context.report_error("Duplicate iter_var bindings of " + str(iter_var), span) + block_scope.iter_bindings[iter_var] = values + + super().__init__(bind, def_symbol=False) + + +@register +class BlockReads(SpecialStmt): + """Special function reads([read_buffer_regions]) + + Example + ------- + .. code-block:: python + + tir.reads([A[vi: vi + 4, vk: vk + 4], B[vk: vk + 4, vj]]) + + """ + + def __init__(self): + def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: Span = None): + assert self.context + block_scope = self.context.current_block_scope() + if block_scope.reads is not None: + self.context.report_error( + "Duplicate write region declaration, " + + "previous one is " + + str(", ".join(str(x) for x in block_scope.reads)), + span, + ) + if isinstance(read_regions, list): + pass + elif isinstance(read_regions, BufferSlice): + read_regions = [read_regions] + else: + self.context.report_error( + "Error input type. " + + f"Expects BufferSlice or List[BufferSlice], but gets {type(read_regions)}", + span, + ) + block_scope.reads = read_regions + + super().__init__(reads, def_symbol=False) + + +@register +class BlockWrites(SpecialStmt): + """Special function writes([write_buffer_regions]) + + Example + ------- + .. code-block:: python + + tir.writes([C[vi: vi + 4, vj]) + + """ + + def __init__(self): + def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: Span = None): + assert self.context + block_scope = self.context.current_block_scope() + if block_scope.writes is not None: + self.context.report_error( + "Duplicate write region declaration, " + + "previous one is " + + str(", ".join(str(x) for x in block_scope.writes)), + span, + ) + if isinstance(write_region, list): + pass + elif isinstance(write_region, BufferSlice): + write_region = [write_region] + else: + self.context.report_error( + "Error input type. " + + f"Expects BufferSlice or List[BufferSlice], but gets {type(write_region)}", + span, + ) + block_scope.writes = write_region + + super().__init__(writes, def_symbol=False) + + +@register +class BlockAttr(SpecialStmt): + """Special function block_attr({attr_key: attr_value}) + + Example + ------- + .. code-block:: python + + tir.block_attr({"double_buffer_scope": 1}) + + """ + + def __init__(self): + def block_attr(attrs: Mapping[str, Object], span: Span = None): + assert self.context + block_scope = self.context.current_block_scope() + if block_scope.annotations is not None: + self.context.report_error( + "Duplicate block annotations declaration, " + + "previous one is " + + str(block_scope.annotations), + span, + ) + attrs = { + key: tvm.tir.StringImm(val) if isinstance(val, str) else val + for key, val in attrs.items() + } + block_scope.annotations = attrs + + super().__init__(block_attr, def_symbol=False) + + +@register +class BlockPredicate(SpecialStmt): + """Special function where(predicate) + + Example + ------- + .. code-block:: python + + tir.where(i < 4) + + """ + + def __init__(self): + def where(predicate, span=None): + block_scope = self.context.current_block_scope() + if block_scope.predicate is not None: + self.context.report_error( + "Duplicate block predicate declaration, " + + "previous one is " + + str(block_scope.predicate), + span, + ) + + block_scope.predicate = predicate + + super().__init__(where, def_symbol=False) + + +@register +class BlockMatchBufferRegion(SpecialStmt): + """Special function match_buffer_region(source, strides, elem_offset, align, offset_factor) + + Example + ------- + .. code-block:: python + + B = tir.match_buffer_region(A[0: 4]) + + """ + + def __init__(self): + def match_buffer_region( + source, + strides=None, + elem_offset=None, + align=-1, + offset_factor=0, + span=None, + ): + if not isinstance(self.node, ast.Assign): + self.context.report_error( + "Need assign the match_buffer_region to a buffer, " + + "e.g. A = match_buffer_region(...)", + self.node.span, + ) + + if strides is None: + strides = [] + align = convert_to_int(align, "align", self.context.report_error, self.node.span) + offset_factor = convert_to_int( + offset_factor, "offset_factor", self.context.report_error, self.node.span + ) + + if not isinstance(source, BufferSlice): + self.context.report_error( + "match_buffer_region needs a buffer region as source", + span=span, + ) + buffer_region = buffer_slice_to_region(source) + shape = [r.extent for r in buffer_region.region] + buffer = tvm.tir.decl_buffer( + shape, + buffer_region.buffer.dtype, + self.node.lhs.id.name, + data=None, + strides=strides, + elem_offset=elem_offset, + scope=buffer_region.buffer.scope, + data_alignment=align, + offset_factor=offset_factor, + span=span, + ) + self.context.current_block_scope().match_buffers.append( + tvm.tir.MatchBufferRegion(buffer, buffer_region) + ) + self.context.update_symbol(self.node.lhs.id.name, buffer, self.node) + + super().__init__(match_buffer_region, def_symbol=True) + + @register class VarDef(SpecialStmt): """ Special function for defining a Var""" @@ -156,7 +472,7 @@ def __init__(self): def var(dtype, span): assert isinstance(self.node, ast.Assign) v = te.var(self.node.lhs.id.name, dtype, span=span) - self.context.update_symbol(v.name, v) + self.context.update_symbol(v.name, v, self.node) super().__init__(var, def_symbol=True) @@ -170,7 +486,7 @@ def env_thread(env_name, span): assert isinstance(self.node, ast.Assign) v = te.var(self.node.lhs.id.name, span=span) self.context.func_var_env_dict[v] = env_name - self.context.update_symbol(v.name, v) + self.context.update_symbol(v.name, v, self.node) super().__init__(env_thread, def_symbol=True) diff --git a/python/tvm/script/utils.py b/python/tvm/script/utils.py index a6ba9d087aa6..503534be4305 100644 --- a/python/tvm/script/utils.py +++ b/python/tvm/script/utils.py @@ -16,15 +16,32 @@ # under the License. """Helper functions in TVM Script Parser""" +from typing import Callable, List, Any, Optional, Tuple, Union + import inspect -from ..ir import Span, SourceName +import synr + +from tvm.arith import Analyzer +from tvm.ir import Range, Span, SourceName +from tvm.tir import PrimExpr, BufferRegion +from tvm.error import DiagnosticError +from .node import BufferSlice -def get_param_list(func): +def get_param_list( + func: Callable, +) -> Tuple[List[str], List[Tuple[str, Tuple[Any, ...]]], Optional[str]]: """Get the parameter list from definition of function""" - full_arg_spec = inspect.getfullargspec(func) + full_arg_spec: inspect.FullArgSpec = inspect.getfullargspec(func) - args, defaults = full_arg_spec.args, full_arg_spec.defaults + args: List[str] + defaults: Optional[Tuple[Any, ...]] + kwonlyargs: List[str] + args, defaults, kwonlyargs = ( + full_arg_spec.args, + full_arg_spec.defaults, + full_arg_spec.kwonlyargs, + ) if defaults is None: defaults = tuple() @@ -33,14 +50,17 @@ def get_param_list(func): raise RuntimeError( "TVM Script register error : variable keyword argument is not supported now" ) - if not len(full_arg_spec.kwonlyargs) == 0: + + if len(kwonlyargs) == 1 and kwonlyargs[0] == "span": + pass + elif not len(kwonlyargs) == 0: raise RuntimeError("TVM Script register error : keyword only argument is not supported now") - pos_only = list() + pos_only: List[str] = list() for arg in args[: len(args) - len(defaults)]: if arg != "span": pos_only.append(arg) - kwargs = list() + kwargs: List[Tuple[str, Tuple[Any, ...]]] = list() for default, arg in zip(defaults, args[len(args) - len(defaults) :]): if arg != "span": kwargs.append((arg, default)) @@ -48,7 +68,37 @@ def get_param_list(func): return pos_only, kwargs, full_arg_spec.varargs -def from_synr_span(span): +def buffer_slice_to_region( + buffer_slice: BufferSlice, analyzer: Optional[Analyzer] = None +) -> BufferRegion: + """Construct BufferRegion from BufferSlice + + Parameters + ---------- + buffer_slice : BufferSlice + The input BufferSlice + + analyzer : Optional[tvm.arith.Analyzer] + The analyzer for simplifying. If not provided, the method will construct a new one + + Returns + ------- + buffer_region : BufferRegion + The constructed BufferRegion. + """ + region: List[Range] = [] + for s in buffer_slice.slices: + start: Union[PrimExpr, int] = s.start + extent: Union[PrimExpr, int] = 1 if s.stop is None else s.stop - s.start + if not analyzer: + analyzer = Analyzer() + if isinstance(extent, PrimExpr): + extent = analyzer.simplify(extent) + region.append(Range.from_min_extent(start, extent, span=s.span)) + return BufferRegion(buffer_slice.buffer, region) + + +def tvm_span_from_synr(span: synr.ast.Span) -> Span: """Convert a synr span to a TVM span""" return Span( SourceName(span.filename), @@ -57,3 +107,32 @@ def from_synr_span(span): span.start_column, span.end_column, ) + + +def synr_span_from_tvm(span: Span) -> synr.ast.Span: + """Convert a TVM span to a synr span""" + return synr.ast.Span( + span.source_name.name, + span.line, + span.column, + span.end_line, + span.end_column, + ) + + +def call_with_error_reporting( + report_error, + node_span, + func, + *args, + **kwargs, +): + """Call function with exception handling and report error using node_span""" + try: + return func(*args, **kwargs) + except DiagnosticError as err: + raise err + except Exception as err: # pylint: disable=broad-except + # printing last non-empty row of error message. + error_msg = list(filter(None, str(err).split("\n")))[-1] + report_error(error_msg, node_span) diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 1a3eb4806677..feeade24284a 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -106,3 +106,26 @@ def verify_gpu_code(func, constraints): The result of verification. """ return _ffi_api.verify_gpu_code(func, constraints) + + +def get_block_access_region(block, buffer_var_map): + """Auto detect the block read/write region according to body stmt + It will detect the read/write region as an array in order of appearance in AST + + Parameters + ---------- + block: tvm.tir.Block + The block to be detected. + + buffer_var_map : Dict[Var, Buffer] + The outside buffers which may be accessed the block. Mapping from buffer var to the buffer + + Returns + ------- + result : List[List[BufferRegion]] + Array of access regions. There are three arrays of BufferRegion: + - first: read regions + - second: write regions + - third: opaque regions + """ + return _ffi_api.get_block_access_region(block, buffer_var_map) diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 8d5bba5e5bb0..788074073c08 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -476,8 +476,7 @@ inline const char* ForKind2String(ForKind t) { case ForKind::kUnrolled: return "unroll"; case ForKind::kThreadBinding: - LOG(FATAL) << "Loop ThreadBinding is reserved for future used and " - << "not yet supported in TIR"; + return "thread_binding"; } LOG(FATAL) << "Unknown ForKind"; return "Unknown"; diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 86b175e1676c..eb444eccd07f 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -22,6 +22,7 @@ * \brief Printer class to print Tensor IR to python syntax script */ +#include #include #include #include @@ -66,7 +67,10 @@ class TVMScriptPrinter : public StmtFunctor, std::unordered_map func2var_; /*! \brief var collector (var defined by For/Loop/Block) */ std::unordered_set var_not_in_headers; - /*! \brief buffer collector (buffer defined in BufferMap and BufferAllocation)*/ + /*! + * \brief buffer collector + * (buffer defined in BufferMap, BufferAllocation and MatchBufferRegion) + */ std::unordered_set buf_not_in_headers; /*! \brief Map from Var to thread env name */ std::unordered_map var_env_map_; @@ -84,6 +88,8 @@ class TVMScriptPrinter : public StmtFunctor, int num_child_; /*! \brief the number of current node */ int current_num_; + /*! \brief loop stack without annotations */ + std::vector loop_stack_; Doc VisitExpr_(const CastNode* op) override; Doc VisitExpr_(const VarNode* op) override; @@ -131,6 +137,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitStmt_(const ForNode* op) override; Doc VisitStmt_(const PrefetchNode* op) override; Doc VisitStmt_(const EvaluateNode* op) override; + Doc VisitStmt_(const BlockRealizeNode* op) override; Doc VisitStmtDefault_(const Object* op) override; Doc VisitType_(const PrimTypeNode* node) override; @@ -145,12 +152,24 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintArray(const ArrayNode* op); Doc PrintBuffer(const BufferNode* op); Doc AllocBufferDeclaration(const Buffer& buf); + Doc PrintBufferRegion(const BufferRegionNode* op); + Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op); + Doc PrintAnnotations(const Map& annotations); static Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } Doc GetUniqueName(std::string prefix); Doc AllocVar(const Var& var); Doc AllocBuf(const Buffer& buffer); + /*! Helper functions for loop printing. */ + /*! + * \brief Print a single for loop + * \param loop The for loop to be printed + */ + Doc PrintLoop(const For& loop); + /*! \brief Print all simple loops in stack into one line using tir.grid(). */ + Doc PrintLoopStack(); + /*! * \brief Print additional info about expr in comment. * \param expr The expression. @@ -308,6 +327,36 @@ Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) { return val; } +Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) { + const Buffer& buf = op->buffer; + buf_not_in_headers.insert(buf.get()); + + Doc doc = Print(op->buffer) << " = tir.match_buffer_region(" << Print(op->source); + if (!buf->strides.empty()) { + doc << ", strides=" << Print(buf->strides); + } + if (buf->offset_factor != 0 && buf->elem_offset->IsInstance()) { + Var elem_offset = Downcast(buf->elem_offset); + if (memo_var_.find(elem_offset) != memo_var_.end()) { + doc << ", elem_offset=" << Print(buf->elem_offset); + } else { + // implicitly define elem_offset + memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() + ".elem_offset"); + var_not_in_headers.insert(elem_offset.get()); + } + } else { + doc << ", elem_offset=" << Print(buf->elem_offset); + } + if (buf->data_alignment != -1) { + doc << ", align=" << buf->data_alignment; + } + if (buf->offset_factor != 0) { + doc << ", offset_factor=" << buf->offset_factor; + } + doc << ")"; + return doc; +} + Doc TVMScriptPrinter::Print(const ObjectRef& node) { if (!node.defined()) return Doc::Text("None"); if (node->IsInstance()) { @@ -330,6 +379,10 @@ Doc TVMScriptPrinter::Print(const ObjectRef& node) { return PrintIterVar(node.as()); } else if (node->IsInstance()) { return PrintRange(node.as()); + } else if (node->IsInstance()) { + return PrintBufferRegion(node.as()); + } else if (node->IsInstance()) { + return PrintMatchBufferRegion(node.as()); } else { meta_collector_.Collect(node); return this->meta_.GetMetaNode(node); @@ -660,9 +713,7 @@ inline const char* ForKind2String(ForKind t) { case ForKind::kUnrolled: return "unroll"; case ForKind::kThreadBinding: - LOG(FATAL) << "Loop ThreadBinding is reserved for future used and " - << "not yet supported in TIR"; - return "threadbinding"; + return "thread_binding"; } LOG(FATAL) << "Unknown ForKind"; return "Unknown"; @@ -671,9 +722,27 @@ inline const char* ForKind2String(ForKind t) { Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { Doc doc; var_not_in_headers.insert(op->loop_var.get()); - doc << "for " << Print(op->loop_var) << " in tir." + std::string(ForKind2String(op->kind)) + "(" - << Print(op->min) << ", " << Print(op->min + op->extent) - << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + const auto* body = op->body.as(); + bool simple_loop = op->kind == ForKind::kSerial && op->annotations.empty() && is_zero(op->min); + if (simple_loop) loop_stack_.push_back(GetRef(op)); + // It is a loop that can be compressed, let the loops below print it out + if (simple_loop && body != nullptr) return Print(GetRef(body)); + // It is a loop that can not be compressed + bool print_above = !loop_stack_.empty(); + // print loops above if needed + if (print_above) { + doc << PrintLoopStack(); + loop_stack_.clear(); + } + if (!simple_loop) { + // print current loop if needed + Doc current_loop; + current_loop << PrintLoop(GetRef(op)); + current_loop << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + doc << (print_above ? Doc::Indent(4, Doc::NewLine() << current_loop) : current_loop); + } else { + doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); + } return doc; } @@ -713,6 +782,88 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { return doc; } +Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { + const auto* block_op = op->block.as(); + // print block name and block vars + Doc doc; + doc << "with tir.block(["; + std::vector block_var_docs; + for (const auto& iter_var : block_op->iter_vars) { + Doc block_var_doc; + if (is_zero(iter_var->dom->min) && iter_var->iter_type == kDataPar) { + block_var_doc << Print(iter_var->dom->extent); + } else { + block_var_doc << "tir."; + switch (iter_var->iter_type) { + case kDataPar: + block_var_doc << "range"; + break; + case kCommReduce: + block_var_doc << "reduce_axis"; + break; + case kOrdered: + block_var_doc << "scan_axis"; + break; + case kOpaque: + block_var_doc << "opaque_axis"; + break; + default: + LOG(FATAL) << "Unknown block var iter type"; + break; + } + block_var_doc << "(" << Print(iter_var->dom->min) << ", " + << Print(iter_var->dom->min + iter_var->dom->extent) << ")"; + } + block_var_docs.push_back(block_var_doc); + } + doc << PrintSep(block_var_docs, Doc::Text(", ")) << "], "; + doc << Doc::StrLiteral(block_op->name_hint) << ")"; + std::vector block_var_names; + for (const auto& iter_var : block_op->iter_vars) { + var_not_in_headers.insert(iter_var->var.get()); + block_var_names.push_back(Print(iter_var->var)); + } + if (!block_var_names.empty()) { + doc << " as [" << PrintSep(block_var_names, Doc::Text(", ")) << "]"; + } + doc << ":"; + Doc block_attr_doc; + // print predicate, binding, read/write tensor region, annotations + if (!is_one(op->predicate)) { + block_attr_doc << Doc::NewLine() << "tir.where(" << Print(op->predicate) << ")"; + } + for (size_t i = 0; i < block_op->iter_vars.size(); ++i) + block_attr_doc << Doc::NewLine() << "tir.bind(" << Print(block_op->iter_vars[i]->var) << ", " + << Print(op->iter_values[i]) << ")"; + block_attr_doc << Doc::NewLine() << "tir.reads(" << Print(block_op->reads) << ")"; + block_attr_doc << Doc::NewLine() << "tir.writes(" << Print(block_op->writes) << ")"; + if (!block_op->annotations.empty()) { + block_attr_doc << Doc::NewLine() << "tir.block_attr({"; + block_attr_doc << PrintAnnotations(block_op->annotations); + block_attr_doc << "})"; + } + // print body + Doc body; + body << Doc::NewLine(); + for (const auto& alloc_buf : block_op->alloc_buffers) { + buf_not_in_headers.insert(alloc_buf.get()); + body << Print(alloc_buf) << " = tir.alloc_buffer(" << memo_buf_decl_[alloc_buf] << ")" + << Doc::NewLine(); + } + for (const auto& match_buf : block_op->match_buffers) { + body << Print(match_buf) << Doc::NewLine(); + } + if (block_op->init.defined()) { + Doc init_block; + init_block << "with tir.init():"; + init_block << Doc::Indent(4, Doc::NewLine() << PrintBody(block_op->init.value())); + body << init_block << Doc::NewLine(); + } + body << PrintBody(block_op->body); + doc << Doc::Indent(4, block_attr_doc << body); + return doc; +} + Doc TVMScriptPrinter::PrintBody(const Stmt& body) { int memo_num_child, memo_current_num; std::swap(memo_num_child, num_child_); @@ -890,6 +1041,73 @@ Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) { return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer); } +Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) { + Doc doc; + doc << Print(op->buffer) << "["; + for (size_t i = 0; i < op->region.size(); ++i) { + if (i != 0) doc << ", "; + const auto& range = op->region[i]; + if (!is_one(range->extent)) { + doc << Print(range->min) << ":" << Print(range->min + range->extent); + } else { + doc << Print(range->min); + } + } + doc << "]"; + return doc; +} + +Doc TVMScriptPrinter::PrintAnnotations(const Map& annotations) { + Doc res; + std::vector> anno_list; + anno_list.reserve(annotations.size()); + for (const auto& pair : annotations) { + anno_list.emplace_back(pair); + } + sort(anno_list.begin(), anno_list.end()); + for (size_t i = 0; i < anno_list.size(); ++i) { + if (i != 0) { + res << ", "; + } + res << "\"" << anno_list[i].first << "\":" << Print(anno_list[i].second); + } + return res; +} + +Doc TVMScriptPrinter::PrintLoop(const For& loop) { + Doc res; + res << "for " << Print(loop->loop_var) + << " in tir." + std::string(ForKind2String(loop->kind)) + "(" << Print(loop->min) << ", " + << Print(loop->min + loop->extent); + if (loop->thread_binding.defined()) { + res << ", thread = "; + res << Print(loop->thread_binding.value()->thread_tag); + } + if (!loop->annotations.empty()) { + res << ", annotation = {"; + res << PrintAnnotations(loop->annotations); + res << "}"; + } + res << "):"; + return res; +} + +Doc TVMScriptPrinter::PrintLoopStack() { + Doc res; + if (loop_stack_.size() == 1) { + res << PrintLoop(loop_stack_[0]); + } else if (loop_stack_.size() > 1) { + std::vector vars, extents; + for (const auto& loop : loop_stack_) { + vars.push_back(Print(loop->loop_var)); + extents.push_back(Print(loop->extent)); + } + res << "for " << PrintSep(vars, Doc::Text(", ")) << " in tir.grid(" + << PrintSep(extents, Doc::Text(", ")) << "):"; + } + return res; +} + TVM_REGISTER_GLOBAL("script.AsTVMScript") .set_body_typed([](const ObjectRef& functions, bool show_meta) { diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc new file mode 100644 index 000000000000..24ed4e7d29b1 --- /dev/null +++ b/src/tir/analysis/block_access_region_detector.cc @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/block_region_detector.cc + * \brief Detect block read/write regions by visiting its body + */ + +#include +#include +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Auto detect the block read write region + * It will detect the read/write region as an array in order of appearance in AST + * \note This detector only accepts to visit a block and will not visit child blocks recursively + */ +class BlockReadWriteDetector : public StmtExprVisitor { + public: + explicit BlockReadWriteDetector(const Map& buffer_var_map) + : buffer_var_map_(buffer_var_map) {} + + /*! \brief Return read regions of the block */ + Array CollectReads(); + /*! \brief Return write regions of the block */ + Array CollectWrites(); + /*! + * \brief Return opaque buffer regions of the block + * \note The buffer accessed by load/store or call with buffer.data will + * be marked as opaque. + */ + Array CollectOpaques(); + /*! \brief overload operator() to make sure it accepts a block node */ + void operator()(const Stmt& stmt); + + private: + /*! \brief Iteration range for loop_vars */ + std::unordered_map dom_map_; + /*! \brief The buffers that the current block reads */ + std::vector read_buffers_; + /*! \brief The buffers that the current block writes */ + std::vector writes_buffers_; + /*! \brief The opaque buffer which is access by buffer.data */ + std::vector opaque_buffers_; + /*! \brief The read regions of the current block */ + std::vector> read_regions_; + /*! \brief The write regions of the current block */ + std::vector> write_regions_; + /*! \brief The outside buffer data mapping to its buffer */ + Map buffer_var_map_; + /*! \brief The analyzer for simplifying*/ + arith::Analyzer analyzer_; + + /*! + * \brief Update read/write buffers and regions with provided buffer and region + * \param buffers The buffers should be updated + * \param regions The access regions should be updated + * \param buffer The provided buffer + * \param region The provided region + */ + void Update(std::vector* buffers, std::vector>* regions, + const Buffer& buffer, const std::vector& region); + + /*! \brief Helper function to collect access regions. */ + Array CollectRegions(const std::vector& buffers, + const std::vector>& regions); + + /*! \brief Helper function to add a opaque buffer. */ + void AddOpaque(const Var& buffer_var); + + void VisitStmt_(const ForNode* op) override; + void VisitStmt_(const BlockRealizeNode* op) override; + void VisitStmt_(const BufferStoreNode* op) override; + void VisitStmt_(const StoreNode* op) override; + void VisitExpr_(const BufferLoadNode* op) override; + void VisitExpr_(const LoadNode* op) override; + void VisitExpr_(const VarNode* op) override; +}; + +void BlockReadWriteDetector::operator()(const Stmt& stmt) { + ICHECK(stmt.as() != nullptr) << "Only allow to visit a block"; + StmtExprVisitor::operator()(stmt); +} + +Array BlockReadWriteDetector::CollectReads() { + return CollectRegions(read_buffers_, read_regions_); +} + +Array BlockReadWriteDetector::CollectWrites() { + return CollectRegions(writes_buffers_, write_regions_); +} + +Array BlockReadWriteDetector::CollectOpaques() { + Array res; + res.reserve(opaque_buffers_.size()); + for (const Buffer& buffer : opaque_buffers_) { + res.push_back(BufferRegion::FullRegion(buffer)); + } + return res; +} + +void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { AddOpaque(GetRef(op)); } + +void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) { + AddOpaque(op->buffer_var); + ExprVisitor::VisitExpr_(op); +} + +void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { + std::vector relaxed_region; + for (const PrimExpr& index : op->indices) { + relaxed_region.push_back(arith::EvalSet(index, dom_map_)); + } + Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region); + ExprVisitor::VisitExpr_(op); +} + +void BlockReadWriteDetector::VisitStmt_(const ForNode* op) { + Range range = Range::FromMinExtent(op->min, op->extent); + dom_map_[op->loop_var.get()] = arith::IntSet::FromRange(range); + StmtVisitor::VisitStmt_(op); + dom_map_.erase(op->loop_var.get()); +} + +void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { + AddOpaque(op->buffer_var); + StmtVisitor::VisitStmt_(op); +} + +void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) { + std::vector relaxed_region; + for (const PrimExpr& index : op->indices) { + relaxed_region.push_back(arith::EvalSet(index, dom_map_)); + } + Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region); + StmtVisitor::VisitStmt_(op); +} + +void BlockReadWriteDetector::VisitStmt_(const BlockRealizeNode* op) { + /*! \note detector will not visit child block recursively, so that it will stop here */ + std::unordered_map vmap; + for (size_t i = 0; i < op->block->iter_vars.size(); ++i) { + vmap[op->block->iter_vars[i]->var.get()] = op->iter_values[i]; + } + for (const auto& read : op->block->reads) { + std::vector relaxed_region; + for (const auto& range : read->region) { + relaxed_region.push_back( + arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent( + Substitute(range->min, vmap), Substitute(range->extent, vmap))), + dom_map_)); + } + Update(&read_buffers_, &read_regions_, read->buffer, relaxed_region); + } + for (const auto& write : op->block->writes) { + std::vector relaxed_region; + for (const auto& range : write->region) { + relaxed_region.push_back( + arith::EvalSet(arith::IntSet::FromRange(Range::FromMinExtent( + Substitute(range->min, vmap), Substitute(range->extent, vmap))), + dom_map_)); + } + Update(&writes_buffers_, &write_regions_, write->buffer, relaxed_region); + } +} + +void BlockReadWriteDetector::Update(std::vector* buffers, + std::vector>* regions, + const Buffer& buffer, + const std::vector& region) { + if (buffer_var_map_.find(buffer->data) == buffer_var_map_.end()) return; + ICHECK_EQ(buffers->size(), regions->size()) + << " Expect the buffer and regions to have the same size "; + for (size_t i = 0; i < regions->size(); ++i) { + if ((*buffers)[i].same_as(buffer)) { + ICHECK_EQ((*regions)[i].size(), region.size()) << "Inconsistent buffer dimension"; + for (size_t j = 0; j < region.size(); ++j) { + (*regions)[i][j] = arith::Union({(*regions)[i][j], region[j]}); + } + return; + } + } + buffers->push_back(buffer); + regions->push_back(region); +} + +Array BlockReadWriteDetector::CollectRegions( + const std::vector& buffers, + const std::vector>& regions) { + ICHECK_EQ(buffers.size(), regions.size()); + Array res; + res.reserve(buffers.size()); + for (size_t i = 0; i < regions.size(); ++i) { + Array region; + region.reserve(regions[i].size()); + for (size_t j = 0; j < regions[i].size(); j++) { + tvm::arith::IntSet range = regions[i][j]; + region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); + } + res.push_back(BufferRegion(buffers[i], region)); + } + return res; +} + +void BlockReadWriteDetector::AddOpaque(const Var& buffer_var) { + auto it = buffer_var_map_.find(buffer_var); + if (it != buffer_var_map_.end()) { + const Buffer& buffer = (*it).second; + for (const Buffer& opaque_buffer : opaque_buffers_) { + if (buffer.same_as(opaque_buffer)) return; + } + opaque_buffers_.push_back(buffer); + } +} + +Array> GetBlockAccessRegion(const Block& block, + const Map& buffer_var_map) { + BlockReadWriteDetector detector(buffer_var_map); + detector(block); + return {detector.CollectReads(), detector.CollectWrites(), detector.CollectOpaques()}; +} + +TVM_REGISTER_GLOBAL("tir.analysis.get_block_access_region").set_body_typed(GetBlockAccessRegion); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/script/script_complete.cc b/src/tir/ir/script/script_complete.cc new file mode 100644 index 000000000000..7c9fff724e33 --- /dev/null +++ b/src/tir/ir/script/script_complete.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/ir/script/script_complete.cc + * \brief Used by TVM Script parser to expand incomplete TIR input + */ + +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace tir { + +/*! \brief Generate surrounding loops automatically */ +class ScriptCompleter : public StmtMutator { + public: + explicit ScriptCompleter(Map* buffer_var_map) : buffer_var_map_(buffer_var_map) {} + /*! \brief Whether the stmt contains at least one block. */ + bool contains_block = false; + + private: + Map* buffer_var_map_; + Stmt VisitStmt_(const BlockRealizeNode* op) override { + contains_block = true; + Stmt body = StmtMutator::VisitStmt_(op); + if (!op->iter_values.empty() && !op->iter_values[0].dtype().is_int()) { + auto block_with_binding = CopyOnWrite(Downcast(body).get()); + std::vector bindings; + for (size_t i = 0; i < op->iter_values.size(); ++i) { + bindings.push_back(Var("i" + std::to_string(i))); + } + block_with_binding->iter_values = bindings; + body = BlockRealize(block_with_binding); + for (int i = op->iter_values.size() - 1; i >= 0; --i) { + body = For(Downcast(bindings[i]), op->block->iter_vars[i]->dom->min, + op->block->iter_vars[i]->dom->extent, {}, body); + } + } + return body; + } + + Stmt VisitStmt_(const BlockNode* op) override { + // Buffers allocated in the block can be accessed by its body. + for (const auto& alloc_buffer : op->alloc_buffers) { + buffer_var_map_->Set(alloc_buffer->data, alloc_buffer); + } + Block block = Downcast(StmtMutator::VisitStmt_(op)); + // Remove buffers allocated inside block to detect its access region + for (const auto& alloc_buffer : op->alloc_buffers) { + buffer_var_map_->erase(alloc_buffer->data); + } + if (block->reads.empty() || block->writes.empty()) { + auto access_region = GetBlockAccessRegion(block, *buffer_var_map_); + const Array& reads = access_region[0]; + const Array& writes = access_region[1]; + const Array& opaque = access_region[2]; + CHECK(opaque.empty()) + << "ValueError: Can not auto detect buffer access region from tir.Load, tir.Store or " + "direct access by buffer data. Please annotation the access region manually"; + auto n = CopyOnWrite(block.operator->()); + if (!n->reads.defined()) n->reads = reads; + if (!n->writes.defined()) n->writes = writes; + return Block(n); + } else { + return std::move(block); + } + } +}; + +PrimFunc ScriptComplete(PrimFunc func, const Array& root_allocates) { + Map buffer_var_map; + for (const auto& pair : func->buffer_map) { + const Buffer& buffer = pair.second; + buffer_var_map.Set(buffer->data, buffer); + } + for (const auto& alloc : root_allocates) { + buffer_var_map.Set(alloc->data, alloc); + } + ScriptCompleter script_completer(&buffer_var_map); + // generate surrounding loops automatically + Stmt res = script_completer(func->body); + // generate root block automatically + if (script_completer.contains_block && + (!res->IsInstance() || !root_allocates.empty())) { + res = Block({}, {}, {}, "root", res, NullOpt, root_allocates); + res = BlockRealize({}, Bool(true), Downcast(res)); + } + if (func->body.same_as(res)) { + return func; + } else { + auto fptr = func.CopyOnWrite(); + fptr->body = res; + return func; + } +} + +TVM_REGISTER_GLOBAL("script.Complete").set_body_typed(ScriptComplete); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py new file mode 100644 index 000000000000..7e4d7d87c1e1 --- /dev/null +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir, script +from tvm.ir import Range + + +@tvm.script.tir +def func() -> None: + A = tir.alloc_buffer((128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + C = tir.alloc_buffer((128, 128), "float32") + D = tir.alloc_buffer((128, 128), "float32") + with tir.block([]): + # Need add read/write region manually to avoid triggering block access region detector + tir.reads([B[0, 0], C[0:16, 0:16], A[4:12, 4:12]]) + tir.writes([A[0:12, 0:12]]) + for i, j in tir.grid(8, 8): + A[i, j] = B[0, 0] + C[0, 0] + with tir.block([2, 2]) as [vi, vj]: + tir.reads([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8], C[12:16, 12:16]]) + tir.writes([A[vi * 4 + 4 : vi * 4 + 8, vj * 4 + 4 : vj * 4 + 8]]) + for i, j in tir.grid(4, 4): + A[vi * 4 + 4 + i, vj * 4 + 4 + j] += C[i + 12, j + 12] + tir.evaluate(D.data) + + +def test_block_access_region_detector(): + block = func.body.block.body.block + alloc_buffers = func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + ret = tir.analysis.get_block_access_region(block, buffer_var_map) + + tvm.ir.assert_structural_equal(block.reads, ret[0]) + tvm.ir.assert_structural_equal(block.writes, ret[1]) + D = alloc_buffers[-1] + tvm.ir.assert_structural_equal( + [tvm.tir.BufferRegion(D, [Range(0, 128), Range(0, 128)])], ret[2] + ) + + +if __name__ == "__main__": + test_block_access_region_detector() diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 048a9544d6df..052217b32cb5 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -144,6 +144,197 @@ def test_no_body(): check_error(no_body, 3) +def allocate_with_buffers() -> None: + with tir.allocate([1], "float32", "") as [A, B]: # error + tir.evaluate(1.0) + + +def test_allocate_with_buffers(): + check_error(allocate_with_buffers, 2) + + +def inconsistent_binding() -> None: + with tir.block([128, 128]) as [vi]: # error + tir.evaluate(1.0) + + +def test_inconsistent_binding(): + check_error(inconsistent_binding, 2) + + +def invalid_block_axes(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + with tir.block([A]) as [vi]: # error + tir.evaluate(1.0) + + +def test_invalid_block_axes(): + check_error(invalid_block_axes, 3) + + +def miss_block_bind() -> None: + with tir.block([16, 16]) as [vi, vj]: # error + tir.bind(vi, 1) + tir.evaluate(1.0) + + +def test_miss_block_bind(): + check_error(miss_block_bind, 2) + + +def invalid_loop_var() -> None: + for i, j in range(0, 16): # error + tir.evaluate(1.0) + + +def test_invalid_loop_var(): + check_error(invalid_loop_var, 2) + + +def inconsistent_grid() -> None: + for i in tir.grid(16, 16): # error + tir.evaluate(1.0) + + +def test_inconsistent_grid(): + check_error(inconsistent_grid, 2) + + +def invalid_match_buffer_region() -> None: + with tir.block([16, 16]) as [vi, vj]: + A = tir.match_buffer_region(vi) # error + tir.evaluate(1.0) + + +def test_invalid_match_buffer_region(): + check_error(invalid_match_buffer_region, 3) + + +def duplicate_buffer() -> None: + A = tir.alloc_buffer((128, 128), "float32") + with tir.block([16, 16]) as [vi, vj]: + A = tir.alloc_buffer((128, 128), "float32") # error + tir.evaluate(1.0) + + +def test_duplicate_buffer(): + check_error(duplicate_buffer, 4) + + +def duplicate_reads() -> None: + A = tir.alloc_buffer((128, 128), "float32") + with tir.block([16, 16]) as [vi, vj]: + tir.reads(A[0:8, 0:8]) + tir.reads(A[0:16, 0:16]) # error + tir.evaluate(1.0) + + +def duplicate_writes() -> None: + A = tir.alloc_buffer((128, 128), "float32") + with tir.block([16, 16]) as [vi, vj]: + tir.writes(A[0:8, 0:8]) + tir.writes(A[0:16, 0:16]) # error + tir.evaluate(1.0) + + +def duplicate_predicate() -> None: + with tir.block([16, 16]) as [vi, vj]: + tir.where(1) + tir.where(0) # error + + +def duplicate_annotations() -> None: + with tir.block([16, 16]) as [vi, vj]: + tir.block_attr({}) + tir.block_attr({}) # error + + +def duplicate_init() -> None: + with tir.block([16, 16]) as [vi, vj]: + with tir.init(): + tir.evaluate(1.0) + with tir.init(): # error + tir.evaluate(1.0) + + +def test_duplicate_block_signature(): + check_error(duplicate_reads, 5) + check_error(duplicate_writes, 5) + check_error(duplicate_predicate, 4) + check_error(duplicate_annotations, 4) + check_error(duplicate_init, 5) + + +def opaque_access_during_complete(a: ty.handle) -> None: # error + A = tir.match_buffer(a, (16, 16), "float32") + with tir.block([16, 16]) as [vi, vj]: + tir.evaluate(tir.load("float32", A.data, vi * 16 + vj)) + + +def test_opaque_access_during_complete(): + check_error(opaque_access_during_complete, 1) + + +def convert_slice_to_bufferload() -> None: + A = tir.alloc_buffer((128, 128), "float32") + with tir.block([16, 16]) as [vi, vj]: + A[vi, vj] = A[vi : vi + 2, vj] + 1 # error + + +def test_convert_slice_to_bufferload(): + check_error(convert_slice_to_bufferload, 4) + + +def error_index_type() -> None: + A = tir.alloc_buffer((128, 128), "float32") + with tir.block([16, 16]) as [vi, vj]: + A[vi, vj] = A[vi, 0.0] + 1 # error + + +def test_error_index_type(): + check_error(error_index_type, 4) + + +def mismatch_args() -> None: + A = tir.alloc_buffer((128, 128), "float32") + with tir.block([16, 16]) as [vi, vj]: + tir.reads(A[0, 0], A[1, 1]) # error + tir.evaluate(1.0) + + +def test_mismatch_args(): + check_error(mismatch_args, 4) + + +def special_stmt_except() -> None: + A = tir.alloc_buffer("(128, 128)", "float32") # error + with tir.block([16, 16]) as [vi, vj]: + tir.evaluate(1.0) + + +def scope_handler_except() -> None: + for i in tir.serial("1", "1"): # error + tir.evaluate(1) + + +def intrin_except_unassign(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + tir.evaluate(A) # error + + +def intrin_except_assign(a: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + A[0, 0] = tir.load(A, A, A) # error + + +def test_tvm_exception_catch(): + # test catching c++ side exception + check_error(special_stmt_except, 2) + check_error(scope_handler_except, 2) + check_error(intrin_except_unassign, 3) + check_error(intrin_except_assign, 3) + + def check_error(module, rel_lineno): # Override the default renderer to accumulate errors _, start_line = inspect.getsourcelines(module) @@ -180,3 +371,17 @@ def render(e): test_return_not_allowed() test_tir_assert() test_no_body() + test_allocate_with_buffers() + test_inconsistent_binding() + test_invalid_block_axes() + test_miss_block_bind() + test_invalid_loop_var() + test_inconsistent_grid() + test_invalid_match_buffer_region() + test_duplicate_buffer() + test_duplicate_block_signature() + test_opaque_access_during_complete() + test_convert_slice_to_bufferload() + test_error_index_type() + test_mismatch_args() + test_tvm_exception_catch() diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index c7a38cccda49..a295908afa6a 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -2662,6 +2662,169 @@ def test_opt_conv_tensorcore_mod_host(): tvm.ir.assert_structural_equal(mod, rt_mod, True) +@tvm.script.tir +def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = tir.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def matmul_original(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = tir.float32(0) + + for k in range(0, 128): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@tvm.script.tir +def element_wise(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128), "float32") + C = tir.match_buffer(c, (128, 128), "float32") + B = tir.alloc_buffer((128, 128), "float32") + + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * tir.float32(2) + + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + tir.float32(1) + + +@tvm.script.tir +def predicate(b: ty.handle, c: ty.handle) -> None: + B = tir.match_buffer(b, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + + for i, jo, ji in tir.grid(16, 4, 5): + with tir.block([16, 16], "update") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, jo * 4 + ji) + tir.where(jo * 4 + ji < 16) + C[vi, vj] = B[vi, vj] + tir.float32(1) + + +def test_module_define(): + func1 = tvm.script.create_module({"matmul": matmul})["matmul"] + func2 = tvm.script.create_module({"element_wise": element_wise})["element_wise"] + func3 = tvm.script.create_module({"predicate": predicate})["predicate"] + mod1 = tvm.script.create_module({"func1": func1, "func2": func2, "func3": func3}) + mod2 = tvm.script.create_module({"func1": matmul, "func2": element_wise, "func3": predicate}) + tvm.ir.assert_structural_equal(mod1, mod2) + + +def test_matmul(): + func = matmul + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + +def test_matmul_original(): + func = matmul_original + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + assert isinstance(rt_func.body.block, tir.stmt.Block) + assert isinstance(rt_func.body.block.body, tir.stmt.For) + assert isinstance(rt_func.body.block.body.body, tir.stmt.For) + assert isinstance(rt_func.body.block.body.body.body, tir.stmt.SeqStmt) + assert isinstance(rt_func.body.block.body.body.body[0].block, tir.stmt.Block) + assert isinstance(rt_func.body.block.body.body.body[1], tir.stmt.For) + assert isinstance(rt_func.body.block.body.body.body[1].body.block, tir.stmt.Block) + + +def test_element_wise(): + func = element_wise + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + assert isinstance(rt_func.body.block, tir.stmt.Block) + assert isinstance(rt_func.body.block.body, tir.stmt.SeqStmt) + assert isinstance(rt_func.body.block.body[0], tir.stmt.For) + assert isinstance(rt_func.body.block.body[0].body, tir.stmt.For) + assert isinstance(rt_func.body.block.body[0].body.body.block, tir.stmt.Block) + + assert isinstance(rt_func.body.block.body[1], tir.stmt.For) + assert isinstance(rt_func.body.block.body[1].body, tir.stmt.For) + assert isinstance(rt_func.body.block.body[1].body.body.block, tir.stmt.Block) + + +def test_predicate(): + func = predicate + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + assert isinstance(rt_func.body.block, tir.stmt.Block) + assert isinstance(rt_func.body.block.body, tir.stmt.For) + assert isinstance(rt_func.body.block.body.body, tir.stmt.For) + assert isinstance(rt_func.body.block.body.body.body, tir.stmt.For) + assert isinstance(rt_func.body.block.body.body.body.body.block, tir.stmt.Block) + + +@tvm.script.tir +def for_thread_binding(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (16, 16), "float32") + + for i in tir.thread_binding(0, 16, thread="threadIdx.x"): + for j in tir.thread_binding(0, 16, thread="threadIdx.y"): + A[i, j] = B[i, j] + tir.float32(1) + + +def test_for_thread_binding(): + func = for_thread_binding + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + assert isinstance(rt_func.body, tir.stmt.For) + assert rt_func.body.kind == 4 + assert rt_func.body.thread_binding.thread_tag == "threadIdx.x" + assert isinstance(rt_func.body.body, tir.stmt.For) + assert rt_func.body.body.kind == 4 + assert rt_func.body.body.thread_binding.thread_tag == "threadIdx.y" + + +@tvm.script.tir +def block_elements(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + B = tir.match_buffer(b, (1, 1), "float32") + + with tir.block([1], "update") as [vi]: + tir.bind(vi, 0) + tir.where(True) + tir.reads(A[0:16, 0:16]) + tir.writes(B[0, 0]) + tir.block_attr({"attr_key": "attr_value"}) + C = tir.alloc_buffer((4, 4), dtype="float32") + D = tir.match_buffer_region(A[0:4, 0]) + with tir.init(): + B[0, 0] = tir.float32(0) + B[0, 0] = A[0, 0] + B[0, 0] + C[1, 1] + D[2, 0] + + +def test_block_elements(): + func = block_elements + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func) + + assert isinstance(rt_func.body.block, tir.stmt.Block) + assert isinstance(rt_func.body.block.body, tir.stmt.BufferStore) + assert isinstance(rt_func.body.block.init, tir.stmt.BufferStore) + assert len(rt_func.body.block.annotations) == 1 + assert rt_func.body.block.annotations["attr_key"] == "attr_value" + + if __name__ == "__main__": test_opt_gemm_normalize() test_opt_gemm_mod_host() @@ -2669,3 +2832,10 @@ def test_opt_conv_tensorcore_mod_host(): test_opt_conv_tensorcore_normalize() test_opt_conv_tensorcore_lower() test_opt_conv_tensorcore_mod_host() + test_module_define() + test_matmul() + test_matmul_original() + test_element_wise() + test_predicate() + test_for_thread_binding() + test_block_elements() diff --git a/tests/scripts/task_ci_python_setup.sh b/tests/scripts/task_ci_python_setup.sh index f48ed49a2266..b880cb9d6457 100755 --- a/tests/scripts/task_ci_python_setup.sh +++ b/tests/scripts/task_ci_python_setup.sh @@ -30,4 +30,4 @@ set -o pipefail # echo "Addtiional setup in" ${CI_IMAGE_NAME} -python3 -m pip install --user tlcpack-sphinx-addon==0.1.4 synr==0.2.1 +python3 -m pip install --user tlcpack-sphinx-addon==0.1.4 synr==0.3.0