Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][M1a] TVMScript Parser/Printer #7630

Merged
merged 22 commits into from
Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
*/
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);

/*!
* \brief Auto detect the block read/write region according to body stmt
* It will detect the read/write region as an array in order of appearance in AST
* \param block The block to be detected
* \param buffer_var_map The outside buffers which may be accessed the block.
* It is a map from buffer var to the buffer.
* \return Array of access regions.
* There are three arrays of BufferRegion:
* - first: read regions
* - second: write regions
* - third: opaque regions
*/
Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);

// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {
Expand Down
172 changes: 146 additions & 26 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,59 +16,179 @@
# under the License.
"""TVM Script Context Maintainer for TIR"""

from tvm.te import schedule
from typing import List, Mapping, Union, Optional, Dict, Callable
import synr


import tvm
from tvm.ir import Span
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
from .node import BufferSlice


class BlockInfo:
"""Information for block and block_realize signature"""

alloc_buffers: List[Buffer] = []
"""List[Buffer]: list of tir.alloc_buffer statements in the block signature"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: list of tir.match_buffer_region statements in the block signature"""
iter_bindings: Mapping[Var, PrimExpr] = {}
"""Mapping[Var, PrimExpr]: map of block iter var to its values"""
reads: Optional[List[BufferSlice]] = None
"""Optional[List[BufferSlice]]:
list of tir.reads statements in the block signature, None for not-visited"""
writes: Optional[List[BufferSlice]] = None
"""Optional[List[BufferSlice]]:
list of tir.writes statements in the block signature, None for not-visited"""
annotations: Optional[Mapping[str, Object]] = None
"""Optional[Mapping[str, Object]]:
list of tir.block_attr statements in the block signature, None for not-visited"""
predicate: Optional[PrimExpr] = None
"""Optional[PrimExpr]: block realize predicate, None for not-visited"""
init: Optional[Stmt] = None
"""Optional[Stmt]: init part of the block, None for not-visited"""

def __init__(self):
self.alloc_buffers = []
self.match_buffers = []
self.iter_bindings = {}
self.reads = None
self.writes = None
self.annotations = None
self.predicate = None
self.init = None


class ContextMaintainer:
"""Maintain all the necessary context info"""
"""Maintain all the necessary context info
Parameters
----------
_report_error : Callable[[str, Union[Span, synr.ast.Span]], None]
The report error function handle
"""

def __init__(self, parser):
# scope context
node_stack: List[List[synr.ast.Node]] = []
"""List[List[synr.ast.Node]]: The ast nodes insides the current scope"""
block_info_stack: List[BlockInfo] = []
"""List[BlockInfo]: The block info for the current block scope"""
loop_stack: List[List[Var]] = []
"""List[List[Var]]: List of loop vars inside the current block scope"""
symbols: List[Dict[str, Union[Var, Buffer]]] = []
"""List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""

# function context
func_params: List[Var] = []
"""List[Var]: The function parameters"""
func_buffer_map: Mapping[Var, Buffer] = {}
"""Mapping[Var, Buffer]: The function buffer map"""
func_dict_attr: Mapping[str, Object] = {}
"""Mapping[str, Object]: The function attrs"""
func_var_env_dict: Mapping[Var, str] = {}
"""Mapping[Var, str]: The map from var to env thread"""

# parser and analyzer
analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer()
"""tvm.arith.Analyzer: The analyzer for simplifying"""
_report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
"""Callable[[str, Union[Span, synr.ast.Span]], None]: The report error function handle"""

def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
# scope context
self.node_stack = [] # AST nodes of scopes
self.symbols = [] # symbols of scopes
self.node_stack = []
self.block_info_stack = []
self.loop_stack = []
self.symbols = []
# function context
self.func_params = [] # parameter list of function
self.func_buffer_map = {} # buffer_map of function
self.func_dict_attr = {} # func_attr of function
self.func_var_env_dict = {} # map from var to env_name
# parser
self.parser = parser

def pop_scope(self):
"""Pop the inner most scope"""
self.symbols.pop()
self.node_stack.pop()
self.func_params = []
self.func_buffer_map = {}
self.func_dict_attr = {}
self.func_var_env_dict = {}
# parser and analyzer
self._report_error = _report_error
self.analyzer = tvm.arith.Analyzer()

def new_scope(self, nodes=None):
"""Creating a new scope"""
def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
"""Creates a new scope

