Skip to content

Commit

Permalink
[TVMScript] Script namespace changes (#9115)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Zihao Ye <zihaoye.cs@gmail.com>
Co-authored-by: Tristan Konolige <tristan.konolige@gmail.com>
  • Loading branch information
4 people authored Oct 1, 2021
1 parent 4f6b2a1 commit e7af601
Show file tree
Hide file tree
Showing 80 changed files with 6,021 additions and 5,899 deletions.
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ pip3 install \
pytest-xdist \
requests \
scipy \
synr==0.4.0 \
synr==0.4.1 \
six \
tornado
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_sphinx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ pip3 install \
matplotlib \
sphinx \
sphinx_autodoc_annotation \
sphinx-gallery==0.4.0 \
sphinx-gallery==0.4.1 \
sphinx_rtd_theme
20 changes: 10 additions & 10 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,12 @@ class LinkedParam : public ObjectRef {
* \note We can define a Meta TIR function with symbolic shape:
*
* \code
* @tvm.script.tir
* def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None:
* A = tir.match_buffer(a, (m, n), "float32")
* B = tir.match_buffer(b, (m, n), "float32")
* @T.prim_func
* def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None:
* A = T.match_buffer(a, (m, n), "float32")
* B = T.match_buffer(b, (m, n), "float32")
*
* with tir.block([m, n], "") as [vi, vj]:
* with T.block([m, n], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* \endcode
*
Expand All @@ -214,12 +214,12 @@ class LinkedParam : public ObjectRef {
* \endcode
*
* \code {.language-id}
* @tvm.script.tir
* def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None:
* A = tir.match_buffer(a, (16, 16), "float32")
* B = tir.match_buffer(b, (16, 16), "float32")
* @T.prim_func
* def mem_copy_16_16(a: T.handle, b: T.handle) -> None:
* A = T.match_buffer(a, (16, 16), "float32")
* B = T.match_buffer(b, (16, 16), "float32")
*
* with tir.block([16, 16], "") as [vi, vj]:
* with T.block([16, 16], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* \endcode
*/
Expand Down
20 changes: 10 additions & 10 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1068,17 +1068,17 @@ class MatchBufferRegion : public ObjectRef {
* \note Block's body is parameterized by iter vars.
* \code
*
* with tir.block([extent0, extent1, ...], name) as [v0, v1, ...]:
* tir.bind(v0, value0)
* tir.bind(v1, value1)
* with T.block([extent0, extent1, ...], name) as [v0, v1, ...]:
* T.bind(v0, value0)
* T.bind(v1, value1)
* ...
* tir.reads([buffer0[start:end, ...], ...])
* tir.writes([buffer1[start:end, ...], ...])
* tir.where(predicate)
* buffer2 = tir.alloc_buffer(shape, dtype)
* buffer3 = tir.match_buffer(source_buffer[start:end, ...])
* tir.attr({attr_key: attr_value, ...})
* with tir.init():
* T.reads([buffer0[start:end, ...], ...])
* T.writes([buffer1[start:end, ...], ...])
* T.where(predicate)
* buffer2 = T.alloc_buffer(shape, dtype)
* buffer3 = T.match_buffer(source_buffer[start:end, ...])
* T.attr({attr_key: attr_value, ...})
* with T.init():
* // init body
* // body
*
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ TVM_DLL Pass ConvertBlocksToOpaque();
* \code
*
* for i in range(0, 16):
* with tir.block([]):
* B = tir.alloc_buffer(16, 16)
* with T.block([]):
* B = T.alloc_buffer(16, 16)
* for j in range(0, 16):
* B[i, j] = A[i, j] + 1
* for j in range(0, 16):
Expand All @@ -404,8 +404,8 @@ TVM_DLL Pass ConvertBlocksToOpaque();
* \code
*
* for i in range(0, 16):
* with tir.block([]):
* B = tir.alloc_buffer(1, 16)
* with T.block([]):
* B = T.alloc_buffer(1, 16)
* for j in range(0, 16):
* B[0, j] = A[i, j] + 1
* for j in range(0, 16):
Expand Down
2 changes: 1 addition & 1 deletion python/gen_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@
("sphinx_autodoc_annotation", None),
("sphinx_gallery", None),
("sphinx_rtd_theme", None),
("synr", "==0.4.0"),
("synr", "==0.4.1"),
("tensorflow", None),
("tensorflow-estimator", None),
("tflite", None),
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,23 @@ def __str__(self):

def __repr__(self):
return self.astext()

def script(self, tir_prefix: str = "tir", show_meta: bool = False) -> str:
"""Print IRModule into TVMScript
Parameters
----------
tir_prefix : str
The tir namespace prefix
show_meta : bool
Whether to show meta information
Returns
-------
script : str
The TVM Script of the IRModule
"""
return tvm._ffi.get_global_func("script.AsTVMScript")(
self, tir_prefix, show_meta
) # type: ignore
4 changes: 3 additions & 1 deletion python/tvm/script/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@
# under the License.
"""TVM Script APIs of TVM Python Package, aimed to support TIR"""

from .parser import from_source, create_module, asscript, tir, module
from . import tir

from .parser import ir_module, from_source
48 changes: 24 additions & 24 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tvm.ir import Span
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
from .node import BufferSlice
from .tir.node import BufferSlice


class BlockInfo:
Expand All @@ -34,55 +34,55 @@ class BlockInfo:
----------
.. code-block:: python
@tvm.script.tir
def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (16, 16), "float32")
B = tir.match_buffer(b, (16, 16), "float32")
C = tir.match_buffer(a, (16, 16), "float32")
@T.prim_func
def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 16), "float32")
B = T.match_buffer(b, (16, 16), "float32")
C = T.match_buffer(a, (16, 16), "float32")
for i, j, k in tir.grid(16, 16, 16):
with tir.block([16, 16, tir.reduce_axis(16)], "matmul") as [vi, vj, vk]:
tir.bind(vi, i)
tir.bind(vj, j)
tir.bind(vk, k) # iter_bindings = {vj: i, vj: j, vk: k}
for i, j, k in T.grid(16, 16, 16):
with T.block([16, 16, T.reduce_axis(16)], "matmul") as [vi, vj, vk]:
T.bind(vi, i)
T.bind(vj, j)
T.bind(vk, k) # iter_bindings = {vj: i, vj: j, vk: k}
tir.where(True) # predicate of the block_realize
T.where(True) # predicate of the block_realize
tir.reads(A[0:16, 0:16], B[0: 16, 0: 16]) # reads region of the block
tir.writes(C[0: 16, 0: 16]) # writes region of the block
tir.block_attr({"attr_key": "attr_value"}) # block annotations
T.reads(A[0:16, 0:16], B[0: 16, 0: 16]) # reads region of the block
T.writes(C[0: 16, 0: 16]) # writes region of the block
T.block_attr({"attr_key": "attr_value"}) # block annotations
# alloc_buffers inside the block
CC = tir.alloc_buffer((1, 1), dtype="float32")
CC = T.alloc_buffer((1, 1), dtype="float32")
# match_buffers of the block,
# which bind a sub-region of source buffer into a new buffer
D = tir.match_buffer(C[vi, vj], ())
D = T.match_buffer(C[vi, vj], ())
# init part of the block, executed when all reduce axes are the beginning value
with tir.init():
C[vi, vj] = tir.float32(0)
with T.init():
C[vi, vj] = T.float32(0)
# block body
CC[0, 0] = A[vi, vk] * B[vj, vk]
D[()] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0]
"""

alloc_buffers: List[Buffer] = []
"""List[Buffer]: list of tir.alloc_buffer statements in the block signature"""
"""List[Buffer]: list of T.alloc_buffer statements in the block signature"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: list of tir.match_buffer statements in the block signature"""
"""List[MatchBufferRegion]: list of T.match_buffer statements in the block signature"""
iter_bindings: Mapping[Var, PrimExpr] = {}
"""Mapping[Var, PrimExpr]: map of block iter var to its values"""
reads: Optional[List[BufferSlice]] = None
"""Optional[List[BufferSlice]]:
list of tir.reads statements in the block signature, None for not-visited"""
list of T.reads statements in the block signature, None for not-visited"""
writes: Optional[List[BufferSlice]] = None
"""Optional[List[BufferSlice]]:
list of tir.writes statements in the block signature, None for not-visited"""
list of T.writes statements in the block signature, None for not-visited"""
annotations: Optional[Mapping[str, Object]] = None
"""Optional[Mapping[str, Object]]:
list of tir.block_attr statements in the block signature, None for not-visited"""
list of T.block_attr statements in the block signature, None for not-visited"""
predicate: Optional[PrimExpr] = None
"""Optional[PrimExpr]: block realize predicate, None for not-visited"""
init: Optional[Stmt] = None
Expand Down
Loading

0 comments on commit e7af601

Please sign in to comment.