Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][M1a] TVMScript Parser/Printer #7630

Merged
merged 22 commits into from
Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
*/
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);

/*!
* \brief Auto detect the block read/write region according to body stmt
* It will detect the read/write region as an array in order of appearance in AST
* \param block The block to be detected
* \param buffer_var_map The outside buffers which may be accessed the block.
* It is a map from buffer var to the buffer.
* \return Array of access regions.
* There are three arrays of BufferRegion:
* - first: read regions
* - second: write regions
* - third: opaque regions
*/
Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);

// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {
Expand Down
159 changes: 133 additions & 26 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,59 +16,166 @@
# under the License.
"""TVM Script Context Maintainer for TIR"""

from tvm.te import schedule
from typing import List, Mapping, Union, Optional, Dict, Callable
import synr


import tvm
from tvm.ir import Span
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
from .node import BufferSlice


class BlockInfo:
"""Information for block and block_realize signature"""

alloc_buffers: List[Buffer] = []
"""List[Buffer]: alloc_buffers list for the block"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: match_buffer_region list for the block"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document this more instead of just restating what the variable is called. You've done the same thing on the other variables.

Maybe it would be helpful to provide an example of a block and show what each variable corresponds to.

iter_bindings: Mapping[Var, PrimExpr] = {}
"""Mapping[Var, PrimExpr]: block iter var and its values"""
reads: Optional[List[BufferSlice]] = None
"""Optional[List[BufferSlice]]: block read buffer regions, None for not-visited"""
writes: Optional[List[BufferSlice]] = None
"""Optional[List[BufferSlice]]: block write buffer regions, None for not-visited"""
annotations: Optional[Mapping[str, Object]] = None
"""Optional[Mapping[str, Object]]: block annotations, None for not-visited"""
predicate: Optional[PrimExpr] = None
"""Optional[PrimExpr]: block realize predicate, None for not-visited"""
init: Optional[Stmt] = None
"""Optional[Stmt]: init part of the block, None for not-visited"""

def __init__(self):
self.alloc_buffers = []
self.match_buffers = []
self.iter_bindings = {}
self.reads = None
self.writes = None
self.annotations = None
self.predicate = None
self.init = None


class ContextMaintainer:
"""Maintain all the necessary context info"""
"""Maintain all the necessary context info
Parameters
----------
_report_error : Callable[[str, Union[Span, synr.ast.Span]], None]
The report error function handle
"""

def __init__(self, parser):
# scope context
node_stack: List[List[synr.ast.Node]] = []
"""List[List[synr.ast.Node]]: The ast nodes insides the current scope"""
block_info_stack: List[BlockInfo] = []
"""List[BlockInfo]: The block info for the current block scope"""
loop_stack: List[List[Var]] = []
"""List[List[Var]]: List of loop vars inside the current block scope"""
symbols: List[Dict[str, Union[Var, Buffer]]] = []
"""List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""

# function context
func_params: List[Var] = []
"""List[Var]: The function parameters"""
func_buffer_map: Mapping[Var, Buffer] = {}
"""Mapping[Var, Buffer]: The function buffer map"""
func_dict_attr: Mapping[str, Object] = {}
"""Mapping[str, Object]: The function attrs"""
func_var_env_dict: Mapping[Var, str] = {}
"""Mapping[Var, str]: The map from var to env thread"""

# parser and analyzer
analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
"""tvm.arith.Analyzer: The analyzer for simplifying"""
_report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
"""Callable[[str, Union[Span, synr.ast.Span]], None]: The report error function handle"""

def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
# scope context
self.node_stack = [] # AST nodes of scopes
self.symbols = [] # symbols of scopes
self.node_stack = []
self.block_info_stack = []
self.loop_stack = []
self.symbols = []
# function context
self.func_params = [] # parameter list of function
self.func_buffer_map = {} # buffer_map of function
self.func_dict_attr = {} # func_attr of function
self.func_var_env_dict = {} # map from var to env_name
# parser
self.parser = parser