Note
----
This function is used for normal scopes that do not involve
a `with block` scope. Use `enter_block_scope`
for block scope cases.

Parameters
----------
nodes : Optional[List[synr.ast.Node]]
The synr AST nodes in new scope
"""
if nodes is None:
nodes = []
self.node_stack.append(list(reversed(nodes)))
self.symbols.append(dict())

def update_symbol(self, name, symbol):
def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the difference between a regular and block scope from the user perspective. When should a should enter_scope be used and when should enter_block_scope be used.

"""Creates a new block scope, the function will call `enter_scope` implicitly
Besides the behaviors of `enter_scope`, it will update loop_stack and block_info_stack
to maintain block info.

Note
----
This function should be used to handle a block scope,
aka the blocks that involve a `with block` scope.
Comment on lines +177 to +178
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still confusing. What is a block? How does it differ from a regular scope?

Copy link
Member

@tqchen tqchen Mar 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The block structure is clearly described in the RFC https://discuss.tvm.apache.org/t/rfc-tensorir-a-schedulable-ir-for-tvm/7872

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put that in the codebase somewhere then? It'll be hard for people to understand if they have to go to discus to get the full docs.

Copy link
Member

@tqchen tqchen Mar 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There will be docs about the TensorIR tvmscript langauge, but that should come as a separate PR. Additionally, this PR already contains test cases that covers the cases needed. Like in our previous parser code , there is less of a description of the language itself..

While I agree some examples would be helpful, it may not be necessary, assuming the maintainer have a good understanding of the Block structure itself, and future docs of the language

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with Tristan on this. I think that assuming that a maintainer has a good understanding of block structure already is not a good assumption to make. Having examples in the codebase makes code easily understandable and accessible to anyone who wants to read it, not just people who are familiar with the code. Since the project is rapidly growing and getting new contributors, it's important to make code understandable and accessible. Scaling the number of developers in TVM isn't sustainable without good documentation -- and good documentation includes having good comments. Ideally, the comments and the formal documentation would even be a little bit redundant.

Copy link
Member

@tqchen tqchen Mar 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @electriclilies. I agree with all you said in particular wrt to code readability. We already followed the principle of enforcing heavy documentation in the case of user facing code and making sure the overall logic flows well.

Code readability also goes beyond the comments, a lot of efforts needs to be spent on API naming, intuitive callings and error handling. This PR does a lot of that, for example:

  • E0: Clear API naming: see convert_to_int function in the diff
  • E1: Good error handling: call_with_error_reporting that allows accurate error reporting in most of the cases.
  • E2: Type annotations: adding type annotations to a lot of the places(while our current convention does not require so) to increase readability
  • E3: Documentation of most functions

There is of course a tradeoff between the time we spend and amount of comments to be added and other efforts. On one hand we certainly want to add as many comments as possible. On the other hand, adding every code blocks may not be the best way of investing time -- we could spend more time on overall architectural correctness, the scaffolds(APIs and components) and other elements that makes the code more readable and maintainable.

Comming back to a related example(e.g. reviewing the quantization code). It is certainly helpful to add examples about network patterns happened during the quantization process, values being involved and so on. But that may not be the most important thing for now, since we can focus on more important issues on readability and maintainability -- e.g. clarifying the key APIs, make sure they compose well and so on. Examples can then be added to places that could contain subtle set of logic to help clarify things.

Right now we are prioritizing to add the examples to developer code paths that are more sutble, like the arithmetic modules and so on. In this particular case, an example can be suggested and checked in as followup PRs is more than welcomed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think any of us disagree with the extreme importance of comprehensive docs, but we need think a bit deeper what kind docs really help.

I would like to further specifically reiterate my understanding of categorization of helpful docs:

  • D1. User-facing. This helps a user, who is possibly not familiar with the implementation, to understand how to use some functionalities of TVM;
  • D2. Design doc. This provides the high-level overview of how the codebase is structured and why it is designed like this;
  • D3. Private API doc. Document the private APIs that are not directly user-facing.

D1 directly helps a user to better understand how to make things work and how it works. It is clear that D1 is desirable.

D2 helps developers to understand the design philosophy, how the codebase is structured, etc, so that more people could help better maintain the codebase. Without D2, people are unable to understand the key concepts - that is why Tianqi redirects us to the RFC, so that we could all understand what a "Block" is, etc. Without D2, no matter how many words the developer use to document private APIs, it is still painful to understand some data structures.

