Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][M1a] TVMScript Parser/Printer #7630

Merged
merged 22 commits into from
Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
*/
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);

/*!
* \brief Auto detect the block read/write region according to body stmt
* It will detect the read/write region as an array in order of appearance in AST
* \param block The block to be detected
* \param buffer_var_map The outside buffers which may be accessed the block.
* It is a map from buffer var to the buffer.
* \return Array of access regions.
* There are three arrays of BufferRegion:
* - first: read regions
* - second: write regions
* - third: opaque regions
*/
Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);

// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {
Expand Down
122 changes: 104 additions & 18 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,59 +16,145 @@
# under the License.
"""TVM Script Context Maintainer for TIR"""

from tvm.te import schedule
from typing import List, Mapping, Union, Optional, Dict, Callable
import synr


import tvm
from tvm.ir import Span
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
from .node import BufferSlice


class BlockInfo:
"""Information for block and block_realize signature"""

alloc_buffers: List[Buffer]
match_buffers: List[MatchBufferRegion]
iter_bindings: Mapping[Var, PrimExpr]
reads: Optional[List[BufferSlice]]
writes: Optional[List[BufferSlice]]
annotations: Optional[Mapping[str, Object]]
predicate: Optional[PrimExpr]
init: Optional[Stmt]
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
self.alloc_buffers = []
self.match_buffers = []
self.iter_bindings = {}
self.reads = None
self.writes = None
self.annotations = None
self.predicate = None
self.init = None


class ContextMaintainer:
"""Maintain all the necessary context info"""

def __init__(self, parser):
# scope context
# ast_node inside a scope
node_stack: List[List[synr.ast.Node]]
# loop stacks inside a block
block_info_stack: List[BlockInfo]
# loop stacks inside a block
loop_stack: List[List[Var]]
symbols: List[Dict[str, Union[Var, Buffer]]]

# function context
func_params: List[Var]
func_buffer_map: Mapping[Var, Buffer]
func_dict_attr: Mapping[str, Object]
func_var_env_dict: Mapping[Var, str]

# parser and analyzer
_report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
analyzer: tvm.arith.Analyzer

def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
# scope context
self.node_stack = [] # AST nodes of scopes
self.block_info_stack = [] # Block info of scopes
self.loop_stack = [] # stack of loop vars
self.symbols = [] # symbols of scopes
# function context
self.func_params = [] # parameter list of function
self.func_buffer_map = {} # buffer_map of function
self.func_dict_attr = {} # func_attr of function
self.func_var_env_dict = {} # map from var to env_name
# parser
self.parser = parser
# parser and analyzer
self._report_error = _report_error
self.analyzer = tvm.arith.Analyzer()

def pop_scope(self):
"""Pop the inner most scope"""
self.symbols.pop()
self.node_stack.pop()
def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
"""Creating a new scope
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved

def new_scope(self, nodes=None):
"""Creating a new scope"""
Parameters
----------
nodes : Optional[List[synr.ast.Node]]
The synr AST nodes in new scope
"""
if nodes is None:
nodes = []
self.node_stack.append(list(reversed(nodes)))
self.symbols.append(dict())

def update_symbol(self, name, symbol):
def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the difference between a regular and block scope from the user perspective. When should a should enter_scope be used and when should enter_block_scope be used.

"""Creating a new block scope, the function will call `enter_scope` implicitly
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
Besides behaviors of normal `enter_scope`, it will update loop_stack and block_info_stack
for block info maintaining.
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
nodes : Optional[List[synr.ast.Node]]
The synr AST nodes in new scope
"""
self.enter_scope(nodes)
# Create a new loop stack for the new block
self.loop_stack.append([])
# Create a new BlockInfo for the new block
self.block_info_stack.append(BlockInfo())

def exit_scope(self):
"""Pop the inner most scope"""
self.symbols.pop()
self.node_stack.pop()

def exit_block_scope(self):
"""Pop the inner most block scope, the function will call `exit_scope` implicitly"""
self.exit_scope()
# Pop loop stack
self.loop_stack.pop()
# Pop block_info
self.block_info_stack.pop()

def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node):
"""Append a symbol into current scope"""
if isinstance(symbol, schedule.Buffer):
if isinstance(symbol, Buffer):
if name in self.symbols[0]:
self.parser.report_error("Duplicate Buffer name")
self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
self.symbols[0][name] = symbol
else:
self.symbols[-1][name] = symbol

def remove_symbol(self, name):
def remove_symbol(self, name: str):
"""Remove a symbol"""
for symbols in reversed(self.symbols):
if name in symbols:
symbols.pop(name)
return
raise RuntimeError("Internal error of tvm script parser: no symbol named" + name)
raise RuntimeError("Internal error of tvm script parser: no symbol named " + name)

def lookup_symbol(self, name):
def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
"""Look up symbol by name"""
for symbols in reversed(self.symbols):
if name in symbols:
return symbols[name]
return None

