Skip to content

Commit

Permalink
[TensorIR] TVMScript Parser/Printer (#317)
Browse files Browse the repository at this point in the history
[TensorIR] TVMScript Parser/Printer

Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Tianqi Chen <tqchen@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
7 people committed Mar 10, 2021
1 parent 3a0e3a5 commit 99a62d1
Show file tree
Hide file tree
Showing 17 changed files with 2,256 additions and 194 deletions.
15 changes: 15 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
*/
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);

/*!
* \brief Auto detect the block read/write region according to body stmt
* It will detect the read/write region as an array in order of appearance in AST
* \param block The block to be detected
* \param buffer_var_map The outside buffers which may be accessed the block.
* It is a map from buffer var to the buffer.
* \return Array of access regions.
* There are three arrays of BufferRegion:
* - first: read regions
* - second: write regions
* - third: opaque regions
*/
Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);

// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {
Expand Down
115 changes: 97 additions & 18 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,59 +16,138 @@
# under the License.
"""TVM Script Context Maintainer for TIR"""

from tvm.te import schedule
from typing import List, Mapping, Union, Optional, Dict, Callable
import synr


import tvm
from tvm.ir import Span
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
from .node import BufferSlice


class BlockInfo:
"""Information for block and block_realize signature"""

alloc_buffers: List[Buffer]
match_buffers: List[MatchBufferRegion]
iter_bindings: Mapping[Var, PrimExpr]
reads: Optional[List[BufferSlice]]
writes: Optional[List[BufferSlice]]
annotations: Optional[Mapping[str, Object]]
predicate: Optional[PrimExpr]
init: Optional[Stmt]

def __init__(self):
self.alloc_buffers = []
self.match_buffers = []
self.iter_bindings = {}
self.reads = None
self.writes = None
self.annotations = None
self.predicate = None
self.init = None


class ContextMaintainer:
"""Maintain all the necessary context info"""

def __init__(self, parser):
# scope context
node_stack: List[List[synr.ast.Node]]
block_info_stack: List[BlockInfo]
loop_stack: List[List[Var]]
symbols: List[Dict[str, Union[Var, Buffer]]]
# function context
func_params: List[Var]
func_buffer_map: Mapping[Var, Buffer]
func_dict_attr: Mapping[str, Object]
func_var_env_dict: Mapping[Var, str]
# parser and analyzer
_report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
analyzer: tvm.arith.Analyzer

def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], None]):
# scope context
self.node_stack = [] # AST nodes of scopes
self.block_info_stack = [] # Block info of scopes
self.loop_stack = [] # stack of loop vars
self.symbols = [] # symbols of scopes
# function context
self.func_params = [] # parameter list of function
self.func_buffer_map = {} # buffer_map of function
self.func_dict_attr = {} # func_attr of function
self.func_var_env_dict = {} # map from var to env_name
# parser
self.parser = parser
# parser and analyzer
self._report_error = _report_error
self.analyzer = tvm.arith.Analyzer()

def pop_scope(self):
"""Pop the inner most scope"""
self.symbols.pop()
self.node_stack.pop()
def enter_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
"""Creating a new scope
def new_scope(self, nodes=None):
"""Creating a new scope"""
Parameters
----------
nodes : Optional[List[synr.ast.Node]]
The synr AST nodes in new scope
"""
if nodes is None:
nodes = []
self.node_stack.append(list(reversed(nodes)))
self.symbols.append(dict())

def update_symbol(self, name, symbol):
def enter_block_scope(self, nodes: Optional[List[synr.ast.Node]] = None):
"""Creating a new block scope, the function will call `enter_scope` implicitly
Parameters
----------
nodes : Optional[List[synr.ast.Node]]
The synr AST nodes in new scope
"""
self.enter_scope(nodes)
# Create a new loop stack for the new block
self.loop_stack.append([])
# Create a new BlockInfo for the new block
self.block_info_stack.append(BlockInfo())

def exit_scope(self):
"""Pop the inner most scope"""
self.symbols.pop()
self.node_stack.pop()

def exit_block_scope(self):
"""Pop the inner most block scope, the function will call `exit_scope` implicitly"""
self.exit_scope()
# Pop loop stack
self.loop_stack.pop()
# Pop block_info
self.block_info_stack.pop()

def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node):
"""Append a symbol into current scope"""
if isinstance(symbol, schedule.Buffer):
if isinstance(symbol, Buffer):
if name in self.symbols[0]:
self.parser.report_error("Duplicate Buffer name")
self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
self.symbols[0][name] = symbol
else:
self.symbols[-1][name] = symbol

def remove_symbol(self, name):
def remove_symbol(self, name: str):
"""Remove a symbol"""
for symbols in reversed(self.symbols):
if name in symbols:
symbols.pop(name)
return
raise RuntimeError("Internal error of tvm script parser: no symbol named" + name)
raise RuntimeError("Internal error of tvm script parser: no symbol named " + name)

def lookup_symbol(self, name):
def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
"""Look up symbol by name"""
for symbols in reversed(self.symbols):
if name in symbols:
return symbols[name]
return None

def report_error(self, message, span):
self.parser.report_error(message, span)
def report_error(self, message: str, span: Union[Span, synr.ast.Span]):
self._report_error(message, span)

