Skip to content

Commit

Permalink
[Backport] MatchBuffer, BufferLocator & GetBlockReadWriteRegion (#460)
Browse files Browse the repository at this point in the history
* [TensorIR] Support for match_buffer from subregion (#8585)

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
# Conflicts:
#	python/tvm/script/special_stmt.py
#	python/tvm/tir/transform/transform.py
#	src/tir/analysis/block_access_region_detector.cc
#	src/tir/analysis/buffer_access_lca_detector.cc
#	src/tir/transforms/lower_match_buffer.cc
#	tests/python/integration/test_lower.py
#	tests/python/unittest/test_tir_analysis_detect_buffer_access_lca.py
#	tests/python/unittest/test_tir_analysis_get_block_access_region.py
#	tests/python/unittest/test_tir_lower_match_buffer.py
#	tests/python/unittest/test_tir_transform_compact_buffer_region.py
#	tests/python/unittest/test_tvmscript_error_report.py

* [TIR] Fix opaque access in buffer locator pass and match_buffer in region detector (#8855)

* init

* fix

* Update src/tir/transforms/plan_update_buffer_allocation_location.cc

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* Update src/tir/transforms/plan_update_buffer_allocation_location.cc

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* address

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* [TIR] GetBlockReadWriteRegion (#8875)

* [TIR] GetBlockReadWriteRegion

* Fix black issue

* Use constant reference for the interface

* Fix lint issue

* Catch the correct error class in logical layout test

Co-authored-by: Siyuan Feng <hzfengsy@vip.qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
  • Loading branch information
3 people authored Aug 30, 2021
1 parent 5cd4706 commit 0ef2897
Show file tree
Hide file tree
Showing 27 changed files with 576 additions and 139 deletions.
19 changes: 15 additions & 4 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);

/*!
* \brief Auto detect the block read/write region according to body stmt
* It will detect the read/write region as an array in order of appearance in AST
* \brief Auto detect the block access region according to its body stmt
* It will detect the access region as an array in order of appearance in AST
* \param block The block to be detected
* \param buffer_var_map The outside buffers which may be accessed the block.
* It is a map from buffer var to the buffer.
Expand All @@ -167,8 +167,19 @@ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constrain
* - second: write regions
* - third: opaque regions
*/
Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);
TVM_DLL Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);

/*!
* \brief Auto detect the block read/write region according to its body stmt. An opaque access will
* be counted as both a read and a write access
* \param block The block to be detected
* \param buffer_var_map The outside buffers which may be accessed the block.
* It is a map from buffer var to the buffer
* \return An array only consisting of the read regions and write regions of the input block
*/
TVM_DLL Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map);