def report_error(self, message, span):
self.parser.report_error(message, span)
def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
self._report_error(message, span)

def current_block_scope(self) -> BlockInfo:
return self.block_info_stack[-1]
20 changes: 16 additions & 4 deletions python/tvm/script/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.
"""TVM Script Parser Intrinsic Classes"""
# pylint: disable=redefined-builtin, relative-beyond-top-level
from typing import List, Any

import tvm.tir
from .registry import register
from .utils import get_param_list, from_synr_span
from .utils import get_param_list, tvm_span_from_synr


class Intrin:
Expand All @@ -29,8 +31,8 @@ def __init__(self, intrin, stmt=False):
def signature(self):
return "tir." + self.intrin.__name__, get_param_list(self.intrin)

def handle(self, arg_list, span):
return self.intrin(*arg_list, span=from_synr_span(span))
def handle(self, arg_list: List[Any], span: tvm.ir.Span):
return self.intrin(*arg_list, span=tvm_span_from_synr(span))


@register
Expand Down Expand Up @@ -98,6 +100,16 @@ def float64(imm, span):
return tvm.tir.Cast("float64", imm, span)


@register
def min_value(dtype, span):
return tvm.tir.min_value(dtype, span)


@register
def max_value(dtype, span):
return tvm.tir.max_value(dtype, span)


@register
def floordiv(x, y, span):
return tvm.tir.floordiv(x, y, span)
Expand Down Expand Up @@ -145,7 +157,7 @@ def get_axis(begin, end, iter_type, span):
block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)

iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span)
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span=span)


@register
Expand Down
150 changes: 150 additions & 0 deletions python/tvm/script/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin
"""TVM Script nodes."""

from typing import Optional, Union, List, Callable
import synr

from tvm.runtime import ObjectGeneric
from tvm.tir import PrimExpr, Buffer, BufferLoad
from tvm.ir import Span


class Slice:
"""A helper class to present slice information for BufferSlice

Parameters
----------
start : Union[PrimExpr, int]
The start index.

stop : Optional[Union[PrimExpr, int]]
The stop index, None means the Slice is an element-wise index

span : Optional[Span]
The location of the slice in the source.
"""

start: Union[PrimExpr, int]
stop: Optional[Union[PrimExpr, int]]
span: Optional[Span]

def __init__(
self,
start: Union[PrimExpr, int],
stop: Optional[Union[PrimExpr, int]] = None,
span: Optional[Span] = None,
):
self.start = start
self.stop = stop
self.span = span


class BufferSlice(ObjectGeneric):
"""A generic object for representing general buffer access. Following cases are supported:
- element wise access buffer[i, j], which can be convert to BufferLoad if necessary
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
- slice access buffer[i: i + 1, j : j + 2]
- union of element and slice buffer[i, j: j + 2]

This node is used in TVMScript to parse BufferLoad, BufferRegion and Realize

Parameters
----------
buffer : Buffer
The buffer.

indices : List[Union[Slice, PrimExpr, int]]
The access indexes can be slice, PrimExpr or int.

report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
The error report func

span : Optional[Span]
The location of the buffer access in the source.
"""

buffer: Buffer
slices: List[Slice]
report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
span: Optional[Span]

def __init__(
self,
buffer: Buffer,
indices: List[Union[Slice, PrimExpr, int]],
report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
span: Optional[Span] = None,
):
def check_index(index: Union[int, PrimExpr]):
""" Check input index is non-negative integer or PrimExpr"""
if isinstance(index, int):
if index < 0:
report_error("Negative index is not allowed during buffer access", span)
elif isinstance(index, PrimExpr):
if index.dtype != "int32":
report_error(
"index expected an int32 type PrimExpr but got " + str(index.dtype),
index.span,
)
else:
report_error(
"Unsupported index type, expected int or tvm.tir.PrimExpr, but got "
+ str(type(index)),
span,
)

slices: List[Slice] = []
for index in indices:
if isinstance(index, Slice):
check_index(index.start)
check_index(index.stop)
slices.append(index)
elif isinstance(index, (PrimExpr, int)):
check_index(index)
slices.append(Slice(index))
else:
report_error(
"Unsupported index type for BufferSlice, "
+ "expected int, tvm.tir.PrimExpr, tvm.tir.Slice, but got "
+ str(type(index)),
span,
)

self.buffer = buffer
self.slices = slices
self.report_error = report_error
self.span = span

def __str__(self):
regions: List[str] = []
for s in self.slices:
if s.stop is None:
regions.append(str(s.start))
else:
regions.append(str(s.start) + ": " + str(s.stop))

return self.buffer.name + "[" + ", ".join(regions) + "]"

def asobject(self) -> BufferLoad:
"""Convert object."""
for s in self.slices:
if s.stop is not None:
self.report_error("BufferLoad only accepts elementwise access", self.span)

indices = [s.start for s in self.slices]
return BufferLoad(self.buffer, indices, span=self.span)
Loading