def pop_scope(self):
"""Pop the inner most scope"""
self.symbols.pop()
self.node_stack.pop()
self.func_params = []
self.func_buffer_map = {}
self.func_dict_attr = {}
self.func_var_env_dict = {}
# parser and analyzer
self._report_error = _report_error
self.analyzer = tvm.arith.Analyzer()

def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
"""Creates a new scope

def new_scope(self, nodes=None):
"""Creating a new scope"""
Parameters
----------
nodes : Optional[List[synr.ast.Node]]
The synr AST nodes in new scope
"""
if nodes is None:
nodes = []
self.node_stack.append(list(reversed(nodes)))
self.symbols.append(dict())

def update_symbol(self, name, symbol):
def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the difference between a regular and block scope from the user perspective. When should a should enter_scope be used and when should enter_block_scope be used.

"""Creates a new block scope, the function will call `enter_scope` implicitly
Besides the behaviors of `enter_scope`, it will update loop_stack and block_info_stack
to maintain block info.
It should be used when the scope is a block (or likely to be a block)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this only for things like

with tir.block([]):
    ...

?


Parameters
----------
nodes : Optional[List[synr.ast.Node]]
The synr AST nodes in new scope
"""
self.enter_scope(nodes)
# Create a new loop stack for the new block
self.loop_stack.append([])
# Create a new BlockInfo for the new block
self.block_info_stack.append(BlockInfo())

def exit_scope(self):
"""Pop the inner most scope"""
self.symbols.pop()
self.node_stack.pop()

def exit_block_scope(self):
"""Pop the inner most block scope, the function will call `exit_scope` implicitly"""
self.exit_scope()
# Pop loop stack
self.loop_stack.pop()
# Pop block_info
self.block_info_stack.pop()

def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node):
"""Append a symbol into current scope"""
if isinstance(symbol, schedule.Buffer):
if isinstance(symbol, Buffer):
if name in self.symbols[0]:
self.parser.report_error("Duplicate Buffer name")
self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
self.symbols[0][name] = symbol
else:
self.symbols[-1][name] = symbol

def remove_symbol(self, name):
def remove_symbol(self, name: str):
"""Remove a symbol"""
for symbols in reversed(self.symbols):
if name in symbols:
symbols.pop(name)
return
raise RuntimeError("Internal error of tvm script parser: no symbol named" + name)
raise RuntimeError("Internal error of tvm script parser: no symbol named " + name)

def lookup_symbol(self, name):
def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
"""Look up symbol by name"""
for symbols in reversed(self.symbols):
if name in symbols:
return symbols[name]
return None

def report_error(self, message, span):
self.parser.report_error(message, span)
def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
self._report_error(message, span)

def current_block_scope(self) -> BlockInfo:
return self.block_info_stack[-1]
20 changes: 16 additions & 4 deletions python/tvm/script/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.
"""TVM Script Parser Intrinsic Classes"""
# pylint: disable=redefined-builtin, relative-beyond-top-level
from typing import List, Any

import tvm.tir
from .registry import register
from .utils import get_param_list, from_synr_span
from .utils import get_param_list, tvm_span_from_synr


class Intrin:
Expand All @@ -29,8 +31,8 @@ def __init__(self, intrin, stmt=False):
def signature(self):
return "tir." + self.intrin.__name__, get_param_list(self.intrin)

def handle(self, arg_list, span):
return self.intrin(*arg_list, span=from_synr_span(span))
def handle(self, arg_list: List[Any], span: tvm.ir.Span):
return self.intrin(*arg_list, span=tvm_span_from_synr(span))


@register
Expand Down Expand Up @@ -98,6 +100,16 @@ def float64(imm, span):
return tvm.tir.Cast("float64", imm, span)


@register
def min_value(dtype, span):
return tvm.tir.min_value(dtype, span)


@register
def max_value(dtype, span):
return tvm.tir.max_value(dtype, span)


@register
def floordiv(x, y, span):
return tvm.tir.floordiv(x, y, span)
Expand Down Expand Up @@ -145,7 +157,7 @@ def get_axis(begin, end, iter_type, span):
block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)

iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span)
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span=span)


@register
Expand Down
Loading