On D3, we always insist that APIs to be at least somewhat documented, so that maintainers could get a brief sense what will be going on if we call an API. Assuming we have good D1 and D2, we could substantially lower the steep learning curve and makes understanding D3 much easier - the prerequisite is that maintainers should read D1 and D2 first, in our specific case, the RFC and related materials.

I think everybody totally agrees with Lily's words, and trust me, nobody wants bring trouble to future maintainers :-) With D1 and D2 ready in place (after M1s and M2s merged), it will be much easier to understand the design philosophy. This is what we are doing in Ansor upstreaming too - we upstream many tutorials after the codebase is fully functioning.

It is totally understandable that reviewers are feeling frustrated when not understanding the design philosophy, and that is what RFCs are for, aren't they :-) Next time, we would love to see such frustration converts to clear questions and answers. For design philosophy-related questions, I would love to redirect everybody to the RFC from the very beginning, and we shouldn't debate outside the RFC about topics like "why two scopes".

Then should we include the RFC text into the inlined docs? I think it is debatable, and I prefer not to replicate. The basic reason is that we will have D1 and D2 in the end, and maintaining two copies is quite error-prone. An easier solution for maintenance is to add links to D1 and D2 on critical data structures.


Parameters
----------
nodes : Optional[List[synr.ast.Node]]
The synr AST nodes in new scope
"""
self.enter_scope(nodes)
# Create a new loop stack for the new block
self.loop_stack.append([])
# Create a new BlockInfo for the new block
self.block_info_stack.append(BlockInfo())

def exit_scope(self):
"""Pop the inner most scope"""
self.symbols.pop()
self.node_stack.pop()

def exit_block_scope(self):
"""Pop the inner most block scope, the function will call `exit_scope` implicitly"""
self.exit_scope()
# Pop loop stack
self.loop_stack.pop()
# Pop block_info
self.block_info_stack.pop()

def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node):
"""Append a symbol into current scope"""
if isinstance(symbol, schedule.Buffer):
if isinstance(symbol, Buffer):
if name in self.symbols[0]:
self.parser.report_error("Duplicate Buffer name")
self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
self.symbols[0][name] = symbol
else:
self.symbols[-1][name] = symbol

def remove_symbol(self, name):
def remove_symbol(self, name: str):
"""Remove a symbol"""
for symbols in reversed(self.symbols):
if name in symbols:
symbols.pop(name)
return
raise RuntimeError("Internal error of tvm script parser: no symbol named" + name)
raise RuntimeError("Internal error of tvm script parser: no symbol named " + name)

def lookup_symbol(self, name):
def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
"""Look up symbol by name"""
for symbols in reversed(self.symbols):
if name in symbols:
return symbols[name]
return None

def report_error(self, message, span):
self.parser.report_error(message, span)
def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
self._report_error(message, span)

def current_block_scope(self) -> BlockInfo:
return self.block_info_stack[-1]
20 changes: 16 additions & 4 deletions python/tvm/script/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.
"""TVM Script Parser Intrinsic Classes"""
# pylint: disable=redefined-builtin, relative-beyond-top-level
from typing import List, Any

import tvm.tir
from .registry import register
from .utils import get_param_list, from_synr_span
from .utils import get_param_list, tvm_span_from_synr


class Intrin:
Expand All @@ -29,8 +31,8 @@ def __init__(self, intrin, stmt=False):
def signature(self):
return "tir." + self.intrin.__name__, get_param_list(self.intrin)

def handle(self, arg_list, span):
return self.intrin(*arg_list, span=from_synr_span(span))
def handle(self, arg_list: List[Any], span: tvm.ir.Span):
return self.intrin(*arg_list, span=tvm_span_from_synr(span))


@register
Expand Down Expand Up @@ -98,6 +100,16 @@ def float64(imm, span):
return tvm.tir.Cast("float64", imm, span)


@register
def min_value(dtype, span):
return tvm.tir.min_value(dtype, span)


@register
def max_value(dtype, span):
return tvm.tir.max_value(dtype, span)


@register
def floordiv(x, y, span):
return tvm.tir.floordiv(x, y, span)
Expand Down Expand Up @@ -145,7 +157,7 @@ def get_axis(begin, end, iter_type, span):
block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)

iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span)
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span=span)


@register
Expand Down
Loading