def current_block_scope(self) -> BlockInfo:
return self.block_info_stack[-1]
20 changes: 16 additions & 4 deletions python/tvm/script/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.
"""TVM Script Parser Intrinsic Classes"""
# pylint: disable=redefined-builtin, relative-beyond-top-level
from typing import List, Any

import tvm.tir
from .registry import register
from .utils import get_param_list, from_synr_span
from .utils import get_param_list, tvm_span_from_synr


class Intrin:
Expand All @@ -29,8 +31,8 @@ def __init__(self, intrin, stmt=False):
def signature(self):
return "tir." + self.intrin.__name__, get_param_list(self.intrin)

def handle(self, arg_list, span):
return self.intrin(*arg_list, span=from_synr_span(span))
def handle(self, arg_list: List[Any], span: tvm.ir.Span):
return self.intrin(*arg_list, span=tvm_span_from_synr(span))


@register
Expand Down Expand Up @@ -98,6 +100,16 @@ def float64(imm, span):
return tvm.tir.Cast("float64", imm, span)


@register
def min_value(dtype, span):
return tvm.tir.min_value(dtype, span)


@register
def max_value(dtype, span):
return tvm.tir.max_value(dtype, span)


@register
def floordiv(x, y, span):
return tvm.tir.floordiv(x, y, span)
Expand Down Expand Up @@ -145,7 +157,7 @@ def get_axis(begin, end, iter_type, span):
block_var_dom = tvm.ir.Range.from_min_extent(begin, extent)

iter_type_dict = {"data_par": 0, "reduce": 2, "scan": 3, "opaque": 4}
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span)
return tvm.tir.IterVar(block_var_dom, "bv", iter_type_dict[iter_type], span=span)


@register
Expand Down
150 changes: 150 additions & 0 deletions python/tvm/script/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-builtin
"""TVM Script nodes."""

from typing import Optional, Union, List, Callable
import synr

from tvm.runtime import ObjectGeneric
from tvm.tir import PrimExpr, Buffer, BufferLoad
from tvm.ir import Span


class Slice:
"""A helper class to present slice information for BufferSlice
Parameters
----------
start : Union[PrimExpr, int]
The start index.
stop : Optional[Union[PrimExpr, int]]
The stop index, None means the Slice is an element-wise index
span : Optional[Span]
The location of the slice in the source.
"""

start: Union[PrimExpr, int]
stop: Optional[Union[PrimExpr, int]]
span: Optional[Span]

def __init__(
self,
start: Union[PrimExpr, int],
stop: Optional[Union[PrimExpr, int]] = None,
span: Optional[Span] = None,
):
self.start = start
self.stop = stop
self.span = span


class BufferSlice(ObjectGeneric):
"""A generic object for representing general buffer access. Following cases are supported:
- element wise access buffer[i, j], which can be convert to BufferLoad if necessary
- slice access buffer[i: i + 1, j : j + 2]
- union of element and slice buffer[i, j: j + 2]
This node is used in TVMScript to parse BufferLoad, BufferRegion and Realize
Parameters
----------
buffer : Buffer
The buffer.
indices : List[Union[Slice, PrimExpr, int]]
The access indexes can be slice, PrimExpr or int.
report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
The error report func
span : Optional[Span]
The location of the buffer access in the source.
"""

buffer: Buffer
slices: List[Slice]
report_error: Callable[[str, Union[Span, synr.ast.Span]], None]
span: Optional[Span]

def __init__(
self,
buffer: Buffer,
indices: List[Union[Slice, PrimExpr, int]],
report_error: Callable[[str, Union[Span, synr.ast.Span]], None],
span: Optional[Span] = None,
):
def check_index(index: Union[int, PrimExpr]):
""" Check input index is non-negative integer or PrimExpr"""
if isinstance(index, int):
if index < 0:
report_error("Negative index is not allowed during buffer access", span)
elif isinstance(index, PrimExpr):
if index.dtype != "int32":
report_error(
"index expects an int32 type PrimExpr but gets " + str(index.dtype),
index.span,
)
else:
report_error(
"Unsupported index type, expects int or tvm.tir.PrimExpr, but gets "
+ str(type(index)),
span,
)

slices: List[Slice] = []
for index in indices:
if isinstance(index, Slice):
check_index(index.start)
check_index(index.stop)
slices.append(index)
elif isinstance(index, (PrimExpr, int)):
check_index(index)
slices.append(Slice(index))
else:
report_error(
"Unsupported index type for BufferSlice, "
+ "expects int, tvm.tir.PrimExpr, tvm.tir.Slice, but gets "
+ str(type(index)),
span,
)

self.buffer = buffer
self.slices = slices
self.report_error = report_error
self.span = span

def __str__(self):
regions: List[str] = []
for s in self.slices:
if s.stop is None:
regions.append(str(s.start))
else:
regions.append(str(s.start) + ": " + str(s.stop))

return self.buffer.name + "[" + ", ".join(regions) + "]"

def asobject(self) -> BufferLoad:
"""Convert object."""
for s in self.slices:
if s.stop is not None:
self.report_error("BufferLoad only accepts elementwise access", self.span)

indices = [s.start for s in self.slices]
return BufferLoad(self.buffer, indices, span=self.span)
Loading

0 comments on commit 99a62d1

Please sign in to comment.