Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][M1a] TVMScript Parser/Printer #7630

Merged
merged 22 commits into from
Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
*/
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);

/*!
* \brief Auto detect the block read/write region according to body stmt
* It will detect the read/write region as an array in order of appearance in AST
* \param block The block to be detected
* \param buffer_var_map The outside buffers which may be accessed the block.
* It is a map from buffer var to the buffer.
* \return Array of access regions.
* There are three arrays of BufferRegion:
* - first: read regions
* - second: write regions
* - third: opaque regions
*/
Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);

// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {
Expand Down
172 changes: 146 additions & 26 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,59 +16,179 @@
# under the License.
"""TVM Script Context Maintainer for TIR"""

from tvm.te import schedule
from typing import List, Mapping, Union, Optional, Dict, Callable
import synr


import tvm
from tvm.ir import Span
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
from .node import BufferSlice


class BlockInfo:
"""Information for block and block_realize signature"""

alloc_buffers: List[Buffer] = []
"""List[Buffer]: list of tir.alloc_buffer statements in the block signature"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: list of tir.match_buffer_region statements in the block signature"""
iter_bindings: Mapping[Var, PrimExpr] = {}
"""Mapping[Var, PrimExpr]: map of block iter var to its values"""
reads: Optional[List[BufferSlice]] = None
"""Optional[List[BufferSlice]]:
list of tir.reads statements in the block signature, None for not-visited"""
writes: Optional[List[BufferSlice]] = None
"""Optional[List[BufferSlice]]:
list of tir.writes statements in the block signature, None for not-visited"""
annotations: Optional[Mapping[str, Object]] = None
"""Optional[Mapping[str, Object]]:
list of tir.block_attr statements in the block signature, None for not-visited"""
predicate: Optional[PrimExpr] = None
"""Optional[PrimExpr]: block realize predicate, None for not-visited"""
init: Optional[Stmt] = None
"""Optional[Stmt]: init part of the block, None for not-visited"""

def __init__(self):
self.alloc_buffers = []
self.match_buffers = []
self.iter_bindings = {}
self.reads = None
self.writes = None
self.annotations = None
self.predicate = None
self.init = None


class ContextMaintainer:
"""Maintain all the necessary context info"""
"""Maintain all the necessary context info
Parameters
----------
_report_error : Callable[[str, Union[Span, synr.ast.Span]], None]
The report error function handle
"""

def __init__(self, parser):
# scope context
node_stack: List[List[synr.ast.Node]] = []
"""List[List[synr.ast.Node]]: The ast nodes insides the current scope"""
block_info_stack: List[BlockInfo] = []
"""List[BlockInfo]: The block info for the current block scope"""
loop_stack: List[List[Var]] = []
"""List[List[Var]]: List of loop vars inside the current block scope"""
symbols: List[Dict[str, Union[Var, Buffer]]] = []
"""List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""

# function context
func_params: List[Var] = []
"""List[Var]: The function parameters"""
func_buffer_map: Mapping[Var, Buffer] = {}
"""Mapping[Var, Buffer]: The function buffer map"""
func_dict_attr: Mapping[str, Object] = {}
"""Mapping[str, Object]: The function attrs"""
func_var_env_dict: Mapping[Var, str] = {}
"""Mapping[Var, str]: The map from var to env thread"""

# parser and analyzer
analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
"""tvm.arith.Analyzer: The analyzer for simplifying"""
_report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
"""Callable[[str, Union[Span, synr.ast.Span]], None]: The report error function handle"""

def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
# scope context
self.node_stack = [] # AST nodes of scopes
self.symbols = [] # symbols of scopes
self.node_stack = []
self.block_info_stack = []
self.loop_stack = []
self.symbols = []
# function context
self.func_params = [] # parameter list of function
self.func_buffer_map = {} # buffer_map of function
self.func_dict_attr = {} # func_attr of function
self.func_var_env_dict = {} # map from var to env_name
# parser
self.parser = parser

def pop_scope(self):
"""Pop the inner most scope"""
self.symbols.pop()
self.node_stack.pop()
self.func_params = []
self.func_buffer_map = {}
self.func_dict_attr = {}
self.func_var_env_dict = {}
# parser and analyzer
self._report_error = _report_error
self.analyzer = tvm.arith.Analyzer()

def new_scope(self, nodes=None):
"""Creating a new scope"""
def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
"""Creates a new scope