/*!
* \brief Calculate the expresion complexity based on number of symbols it contains.
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,21 @@ def example_func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
# match_buffers of the block,
# which bind a sub-region of source buffer into a new buffer
D = tir.match_buffer_region(C[vi, vj])
D = tir.match_buffer(C[vi, vj], ())
# init part of the block, executed when all reduce axes are the beginning value
with tir.init():
C[vi, vj] = tir.float32(0)
# block body
CC[0, 0] = A[vi, vk] * B[vj, vk]
D[0, 0] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0]
D[()] += CC[0, 0] # The same as C[vi, vj] += CC[0, 0]
"""

alloc_buffers: List[Buffer] = []
"""List[Buffer]: list of tir.alloc_buffer statements in the block signature"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: list of tir.match_buffer_region statements in the block signature"""
"""List[MatchBufferRegion]: list of tir.match_buffer statements in the block signature"""
iter_bindings: Mapping[Var, PrimExpr] = {}
"""Mapping[Var, PrimExpr]: map of block iter var to its values"""
reads: Optional[List[BufferSlice]] = None
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class MatchBuffer(SpecialStmt):
Match buffer from Buffer subregion
.. code-block:: python
A = tir.match_buffer(, (128, 128), dtype="float32")
A = tir.match_buffer(B[0:128, i * 128 : i * 128 + 128], (128, 128), dtype="float32")
"""

def __init__(self):
Expand Down
24 changes: 23 additions & 1 deletion python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,29 @@ def get_block_access_region(
- second: write regions
- third: opaque regions
"""
return _ffi_api.get_block_access_region(block, buffer_var_map) # type: ignore
return _ffi_api.GetBlockAccessRegion(block, buffer_var_map) # type: ignore


def get_block_read_write_region(
block: Block, buffer_var_map: Dict[Var, Buffer]
) -> List[List[BufferRegion]]:
"""Auto detect the block read/write region according to its body stmt.
An opaque access will be counted as both a read and a write access
Parameters
----------
block: tvm.tir.Block
The block in which we are detecting read/write regions.
buffer_var_map : Dict[Var, Buffer]
The outside buffers which may access the block. Mapping from buffer var to the buffer
Returns
-------
result : List[List[BufferRegion]]
An array only consisting of the read regions and write regions of the input block
"""
return _ffi_api.GetBlockReadWriteRegion(block, buffer_var_map) # type: ignore


def calculate_workspace_bytes(func: PrimFunc, workspace_byte_alignment: int) -> int:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def LowerMatchBuffer():
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerMatchBuffer()
return _ffi_api.LowerMatchBuffer() # type: ignore


def FlattenBuffer():
Expand Down
4 changes: 2 additions & 2 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -581,8 +581,8 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) {
<< Print(alloc_buf->shape) << ")" << Doc::NewLine();
}
for (const auto& match_buf : block_op->match_buffers) {
body << AllocBuf(match_buf->buffer) << " = match_buffer_region(" << Print(match_buf->source)
<< ")" << Doc::NewLine();
body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")"
<< Doc::NewLine();
}
if (block_op->init.defined()) {
Doc init_block;
Expand Down
44 changes: 40 additions & 4 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

#include "../transforms/ir_utils.h"
namespace tvm {
namespace tir {

Expand Down Expand Up @@ -109,8 +110,11 @@ void BlockReadWriteDetector::operator()(const Stmt& stmt) {
ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey();
for (const MatchBufferRegion& match_buffer : block->match_buffers) {
const Var& target_var = match_buffer->buffer->data;
match_buffers_[target_var.get()] = match_buffer;
buffer_var_map_.Set(target_var, match_buffer->buffer);
const Var& source_var = match_buffer->source->buffer->data;
if (buffer_var_map_.find(source_var) != buffer_var_map_.end()) {
match_buffers_[target_var.get()] = match_buffer;
buffer_var_map_.Set(target_var, match_buffer->buffer);
}
}
StmtExprVisitor::operator()(stmt);
}
Expand Down Expand Up @@ -204,7 +208,7 @@ std::vector<arith::IntSet> BlockReadWriteDetector::ConvertMatchedRegion(
region.push_back(int_set.CoverRange(Range::FromMinExtent(0, buffer->shape[i])));
}

region = match_buffer.ConvertRegion(region);
region = ConvertRegion(match_buffer, region);

std::vector<arith::IntSet> result;
result.reserve(region.size());
Expand Down Expand Up @@ -281,7 +285,39 @@ Array<Array<BufferRegion>> GetBlockAccessRegion(const Block& block,
return {detector.CollectReads(), detector.CollectWrites(), detector.CollectOpaques()};
}

TVM_REGISTER_GLOBAL("tir.analysis.get_block_access_region").set_body_typed(GetBlockAccessRegion);
Array<Array<BufferRegion>> GetBlockReadWriteRegion(const Block& block,
const Map<Var, Buffer>& buffer_var_map) {
// Step 1. Get all the read/write/opaque accesses in the input block.
Array<Array<BufferRegion>> access_regions = GetBlockAccessRegion(block, buffer_var_map);
// Step 2. Collect all the buffers that are opaquely accessed.
std::unordered_set<const BufferNode*> opaque_accessed_buffers;
for (const BufferRegion& opaque_access : access_regions[2]) {
opaque_accessed_buffers.insert(opaque_access->buffer.get());
}
// Step 3. Create new arrays of read/write regions.
Array<BufferRegion> new_read_regions;
Array<BufferRegion> new_write_regions;
new_read_regions.reserve(access_regions[0].size() + access_regions[2].size());
new_write_regions.reserve(access_regions[1].size() + access_regions[2].size());
for (const BufferRegion& read_access : access_regions[0]) {
if (!opaque_accessed_buffers.count(read_access->buffer.get())) {
new_read_regions.push_back(read_access);
}
}
for (const BufferRegion& write_access : access_regions[1]) {
if (!opaque_accessed_buffers.count(write_access->buffer.get())) {
new_write_regions.push_back(write_access);
}
}
for (const BufferRegion& opaque_access : access_regions[2]) {
new_read_regions.push_back(opaque_access);
new_write_regions.push_back(opaque_access);
}
return {new_read_regions, new_write_regions};
}

TVM_REGISTER_GLOBAL("tir.analysis.GetBlockAccessRegion").set_body_typed(GetBlockAccessRegion);
TVM_REGISTER_GLOBAL("tir.analysis.GetBlockReadWriteRegion").set_body_typed(GetBlockReadWriteRegion);

} // namespace tir
} // namespace tvm
30 changes: 11 additions & 19 deletions src/tir/analysis/buffer_access_lca_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,16 @@ class LCADetector : public StmtExprVisitor {
buffer_var_map_.emplace(buf->data.get(), buf.get());
}

const ScopeInfo* parent_scope = ancestor_scopes_.back();
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);

ancestor_scopes_.push_back(current_scope);
// Update match_buffers
for (const MatchBufferRegion& match_buffer : op->match_buffers) {
const Buffer& target_buffer = match_buffer->buffer;
buffer_var_map_.emplace(target_buffer->data.get(), target_buffer.get());

const Buffer& source_buffer = match_buffer->source->buffer;
auto it = match_buffers_.find(source_buffer.get());
if (it != match_buffers_.end()) {
match_buffers_[target_buffer.get()] = it->second;
} else {
match_buffers_[target_buffer.get()] = source_buffer.get();
}
UpdateBufferLCA(match_buffer->source->buffer.get());
match_buffers_.insert(match_buffer->buffer.get());
}

const ScopeInfo* parent_scope = ancestor_scopes_.back();
auto* current_scope = arena_.make<ScopeInfo>(parent_scope, op, n);
ancestor_scopes_.push_back(current_scope);
StmtExprVisitor::VisitStmt_(op);
ancestor_scopes_.pop_back();
}
Expand Down Expand Up @@ -144,12 +137,11 @@ class LCADetector : public StmtExprVisitor {
}

void UpdateBufferLCA(const BufferNode* buffer) {
auto it = match_buffers_.find(buffer);
if (it != match_buffers_.end()) {
buffer = it->second;
if (match_buffers_.find(buffer) == match_buffers_.end()) {
// Ingore buffer created by block match_buffer
const ScopeInfo*& lca = buffer_lca_[buffer];
lca = LowestCommonAncestor(lca, ancestor_scopes_.back());
}
const ScopeInfo*& lca = buffer_lca_[buffer];
lca = LowestCommonAncestor(lca, ancestor_scopes_.back());
}

static const ScopeInfo* LowestCommonAncestor(const ScopeInfo* lhs, const ScopeInfo* rhs) {
Expand Down Expand Up @@ -184,7 +176,7 @@ class LCADetector : public StmtExprVisitor {
/*! \brief The map from Buffer data to the Buffer. */
std::unordered_map<const VarNode*, const BufferNode*> buffer_var_map_ = {};
/*! \brief The match buffers inside blocks. */
std::unordered_map<const BufferNode*, const BufferNode*> match_buffers_ = {};
std::unordered_set<const BufferNode*> match_buffers_ = {};
/*! \brief Internal arena. */
support::Arena arena_;
};
Expand Down
24 changes: 10 additions & 14 deletions src/tir/ir/script/script_complete.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,12 @@ namespace tir {
/*! \brief Generate surrounding loops automatically */
class ScriptCompleter : public StmtMutator {
public:
explicit ScriptCompleter(Map<Var, Buffer>* buffer_var_map, bool contain_root)
: buffer_var_map_(buffer_var_map), contain_root_(contain_root) {}
explicit ScriptCompleter(Map<Var, Buffer>* buffer_var_map) : buffer_var_map_(buffer_var_map) {}
/*! \brief Whether the stmt contains at least one block. */
bool contains_block = false;

private:
Map<Var, Buffer>* buffer_var_map_;
bool contain_root_;
bool visited_root_ = false;
Stmt VisitStmt_(const BlockRealizeNode* op) override {
contains_block = true;
Stmt body = StmtMutator::VisitStmt_(op);
Expand All @@ -65,17 +62,23 @@ class ScriptCompleter : public StmtMutator {
}

Stmt VisitStmt_(const BlockNode* op) override {
bool is_root_block = contain_root_ && !visited_root_;
visited_root_ = true;
// Buffers allocated in the block can be accessed by its body.
for (const auto& alloc_buffer : op->alloc_buffers) {
buffer_var_map_->Set(alloc_buffer->data, alloc_buffer);
}
for (const auto& match_buffer : op->match_buffers) {
const Buffer& target_buffer = match_buffer->buffer;
buffer_var_map_->Set(target_buffer->data, target_buffer);
}
Block block = Downcast<Block>(StmtMutator::VisitStmt_(op));
// Remove buffers allocated inside block to detect its access region
for (const auto& alloc_buffer : op->alloc_buffers) {
buffer_var_map_->erase(alloc_buffer->data);
}
for (const auto& match_buffer : op->match_buffers) {
const Buffer& target_buffer = match_buffer->buffer;
buffer_var_map_->erase(target_buffer->data);
}
// Get access detection mask
// 0 for provided region, 1 and 3 for need detect read, 2 and 3 for need detect write
int mask = 0;
Expand All @@ -85,13 +88,6 @@ class ScriptCompleter : public StmtMutator {
}
// ignore root block or blocks which already has reads/writes regions
if (mask != 0) {
if (op->iter_vars.empty()) {
// non-root opaque block is not allowed
CHECK(is_root_block)
<< "ValueError: Can not auto detect buffer access region for an opaque block. Please "
"annotate the access region manually.";
return std::move(block);
}
auto access_region = GetBlockAccessRegion(block, *buffer_var_map_);
const Array<BufferRegion>& reads = access_region[0];
const Array<BufferRegion>& writes = access_region[1];
Expand Down Expand Up @@ -122,7 +118,7 @@ PrimFunc ScriptComplete(PrimFunc func, const Array<Buffer>& root_allocates) {
}
bool contain_root = root_allocates.empty() && func->body->IsInstance<BlockRealizeNode>() &&
Downcast<BlockRealize>(func->body)->block->iter_vars.empty();
ScriptCompleter script_completer(&buffer_var_map, contain_root);
ScriptCompleter script_completer(&buffer_var_map);
// generate surrounding loops automatically
Stmt res = script_completer(func->body);
// generate root block automatically
Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MatchBufferRegionNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const MatchBufferRegionNode*>(node.get());
p->PrintIndent();
p->stream << op->buffer->name << " = match_buffer_region(";
p->stream << op->buffer->name << " = match_buffer(";
p->Print(op->source);
p->stream << ")\n";
});
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class BaseInliner : public StmtExprMutator {
Array<BufferRegion> reads = std::move(block->reads);
Array<BufferRegion> writes = std::move(block->writes);
if (!is_scope_root) {
Array<Array<BufferRegion>> inspected = GetBlockAccessRegion(block, buffer_var_map_);
Array<Array<BufferRegion>> inspected = GetBlockReadWriteRegion(block, buffer_var_map_);
reads = std::move(inspected[0]);
writes = std::move(inspected[1]);
}
Expand Down
47 changes: 47 additions & 0 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include "ir_utils.h"

#include <tvm/arith/analyzer.h>
#include <tvm/tir/stmt_functor.h>

#include <unordered_map>
Expand Down Expand Up @@ -210,5 +211,51 @@ String GetPtrStorageScope(Var buffer_var) {
return ptr_type->storage_scope;
}

Array<PrimExpr> ConvertIndices(const MatchBufferRegion& match_buffer,
const Array<PrimExpr>& indices) {
const Buffer& target = match_buffer->buffer;
const BufferRegion& source = match_buffer->source;
ICHECK_EQ(indices.size(), target->shape.size());

arith::Analyzer analyzer;
Array<PrimExpr> result;
result.reserve(source->region.size());
size_t offset = source->region.size() - indices.size();
for (size_t i = 0; i < offset; ++i) {
const Range& range = source->region[i];
ICHECK(analyzer.CanProve(range->extent == 1));
result.push_back(range->min);
}
for (size_t i = 0; i < indices.size(); ++i) {
const Range& range = source->region[i + offset];
const PrimExpr& index = indices[i];
result.push_back(range->min + index);
}
return result;
}

Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region) {
const Buffer& target = match_buffer->buffer;
const BufferRegion& source = match_buffer->source;
ICHECK_EQ(region.size(), target->shape.size());

arith::Analyzer analyzer;
Region result;
result.reserve(source->region.size());
size_t offset = source->region.size() - region.size();
for (size_t i = 0; i < offset; ++i) {
const Range& source_range = source->region[i];
ICHECK(analyzer.CanProve(source_range->extent == 1));
result.push_back(Range::FromMinExtent(source_range->min, 1));
}
for (size_t i = 0; i < region.size(); ++i) {
const Range& source_range = source->region[i + offset];
const Range& target_range = region[i];
result.push_back(
Range::FromMinExtent(source_range->min + target_range->min, target_range->extent));
}
return result;
}

} // namespace tir
} // namespace tvm
Loading

0 comments on commit 0ef2897

Please sign in to comment.