Skip to content

Commit

Permalink
[TVMScript] IRBuilder methods for Axis (#12808)
Browse files Browse the repository at this point in the history
This PR introduces remaining IRBuilder methods for `Axis`.

Co-authored-by: yongwww <yongcale@gmail.com>
  • Loading branch information
cyx-6 and yongwww authored Sep 16, 2022
1 parent 77d0a28 commit c0d2734
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 1 deletion.
49 changes: 49 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,55 @@ void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
*/
BlockFrame Block(String name, bool no_realize = false);

namespace axis {

/*!
* \brief The spatial block axis defining function.
* \param dom The domain of the iteration variable.
* \param binding The binding value of the iteration variable.
* \param dtype The data type of the iteration variable.
* \return The iteration variable.
*/
Var Spatial(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));

/*!
* \brief The reduced block axis defining function.
* \param dom The domain of the iteration variable.
* \param binding The binding value of the iteration variable.
* \param dtype The data type of the iteration variable.
* \return The iteration variable.
*/
Var Reduce(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));

/*!
* \brief The scanning block axis defining function.
* \param dom The domain of the iteration variable.
* \param binding The binding value of the iteration variable.
* \param dtype The data type of the iteration variable.
* \return The iteration variable.
*/
Var Scan(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));

/*!
* \brief The opaque block axis defining function.
* \param dom The domain of the iteration variable.
* \param binding The binding value of the iteration variable.
* \param dtype The data type of the iteration variable.
* \return The iteration variable.
*/
Var Opaque(Range dom, PrimExpr binding, DataType dtype = DataType::Int(32));

/*!
* \brief The block axis remapping function.
* \param kinds The types of the iteration variables.
* \param bindings The binding values of the iteration variables.
* \param dtype The data types of the iteration variables.
* \return The iteration variables.
*/
Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype = DataType::Int(32));

} // namespace axis

/*!
* \brief The serial For statement.
* \param start The minimum value of iteration.
Expand Down
157 changes: 156 additions & 1 deletion python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from numbers import Integral
from typing import Any, Dict, List, Optional, Union, Tuple

from tvm.ir import Type
from tvm.ir import Range, Type
from tvm.tir import (
Buffer,
BufferLoad,
Expand Down Expand Up @@ -344,6 +344,160 @@ def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
return _ffi_api.Block(name, no_realize) # pylint: disable=no-member # type: ignore


def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range:
"""The range constructor.
Parameters
----------
dom : Union[Range, List[PrimExpr]]
The domain.
Returns
-------
res : Range
The Range.
"""
if isinstance(dom, Range):
return dom
if isinstance(dom, (list, tuple)):
return Range(dom[0], dom[1])
return Range(0, dom)


class axis: # pylint: disable=invalid-name
@staticmethod
def spatial(
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
) -> Var:
"""The spatial block axis defining function.
Parameters
----------
dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]]
The domain of the iteration variable.
binding : PrimExpr
The binding value of the iteration variable.
dtype : str
The data type of the iteration variable.
Returns
-------
res : Var
The iteration variable.
"""
return _ffi_api.AxisSpatial( # pylint: disable=no-member # type: ignore
_as_range(dom), binding, dtype
)

@staticmethod
def reduce(
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
) -> Var:
"""The reduced block axis defining function.
Parameters
----------
dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]]
The domain of the iteration variable.
binding : PrimExpr
The binding value of the iteration variable.
dtype : str
The data type of the iteration variable.
Returns
-------
res : Var
The iteration variable.
"""
return _ffi_api.AxisReduce( # pylint: disable=no-member # type: ignore
_as_range(dom), binding, dtype
)

@staticmethod
def scan(
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
) -> Var:
"""The scanning block axis defining function.
Parameters
----------
dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]]
The domain of the iteration variable.
binding : PrimExpr
The binding value of the iteration variable.
dtype : str
The data type of the iteration variable.
Returns
-------
res : Var
The iteration variable.
"""
return _ffi_api.AxisScan( # pylint: disable=no-member # type: ignore
_as_range(dom), binding, dtype
)

@staticmethod
def opaque(
dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32"
) -> Var:
"""The opaque block axis defining function.
Parameters
----------
dom : Union[Range, List[PrimExpr], Tuple[PrimExpr]]
The domain of the iteration variable.
binding : PrimExpr
The binding value of the iteration variable.
dtype : str
The data type of the iteration variable.
Returns
-------
res : Var
The iteration variable.
"""
return _ffi_api.AxisOpaque( # pylint: disable=no-member # type: ignore
_as_range(dom), binding, dtype
)

@staticmethod
def remap(kinds: str, bindings: List[PrimExpr], dtype: str = "int32") -> Union[List[Var], Var]:
"""The block axis remapping function.
Parameters
----------
kinds : str
The types of the iteration variables.
bindings : List[PrimExpr]
The binding values of the iteration variables.
dtype : str
The data types of the iteration variables.
Returns
-------
res : Var
The iteration variables.
"""
iter_vars = _ffi_api.AxisRemap( # pylint: disable=no-member # type: ignore
kinds, bindings, dtype
)
return iter_vars[0] if len(iter_vars) == 1 else iter_vars

S = spatial # pylint: disable=invalid-name
R = reduce # pylint: disable=invalid-name


def serial(
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
) -> frame.ForFrame:
Expand Down Expand Up @@ -843,6 +997,7 @@ def var(dtype, name="") -> Var:
"match_buffer",
"preflattened_buffer",
"block",
"axis",
"serial",
"parallel",
"vectorized",
Expand Down
86 changes: 86 additions & 0 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,86 @@ BlockFrame Block(String name, bool no_realize) {
return BlockFrame(n);
}

namespace axis {

IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) {
if (Optional<BlockFrame> opt_frame = IRBuilder::Current()->GetLastFrame<BlockFrame>()) {
BlockFrame frame = opt_frame.value();
frame->iter_vars.push_back(iter_var);
frame->iter_values.push_back(binding);
} else {
LOG(FATAL) << "TypeError: The last frame is not BlockFrame";
}
return iter_var;
}

#define TVM_TIR_IR_BUILDER_AXIS(Method, Kind, Name) \
Var Method(Range dom, PrimExpr binding, DataType dtype) { \
ICHECK(dom.defined()) << Name << " axis must have a domain"; \
int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); \
return PushBlockVar(IterVar(/*dom=*/dom, /*var=*/Var("", dtype.with_bits(bits)), \
/*iter_type=*/Kind, /*thread_tag=*/""), \
binding) \
->var; \
}
TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tir::IterVarType::kDataPar, "Spatial");
TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tir::IterVarType::kCommReduce, "Reduction");
TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan");
TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque");
#undef TVM_TIR_IR_BUILDER_AXIS

