Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Handle axis_separators during FlattenBuffer #12652

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 115 additions & 8 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file flatten_buffer.cc
*/

#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

Expand Down Expand Up @@ -53,6 +54,34 @@ class BufferFlattener : public StmtExprMutator {
}
}

Stmt VisitStmt_(const BlockNode* op) final {
ICHECK_EQ(op->match_buffers.size(), 0)
<< "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. "
<< "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer.";

Block block = GetRef<Block>(op);

Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply([this](Buffer buf) { return GetFlattenedBuffer(buf); });
if (!alloc_buffers.same_as(op->alloc_buffers)) {
block.CopyOnWrite()->alloc_buffers = alloc_buffers;
}

Array<BufferRegion> reads = op->reads;
reads.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); });
if (!reads.same_as(op->reads)) {
block.CopyOnWrite()->reads = reads;
}

Array<BufferRegion> writes = op->writes;
writes.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); });
if (!writes.same_as(op->writes)) {
block.CopyOnWrite()->writes = writes;
}

return StmtExprMutator::VisitStmt_(block.get());
}

Stmt VisitStmt_(const AllocateNode* op) final {
Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
// TODO(Lunderberg): Move the handling of boolean into a
Expand All @@ -61,18 +90,70 @@ class BufferFlattener : public StmtExprMutator {
auto writer = alloc.CopyOnWrite();
writer->dtype = DataType::Int(8);
}
// Handle multi-dimension allocations

if (alloc->extents.size() == 1) {
return std::move(alloc);
} else {
Array<PrimExpr> flat_extent(static_cast<size_t>(1), 1);
for (size_t i = 0; i < alloc->extents.size(); i++) {
flat_extent.Set(0, flat_extent[0] * alloc->extents[i]);
// No flattening required for buffers that are already flat

// TODO(rfc-70): Keep the DeclBuffer node as-is. Stripping it
// out in the current implementation as not all lowering passes
// support DeclBuffer.
if (auto* decl_buffer = alloc->body.as<DeclBufferNode>()) {
alloc.CopyOnWrite()->body = std::move(decl_buffer->body);
}
auto n = alloc.CopyOnWrite();
n->extents = flat_extent;

return std::move(alloc);
}

if (auto* decl_buffer = alloc->body.as<DeclBufferNode>();
decl_buffer && decl_buffer->buffer->data.same_as(alloc->buffer_var)) {
// N-d buffer, use the DeclBuffer inside to determine how it
// should be flattened.
auto& buffer = decl_buffer->buffer;
bool matching_buffer = [&]() {
if (alloc->dtype != buffer->dtype) {
return false;
}
if (alloc->extents.size() != buffer->shape.size()) {
return false;
}
ExprDeepEqual expr_equal;
for (size_t i = 0; i < alloc->extents.size(); i++) {
if (!expr_equal(alloc->extents[i], buffer->shape[i])) {
return false;
}
}
return true;
}();

if (matching_buffer) {
Buffer flattened = GetFlattenedBuffer(buffer);

auto n = alloc.CopyOnWrite();
// TODO(rfc-70): Update the DeclBuffer node instead of
// stripping it out. Stripping it out in the current
// implementation as not all lowering passes support
// DeclBuffer.
//
// n->body = DeclBuffer(flattened, std::move(decl_buffer->body));
n->body = std::move(decl_buffer->body);
n->extents = flattened->shape;
return std::move(alloc);
} else {
ICHECK(decl_buffer->buffer->axis_separators.empty())
<< "DeclBuffer node doesn't match Allocate extents, but also shouldn't be "
"flattened to 1-d physical memory";
}
}

// Fallback, this is an allocation without a matching DeclBuffer
PrimExpr flat_extent = 1;
for (const auto& dim : alloc->extents) {
flat_extent *= dim;
}

auto n = alloc.CopyOnWrite();
n->extents = {flat_extent};
return std::move(alloc);
}

Buffer GetFlattenedBuffer(Buffer buf) {
Expand Down Expand Up @@ -141,6 +222,32 @@ class BufferFlattener : public StmtExprMutator {
return node;
}

BufferRegion MutateBufferRegion(BufferRegion region) {
Buffer orig_buf = region->buffer;
Buffer flattened_buf = GetFlattenedBuffer(orig_buf);
if (flattened_buf.same_as(orig_buf)) {
return region;
}

Array<PrimExpr> min_values;
Array<PrimExpr> max_values;
for (const auto& range : region->region) {
min_values.push_back(range->min);
max_values.push_back(range->min + range->extent - 1);
}

Array<PrimExpr> flattened_min = orig_buf->ElemOffset(min_values);
Array<PrimExpr> flattened_max = orig_buf->ElemOffset(max_values);

Array<Range> flattened_ranges;
ICHECK_EQ(flattened_min.size(), flattened_max.size());
for (size_t i = 0; i < flattened_min.size(); i++) {
flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1));
}

return BufferRegion(flattened_buf, flattened_ranges);
}

/*! \brief Map of buffers being remapped. */
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;

Expand Down
1 change: 1 addition & 0 deletions src/tir/transforms/lower_opaque_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class OpaqueBlockLower : public StmtExprMutator {
new_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
}
}
body = DeclBuffer(buffer, std::move(body));
body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(), std::move(body));
}
// Step 4. Handle annotations, block annotations are not preserved by default.
Expand Down
Loading