Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR] Support for match_buffer from subregion #8585

Merged
merged 6 commits into from
Jul 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,12 @@ TVM_DLL Pass CompactBufferAllocation();
*/
TVM_DLL Pass LegalizePackedCalls();

/*!
* \brief Remove match buffers inside the block. Also, it will validate the binding.
* \return The pass.
*/
TVM_DLL Pass LowerMatchBuffer();

/*!
* \brief Flatten the multi-dimensional BufferLoad and BufferStore
* to single dimensional Load/Store. Also remove Block to
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,21 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:

# match_buffers of the block,
# which bind a sub-region of source buffer into a new buffer
D = tir.match_buffer_region(C[vi, vj])
D = tir.match_buffer(C[vi, vj], ())

# init part of the block, executed when all reduce axes are the beginning value
with tir.init():
C[vi, vj] = tir.float32(0)

# block body
CC[0, 0] = A[vi, vk] * B[vj, vk]
D[0, 0] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0]
D[()] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0]
"""

alloc_buffers: List[Buffer] = []
"""List[Buffer]: list of tir.alloc_buffer statements in the block signature"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: list of tir.match_buffer_region statements in the block signature"""
"""List[MatchBufferRegion]: list of tir.match_buffer statements in the block signature"""
iter_bindings: Mapping[Var, PrimExpr] = {}
"""Mapping[Var, PrimExpr]: map of block iter var to its values"""
reads: Optional[List[BufferSlice]] = None
Expand Down
22 changes: 21 additions & 1 deletion python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,10 +784,11 @@ def transform_Slice(self, node):
def transform_Subscript(self, node):
"""Array access visitor.

By now only 2 types of Subscript are supported:
By now only 3 types of Subscript are supported:
1. Buffer[index, index, ...], Buffer element access(BufferLoad & BufferStore)
Var[index] Buffer element access()
2. Buffer[start: stop, start: stop, ...], BufferRealize(realize(buffer[...]))
3. Array[index], Buffer element access
"""

symbol = self.transform(node.params[0])
Expand All @@ -812,6 +813,25 @@ def transform_Subscript(self, node):
return BufferSlice(
symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span)
)
elif isinstance(symbol, tvm.container.Array):
if len(indexes) > 1:
self.report_error(
"Array access should be one-dimension access, but the indices are "
+ str(indexes),
node.span,
)
index = indexes[0]
if not isinstance(index, (int, tvm.tir.expr.IntImm)):
self.report_error(
"Array access index expected int or IntImm, but got " + type(index),
node.span,
)
if int(index) >= len(symbol):
self.report_error(
f"Array access out of bound, size: {len(symbol)}, got index {index}.",
node.span,
)
return symbol[int(index)]
else:
self.report_error(
f"Cannot subscript from a {type(symbol).__name__}. Only variables and "
Expand Down
98 changes: 30 additions & 68 deletions python/tvm/script/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,24 @@ def handle(

@register
class MatchBuffer(SpecialStmt):
"""Special Stmt match_buffer(var, shape, dtype, data, strides, elem_offset, scope, align,
"""Special Stmt match_buffer(param, shape, dtype, data, strides, elem_offset, scope, align,
offset_factor, buffer_type)

Note
----
This Special Stmt will perform different behavior depends on the type of param.
If the param is a var in function parameter, it will create a buffer from DLTensor.
Else if the param is a subregion of other buffers, then create a subregion match inside a block.

Example
-------
Match buffer from function parameter
.. code-block:: python
A = tir.match_buffer(a, (128, 128), dtype="float32")

Match buffer from Buffer subregion
.. code-block:: python
A = tir.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32")
"""

def __init__(self):
Expand All @@ -123,10 +135,6 @@ def match_buffer(
"match_buffer must be assigned to a buffer, e.g. A = match_buffer(...)",
self.node.span,
)
if param not in self.context.func_params:
self.context.report_error(
"Can not bind non-input param to buffer", self.node.rhs.params[0].span
)
if strides is None:
strides = []
align = convert_to_int(align, "align", self.context.report_error, self.node.span)
Expand All @@ -146,7 +154,23 @@ def match_buffer(
buffer_type,
span=span,
)
self.context.func_buffer_map[param] = buffer
if isinstance(param, tvm.tir.Var):
if param not in self.context.func_params:
self.context.report_error(
"Can not bind non-input param to buffer", self.node.rhs.params[0].span
)
self.context.func_buffer_map[param] = buffer
elif isinstance(param, BufferSlice):
buffer_region = buffer_slice_to_region(param)
self.context.current_block_scope().match_buffers.append(
tvm.tir.MatchBufferRegion(buffer, buffer_region)
)
else:
self.context.report_error(
"The source of match_buffer expected Var or BufferSlice, but got "
+ str(type(param)),
self.node.rhs.params[0].span,
)
self.context.update_symbol(self.node.lhs.id.name, buffer, self.node)

super().__init__(match_buffer, def_symbol=True)
Expand Down Expand Up @@ -414,68 +438,6 @@ def where(predicate, span=None):
super().__init__(where, def_symbol=False)


@register
class BlockMatchBufferRegion(SpecialStmt):
"""Special function match_buffer_region(source, strides, elem_offset, align, offset_factor)

Example
-------
.. code-block:: python

B = tir.match_buffer_region(A[0: 4])
"""

def __init__(self):
def match_buffer_region(
source,
strides=None,
elem_offset=None,
align=-1,
offset_factor=0,
span=None,
):
assert self.context, "call 'exit_scope' before 'enter_scope'"
if not isinstance(self.node, ast.Assign):
self.context.report_error(
"match_buffer_region must be assigned to a buffer, "
+ "e.g. A = match_buffer_region(...)",
self.node.span,
)

if strides is None:
strides = []
align = convert_to_int(align, "align", self.context.report_error, self.node.span)
offset_factor = convert_to_int(
offset_factor, "offset_factor", self.context.report_error, self.node.span
)

if not isinstance(source, BufferSlice):
self.context.report_error(
"match_buffer_region needs a buffer region as source",
span=span,
)
buffer_region = buffer_slice_to_region(source)
shape = [r.extent for r in buffer_region.region]
buffer = tvm.tir.decl_buffer(
shape,
buffer_region.buffer.dtype,
self.node.lhs.id.name,
data=None,
strides=strides,
elem_offset=elem_offset,
scope=buffer_region.buffer.scope(),
data_alignment=align,
offset_factor=offset_factor,
span=span,
)
self.context.current_block_scope().match_buffers.append(
tvm.tir.MatchBufferRegion(buffer, buffer_region)
)
self.context.update_symbol(self.node.lhs.id.name, buffer, self.node)

super().__init__(match_buffer_region, def_symbol=True)


@register
class VarDef(SpecialStmt):
"""Special function for defining a Var"""
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def decl_buffer(
dtype = "float32" if dtype is None else dtype
strides = () if strides is None else strides
if offset_factor != 0 and elem_offset is None:
shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32"
shape_dtype = shape[0].dtype if shape and hasattr(shape[0], "dtype") else "int32"
elem_offset = Var("%s_elem_offset" % name, shape_dtype)
if data is None:
# Bool is represented as uint1 in the IR, but stored as int8
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,17 @@ def CompactBufferAllocation():
return _ffi_api.CompactBufferAllocation() # type: ignore


def LowerMatchBuffer():
"""Remove match buffers inside the block. Also, it will validate the binding.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerMatchBuffer() # type: ignore


def FlattenBuffer():
"""Flatten the multi-dimensional BufferLoad and BufferStore
to single dimensional Load/Store. Also remove Block to
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::FlattenBuffer());
}
pass_list.push_back(tir::transform::BF16Legalize());
Expand Down
4 changes: 2 additions & 2 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,8 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) {
<< Print(alloc_buf->shape) << ")" << Doc::NewLine();
}
for (const auto& match_buf : block_op->match_buffers) {
body << AllocBuf(match_buf->buffer) << " = match_buffer_region(" << Print(match_buf->source)
<< ")" << Doc::NewLine();
body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")"
<< Doc::NewLine();
}
if (block_op->init.defined()) {
Doc init_block;
Expand Down
25 changes: 2 additions & 23 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -337,29 +337,8 @@ Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
const Buffer& buf = op->buffer;
buf_not_in_headers.insert(buf.get());

Doc doc = Print(op->buffer) << " = tir.match_buffer_region(" << Print(op->source);
if (!buf->strides.empty()) {
doc << ", strides=" << Print(buf->strides);
}
if (buf->offset_factor != 0 && buf->elem_offset->IsInstance<VarNode>()) {
Var elem_offset = Downcast<Var>(buf->elem_offset);
if (memo_var_.find(elem_offset) != memo_var_.end()) {
doc << ", elem_offset=" << Print(buf->elem_offset);
} else {
// implicitly define elem_offset
memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() + ".elem_offset");
var_not_in_headers.insert(elem_offset.get());
}
} else {
doc << ", elem_offset=" << Print(buf->elem_offset);
}
if (buf->data_alignment != -1) {
doc << ", align=" << buf->data_alignment;
}
if (buf->offset_factor != 0) {
doc << ", offset_factor=" << buf->offset_factor;
}
doc << ")";
Doc doc = Print(op->buffer) << " = tir.match_buffer(" << Print(op->source) << ", "
<< memo_buf_decl_[op->buffer] << ")";
return doc;
}

Expand Down
Loading