Array<Var> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) {
using namespace tvm::tir;
Array<Var> results;
ICHECK_EQ(kinds.size(), bindings.size());
int n = bindings.size();
results.reserve(n);
for (int i = 0; i < n; ++i) {
char c = kinds.c_str()[i];
PrimExpr e = bindings[i];
const VarNode* v = e.as<VarNode>();
ICHECK(v) << "TypeError: Only Var is supported in T.axis.remap";
Range dom{nullptr};
for (const auto& frame : IRBuilder::Current()->frames) {
if (const auto* for_frame = frame.as<ForFrameNode>()) {
ICHECK_EQ(for_frame->doms.size(), for_frame->vars.size());
int n = for_frame->doms.size();
for (int i = 0; i < n; ++i) {
if (for_frame->vars[i].get() == v) {
dom = for_frame->doms[i];
break;
}
}
if (dom.defined()) {
break;
}
}
}
ICHECK(dom.defined()) << "TypeError: Variable is not in the loop: " << GetRef<Var>(v);
DataType dtype = v->dtype;
if (c == 'S') {
results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
/*var=*/Var("", dtype),
/*iter_type=*/IterVarType::kDataPar,
/*thread_tag=*/""),
e)
->var);
} else if (c == 'R') {
results.push_back(PushBlockVar(IterVar(/*dom=*/dom,
/*var=*/Var("", dtype),
/*iter_type=*/IterVarType::kCommReduce,
/*thread_tag=*/""),
e)
->var);
} else {
LOG(FATAL) << "Unknown axis kind: " << c;
}
}
return results;
}

} // namespace axis

#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \
ForFrame Method(PrimExpr start, PrimExpr stop, Optional<Map<String, ObjectRef>> annotations) { \
PrimExpr min = start; \
Expand Down Expand Up @@ -304,6 +384,12 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.PreflattenedBuffer").set_body_typed(P

TVM_REGISTER_GLOBAL("script.ir_builder.tir.Block").set_body_typed(Block);

TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisSpatial").set_body_typed(axis::Spatial);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisReduce").set_body_typed(axis::Reduce);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisScan").set_body_typed(axis::Scan);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisOpaque").set_body_typed(axis::Opaque);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.AxisRemap").set_body_typed(axis::Remap);

TVM_REGISTER_GLOBAL("script.ir_builder.tir.Serial").set_body_typed(Serial);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Parallel").set_body_typed(Parallel);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Vectorized").set_body_typed(Vectorized);
Expand Down
43 changes: 43 additions & 0 deletions tests/python/unittest/test_tvmscript_ir_builder_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,49 @@ def test_ir_builder_tir_block():
assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)


def test_ir_builder_tir_axis():
with IRBuilder() as ib:
a = T.var("int32", "a")
b = T.var("int32", "b")
c = T.var("int32", "c")
d = T.var("int32", "d")
with T.block("block"):
T.axis.spatial(8, a)
T.axis.reduce(16, b)
T.axis.scan(32, c)
T.axis.opaque(64, d)
T.evaluate(0)

# the block generated by IRBuilder
block_realize_actual = ib.get()

# the expected block
var_a = tir.Var("a", "int32")
var_b = tir.Var("b", "int32")
var_c = tir.Var("c", "int32")
var_d = tir.Var("d", "int32")
block_expected = tir.Block(
iter_vars=[
tir.IterVar((0, 8), tir.Var("", "int32"), iter_type=tir.IterVar.DataPar),
tir.IterVar((0, 16), tir.Var("", "int32"), iter_type=tir.IterVar.CommReduce),
tir.IterVar((0, 32), tir.Var("", "int32"), iter_type=tir.IterVar.Ordered),
tir.IterVar((0, 64), tir.Var("", "int32"), iter_type=tir.IterVar.DimInfo),
],
reads=[],
writes=[],
name_hint="block",
body=tir.Evaluate(0),
annotations={"tir.script_parsing_detect_access": tir.IntImm("int64", 3)},
)
block_realize_expected = tir.BlockRealize(
iter_values=[var_a, var_b, var_c, var_d],
predicate=True,
block=block_expected,
)
# Check if the generated ir is expected
assert_structural_equal(block_realize_actual, block_realize_expected, map_free_vars=True)


def test_ir_builder_tir_for():
with IRBuilder() as ib:
with T.serial(128) as a:
Expand Down

0 comments on commit c0d2734

Please sign in to comment.