Note
------
This function is used for normal scopes that does not
involve a `with block` scope. Use `enter_block_scope`
for the other case.

Parameters
----------
nodes : Optional[List[synr.ast.Node]]
The synr AST nodes in new scope
"""
if nodes is None:
nodes = []
self.node_stack.append(list(reversed(nodes)))
self.symbols.append(dict())

def update_symbol(self, name, symbol):
def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the difference between a regular and block scope from the user perspective. When should a should enter_scope be used and when should enter_block_scope be used.

"""Creates a new block scope, the function will call `enter_scope` implicitly
Besides the behaviors of `enter_scope`, it will update loop_stack and block_info_stack
to maintain block info.

Note
------
This function should be used for the block scope,
aka the blocks that involves a `with block` scope.

Parameters
----------
nodes : Optional[List[synr.ast.Node]]
The synr AST nodes in new scope
"""
self.enter_scope(nodes)
# Create a new loop stack for the new block
self.loop_stack.append([])
# Create a new BlockInfo for the new block
self.block_info_stack.append(BlockInfo())

def exit_scope(self):
"""Pop the inner most scope"""
self.symbols.pop()
self.node_stack.pop()

def exit_block_scope(self):
"""Pop the inner most block scope, the function will call `exit_scope` implicitly"""
self.exit_scope()
# Pop loop stack
self.loop_stack.pop()
# Pop block_info
self.block_info_stack.pop()

def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node):
"""Append a symbol into current scope"""
if isinstance(symbol, schedule.Buffer):
if isinstance(symbol, Buffer):
if name in self.symbols[0]:
self.parser.report_error("Duplicate Buffer name")
self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
self.symbols[0][name] = symbol
else:
self.symbols[-1][name] = symbol

def remove_symbol(self, name):
def remove_symbol(self, name: str):
"""Remove a symbol"""
for symbols in reversed(self.symbols):
if name in symbols:
symbols.pop(name)
return
raise RuntimeError("Internal error of tvm script parser: no symbol named" + name)
raise RuntimeError("Internal error of tvm script parser: no symbol named " + name)

def lookup_symbol(self, name):
def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
"""Look up symbol by name"""
for symbols in reversed(self.symbols):
if name in symbols:
return symbols[name]
return None

def report_error(self, message, span):
self.parser.report_error(message, span)
def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
self._report_error(message, span)

def current_block_scope(self) -> BlockInfo:
return self.block_info_stack[-1]
20 changes: 16 additions & 4 deletions python/tvm/script/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.
"""TVM Script Parser Intrinsic Classes"""
# pylint: disable=redefined-builtin, relative-beyond-top-level
from typing import List, Any

import tvm.tir
from .registry import register
from .utils import get_param_list, from_synr_span
from .utils import get_param_list, tvm_span_from_synr


class Intrin:
Expand All @@ -29,8 +31,8 @@ def __init__(self, intrin, stmt=False):
def signature(self):
return "tir." + self.intrin.__name__, get_param_list(self.intrin)

def handle(self, arg_list, span):
return self.intrin(*arg_list, span=from_synr_span(span))
def handle(self, arg_list: List[Any], span: tvm.ir.Span):
return self.intrin(*arg_list, span=tvm_span_from_synr(span))


@register
Expand Down Expand Up @@ -98,6 +100,16 @@ def float64(imm, span):
return tvm.tir.Cast("float64", imm, span)


@register
def min_value(dtype, span):
return tvm.tir.min_value(dtype, span)


@register
def max_value(dtype, span):
return tvm.tir.max_value(dtype, span)


@register
def floordiv(x, y, span):
return tvm.tir.floordiv(x, y, span)
Expand Down Expand Up @@ -145,7 +157,7 @@ def get_axis(begin, end, iter_type, span):
block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)

iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span)
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span=span)


@register
Expand Down
Loading