From e2afbbcaf69c874256a24df6b23c0ad0de9c5113 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 22 Sep 2021 09:53:09 -0500 Subject: [PATCH 01/21] [TE] Improved flexibility of ArgBinder::BindDLTensor Allowed a compact DLTensor to bind to a Buffer object that defines strides, if the strides defined correspond to a compact layout. --- src/tir/transforms/arg_binder.cc | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 293c990d2745..d3ab32cbd7f9 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -204,7 +204,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, def_handle_dtype_.Set(v_strides, tir::TypeAnnotation(tvm_shape_type)); init_nest_.emplace_back( LetStmt(v_strides, TVMArrayGet(DataType::Handle(), handle, builtin::kArrStrides), nop)); - PrimExpr is_null = Call(DataType::Bool(1), builtin::isnullptr(), {v_strides}); + PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), {v_strides}); if (buffer->strides.size() == 0) { // Assert the buffer is compact DataType stype = buffer->DefaultIndexType(); @@ -226,7 +226,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); }, const_true(1), conds), stride_msg, Evaluate(0)); - check = IfThenElse(Not(is_null), check, Stmt()); + check = IfThenElse(Not(v_strides_is_null), check, Stmt()); asserts_.emplace_back(SeqStmt({check, Evaluate(0)})); } } else if (buffer->buffer_type == kAutoBroadcast) { @@ -239,24 +239,29 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, PrimExpr value = cast(buffer->shape[k].dtype(), Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); - value = tvm::if_then_else(is_null, stride, value); + value = tvm::if_then_else(v_strides_is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); Bind_(buffer->strides[k], value, field_name.str(), true); stride = analyzer_.Simplify(stride * buffer->shape[k]); } } else { - std::ostringstream stride_null_err_msg; - stride_null_err_msg << arg_name << ".strides: expected non-null strides."; - asserts_.emplace_back( - AssertStmt(Not(is_null), tvm::tir::StringImm(stride_null_err_msg.str()), nop)); + PrimExpr stride_from_shape = 1; - for (size_t k = 0; k < buffer->strides.size(); ++k) { + for (int k = buffer->strides.size() - 1; k >= 0; k--) { std::ostringstream field_name; field_name << v_strides->name_hint << '[' << k << ']'; + + PrimExpr explicit_stride = + cast(buffer->shape[k].dtype(), + Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))); + Bind_(buffer->strides[k], - cast(buffer->shape[k].dtype(), - Load(tvm_shape_type, v_strides, IntImm(DataType::Int(32), k), const_true(1))), + tvm::if_then_else(v_strides_is_null, stride_from_shape, explicit_stride), field_name.str(), true); + + stride_from_shape *= + cast(buffer->shape[k].dtype(), + Load(tvm_shape_type, v_shape, IntImm(DataType::Int(32), k), const_true(1))); } } // Byte_offset field. From 806eb131687e77c2dfbba2123d614d805d64b2d1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 22 Sep 2021 10:27:40 -0500 Subject: [PATCH 02/21] [TIR] Exposed ElemOffset as a member function of BufferNode. --- include/tvm/tir/buffer.h | 8 ++++++++ src/tir/ir/buffer.cc | 24 ++++++++++++------------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 28d202cb50a9..f04209d0b061 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -121,6 +121,14 @@ class BufferNode : public Object { return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32); } + /*! \brief Determine the offset in the buffer of the given index. + * + * Returns the buffer offset, in number of elements of type dtype, + * without adjusting for number of lanes. (e.g. The number of + * float16x4 elements in a buffer of type float16x4.) + */ + PrimExpr ElemOffset(Array index) const; + static constexpr const char* _type_key = "tir.Buffer"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 8253fa3c0a36..24aacc3c04f7 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -246,41 +246,41 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { // The buffer offset in convention of number of elements of // original data ignoring number of lanes. // We also perform optimization to simplify the indexing expression. -inline PrimExpr ElemOffset(const BufferNode* n, Array index) { - PrimExpr base = n->elem_offset; +PrimExpr BufferNode::ElemOffset(Array index) const { + PrimExpr base = this->elem_offset; arith::Analyzer ana; - if (n->strides.size() == 0) { + if (this->strides.size() == 0) { // Scalar case - if (n->shape.size() == 0 && index.size() == 1) { + if (this->shape.size() == 0 && index.size() == 1) { auto is_int = index[0].as(); ICHECK(is_int && is_int->value == 0); base = base + index[0]; } else { - ICHECK_EQ(n->shape.size(), index.size()); + ICHECK_EQ(this->shape.size(), index.size()); if (index.size() > 0) { PrimExpr offset = index[0]; for (size_t i = 1; i < index.size(); ++i) { - offset = MergeMulMod(&ana, offset * n->shape[i] + index[i]); + offset = MergeMulMod(&ana, offset * this->shape[i] + index[i]); } base = base + offset; } } } else { - ICHECK_EQ(n->strides.size(), index.size()); + ICHECK_EQ(this->strides.size(), index.size()); if (is_zero(base)) { - base = MergeMulMod(&ana, index[0] * n->strides[0]); + base = MergeMulMod(&ana, index[0] * this->strides[0]); } else { - base = MergeMulMod(&ana, base + index[0] * n->strides[0]); + base = MergeMulMod(&ana, base + index[0] * this->strides[0]); } for (size_t i = 1; i < index.size(); ++i) { - base = MergeMulMod(&ana, base + index[i] * n->strides[i]); + base = MergeMulMod(&ana, base + index[i] * this->strides[i]); } } return base; } inline PrimExpr BufferOffset(const BufferNode* n, Array index, DataType dtype) { - PrimExpr offset = ElemOffset(n, index); + PrimExpr offset = n->ElemOffset(index); if (n->dtype.lanes() != 1) { offset = offset * make_const(offset.dtype(), dtype.lanes()); } @@ -353,7 +353,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const ICHECK(n != nullptr); arith::Analyzer ana; begins = SimplifyArray(&ana, begins); - PrimExpr elem_offset = ana.Simplify(ElemOffset(n, begins)); + PrimExpr elem_offset = ana.Simplify(n->ElemOffset(begins)); Array strides = n->strides; if (strides.size() == 0) { bool can_relax = true; From 7c680852d01b41dc3c5beb706db7acb100e815ef Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 21 Sep 2021 10:34:26 -0500 Subject: [PATCH 03/21] [TE] Pulled shape determination out of StorageFlattener Previously, StorageFlattener would determine the shape of a physical buffer based on the extents of the BufferRealizeNode. Pulled these out into a separate BufferShapeLegalize pass. After this pass, all buffers have a shape that matches the buffer realization extents. --- src/tir/transforms/storage_flatten.cc | 236 +++++++++++++++++++++++++- 1 file changed, 232 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 2c32cc7f0883..438217583d66 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -50,6 +50,229 @@ using runtime::StorageRank; using runtime::StorageScope; using runtime::ThreadScope; +/* Make buffer realize extents and buffer shapes consistent + * + * For external buffers, verify that the extents of BufferRealize + * nodes match the shape of the external buffer. For internal + * buffers, rewrite the shape of the Buffer objects to match the + * extent of the BufferRealize, and rewrite indices of + * BufferLoad/BufferStore nodes to match. + */ +class BufferShapeLegalize : public StmtExprMutator { + public: + explicit BufferShapeLegalize(const Map& extern_buffer_map, + IRVisitorWithAnalyzer* bound_analyzer) + : bound_analyzer_(bound_analyzer) { + for (auto kv : extern_buffer_map) { + extern_buffers_.insert(kv.second); + } + } + + Stmt VisitStmt_(const BufferRealizeNode* op) final { + // External buffers should not be changed. + if (extern_buffers_.count(op->buffer)) { + ICHECK_EQ(op->buffer->shape.size(), op->bounds.size()) + << "External buffer realize has mismatched dimension"; + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op); + + for (size_t i = 0; i < op->bounds.size(); i++) { + PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent); + std::ostringstream ss; + ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape " + << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent; + if (auto eq_int = eq.as()) { + ICHECK(eq_int->value) << ss.str(); + } else { + stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt); + } + } + return stmt; + } + + // Compute the new buffer shape, new realization bounds, and the + // offsets to be applied to buffer access. + Array realized_shape; + Array realized_begins; + Array new_bounds; + for (size_t i = 0; i < op->bounds.size(); i++) { + const Range& bound = op->bounds[i]; + realized_shape.push_back(bound->extent); + realized_begins.push_back(bound->min); + new_bounds.push_back({0, bound->extent}); + } + + Buffer key = op->buffer; + + Buffer buf = op->buffer; + auto write_ptr = buf.CopyOnWrite(); + write_ptr->shape = realized_shape; + + { + InternalBufferRemap remap; + remap.remap_to = buf; + remap.realized_begins = realized_begins; + remap.in_scope = true; + internal_buf_map_[key] = remap; + } + + Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span); + + internal_buf_map_.at(key).in_scope = false; + + return stmt; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op); + + auto it = internal_buf_map_.find(op->buffer); + if (it != internal_buf_map_.end()) { + const InternalBufferRemap& entry = it->second; + ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer"; + ICHECK_EQ(entry.realized_begins.size(), op->indices.size()) + << "Inconsistent dimensions for buffer " << op->buffer->name; + + Array new_indices; + for (size_t i = 0; i < entry.realized_begins.size(); i++) { + new_indices.push_back(op->indices[i] - entry.realized_begins[i]); + } + + BufferStore updated = GetRef(op); + auto write_ptr = updated.CopyOnWrite(); + write_ptr->indices = new_indices; + write_ptr->buffer = entry.remap_to; + stmt = updated; + } + + return stmt; + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK(op); + + auto it = internal_buf_map_.find(op->buffer); + if (it != internal_buf_map_.end()) { + const InternalBufferRemap& entry = it->second; + ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer"; + ICHECK_EQ(entry.realized_begins.size(), op->indices.size()) + << "Inconsistent dimensions for buffer " << op->buffer->name; + + Array new_indices; + for (size_t i = 0; i < entry.realized_begins.size(); i++) { + new_indices.push_back(op->indices[i] - entry.realized_begins[i]); + } + + BufferLoad updated = GetRef(op); + auto write_ptr = updated.CopyOnWrite(); + write_ptr->indices = new_indices; + write_ptr->buffer = entry.remap_to; + expr = updated; + } + + return expr; + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { + return HandleDoubleBuffer(op); + } else if (op->attr_key == attr::buffer_bind_scope) { + return HandleBufferBindScope(op); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + private: + Stmt HandleDoubleBuffer(const AttrStmtNode* op) { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op); + + Buffer buffer = Downcast(op->node); + auto it = internal_buf_map_.find(buffer); + if (it != internal_buf_map_.end()) { + return AttrStmt(it->second.remap_to->data, op->attr_key, op->value, op->body); + } else { + return stmt; + } + } + + Stmt HandleBufferBindScope(const AttrStmtNode* op) { + Array arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2U); + Buffer buffer = Downcast(arr[0]); + ICHECK(buffer.defined()); + Buffer target = Downcast(arr[1]); + ICHECK(target.defined()); + + auto it = internal_buf_map_.find(target); + if (it == internal_buf_map_.end()) { + return StmtExprMutator::VisitStmt_(op); + } + const InternalBufferRemap& target_remap = it->second; + + ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name + << " to the out-of-scope buffer " << target_remap.remap_to->name; + + Call tuple = Downcast(op->value); + ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple())); + + Array new_tuple_args; + Array realized_begins; + Array realized_shape; + ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2); + for (size_t i = 0; i < target_remap.realized_begins.size(); i++) { + PrimExpr parent_begin = tuple->args[2 * i]; + PrimExpr view_extent = tuple->args[2 * i + 1]; + // Offset the begin of the buffer view by the offset of the target buffer. + new_tuple_args.push_back(parent_begin - target_remap.realized_begins[i]); + // Keep the extent of the buffer view the same. + new_tuple_args.push_back(view_extent); + // Use the extent of the buffer view to define the buffer view's shape. + realized_shape.push_back(view_extent); + // Within the buffer view, indices start at 0. + realized_begins.push_back(0); + } + + Buffer key = buffer; + + auto write_ptr = buffer.CopyOnWrite(); + write_ptr->shape = realized_shape; + + { + InternalBufferRemap remap; + remap.realized_begins = realized_begins; + remap.remap_to = buffer; + remap.in_scope = true; + internal_buf_map_[key] = remap; + } + + Stmt stmt = AttrStmt(Array{buffer, target_remap.remap_to}, op->attr_key, + Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span), + this->VisitStmt(op->body)); + internal_buf_map_.at(key).in_scope = false; + return stmt; + } + + std::unordered_set extern_buffers_; + + struct InternalBufferRemap { + Buffer remap_to; + Array realized_begins; + bool in_scope; + }; + + std::unordered_map internal_buf_map_; + + IRVisitorWithAnalyzer* bound_analyzer_; +}; + class StorageFlattener : public StmtExprMutator { public: explicit StorageFlattener(const Map& extern_buffer_map, int cache_line_size, @@ -147,10 +370,14 @@ class StorageFlattener : public StmtExprMutator { // create a buffer entry BufferEntry e; e.bounds = op->bounds; - Array shape; - for (auto r : e.bounds) { - shape.push_back(r->extent); - } + + ICHECK(op->buffer->shape.size()) << "StorageFlattener expects buffer shapes to be defined. " + << "Please run through BufferShapeLegalize first."; + + ICHECK_EQ(op->buffer->shape.size(), op->bounds.size()) + << "Inconsistent buffer shape and realization shape for " << op->buffer; + + Array shape = op->buffer->shape; // deduce current storage scope. StorageScope skey; std::string strkey = GetPtrStorageScope(op->buffer->data); @@ -507,6 +734,7 @@ PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_at IRVisitorWithAnalyzer bound_analyzer; bound_analyzer(fptr->body); + fptr->body = BufferShapeLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer)(std::move(fptr->body)); return func; From 6599278392daf62d419757418ee4902a847de42a Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 22 Sep 2021 09:52:05 -0500 Subject: [PATCH 04/21] [TE] Refactor stride calculation out of StorageFlattener Previously, StorageFlattener would handle any attr::dim_align annotations. Now, this is pulled out into a separate BufferStrideLegalize pass. --- src/tir/transforms/storage_flatten.cc | 233 ++++++++++++++++++++++---- 1 file changed, 200 insertions(+), 33 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 438217583d66..cdf5ab052291 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -273,6 +273,193 @@ class BufferShapeLegalize : public StmtExprMutator { IRVisitorWithAnalyzer* bound_analyzer_; }; +/* Apply dimension alignment restrictions + * + * Buffers annotated with attr::buffer_dim_align may need to have + * strides defined such that they are no longer in a compact shape. + * After this pass, buffers have stride definitions to include these + * alignment restrictions, and attr::buffer_dim_align annotations have + * been removed. + */ +class BufferStrideLegalize : public StmtExprMutator { + public: + explicit BufferStrideLegalize(const Map& extern_buffer_map, + IRVisitorWithAnalyzer* bound_analyzer) + : bound_analyzer_(bound_analyzer) { + for (auto kv : extern_buffer_map) { + Buffer buf = kv.second; + Buffer with_strides = WithStrides(buf, true); + { + BufferEntry entry; + entry.remap_to = with_strides; + entry.in_scope = true; + entry.is_external = true; + buf_map_[buf] = entry; + } + updated_extern_buffer_map_.Set(kv.first, with_strides); + } + } + + Map UpdatedExternBufferMap() const { return updated_extern_buffer_map_; } + + Buffer WithStrides(Buffer buf, bool allow_scalar = false) { + auto it = buf_map_.find(buf); + if (it != buf_map_.end()) { + const BufferEntry& entry = it->second; + ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer"; + return entry.remap_to; + } + + if (buf->strides.size()) { + ICHECK_EQ(buf->strides.size(), buf->shape.size()); + return buf; + } + + Array shape = buf->shape; + + if (shape.size() == 0) { + // This is only allowed for buffers that point to external + // buffers. These are treated as 1-d pointers of unknown size, + // with stride of 1. + ICHECK(allow_scalar) + << "Buffer " << buf << " does not have a valid shape. " + << "BufferStrideLegalize requires all internal buffers to have a valid shape. " + << "Please run BufferShapeLegalize first"; + return buf; + } + + // Keeping this to have matched behavior to previous version. + // There are many parts of the codebase that assume that a strided + // array cannot be compact. + if (dim_align_.count(buf) == 0) { + return buf; + } + + std::vector rstrides; + const std::vector& avec = dim_align_[buf]; + int first_dim = 0; + PrimExpr stride = make_const(shape[first_dim].dtype(), 1); + for (size_t i = shape.size(); i != 0; --i) { + size_t dim = i - 1; + if (dim < avec.size() && avec[dim].align_factor != 0) { + PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); + PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); + stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); + stride = bound_analyzer_->Simplify(stride); + } + rstrides.push_back(stride); + stride = stride * shape[dim]; + } + + auto ptr = buf.CopyOnWrite(); + ptr->strides = Array(rstrides.rbegin(), rstrides.rend()); + + return buf; + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::buffer_dim_align) { + auto buffer = Downcast(op->node); + const CallNode* tuple = op->value.as(); + ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); + auto& vinfo = dim_align_[buffer]; + int dim = tuple->args[0].as()->value; + if (static_cast(dim) >= vinfo.size()) { + vinfo.resize(dim + 1); + } + vinfo[dim].align_factor = tuple->args[1].as()->value; + vinfo[dim].align_offset = tuple->args[2].as()->value; + return this->VisitStmt(op->body); + } else if (op->attr_key == attr::buffer_bind_scope) { + Array arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2U); + Buffer source = Downcast(arr[0]); + Buffer target_with_strides = WithStrides(Downcast(arr[1])); + Buffer source_with_strides = WithStrides(source); + + { + BufferEntry entry; + entry.remap_to = source_with_strides; + entry.in_scope = true; + entry.is_external = false; + buf_map_[source] = entry; + } + + Stmt body = this->VisitStmt(op->body); + + return AttrStmt(Array{source_with_strides, target_with_strides}, op->attr_key, + op->value, body, op->span); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + Stmt VisitStmt_(const BufferRealizeNode* op) final { + Buffer key = op->buffer; + Buffer with_strides = WithStrides(op->buffer); + { + BufferEntry entry; + entry.remap_to = with_strides; + entry.in_scope = true; + entry.is_external = false; + buf_map_[key] = entry; + } + + Stmt stmt = StmtExprMutator::VisitStmt_(op); + + buf_map_[key].in_scope = false; + op = stmt.as(); + ICHECK(op); + + return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + + auto it = buf_map_.find(op->buffer); + ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; + const BufferEntry& e = it->second; + ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope"; + + return BufferLoad(e.remap_to, op->indices, op->span); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + + auto it = buf_map_.find(op->buffer); + ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; + const BufferEntry& e = it->second; + ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope"; + + return BufferStore(e.remap_to, op->value, op->indices, op->span); + } + + private: + Map updated_extern_buffer_map_; + + struct DimAlignInfo { + int align_factor{0}; + int align_offset{0}; + }; + + // Dimension alignment + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dim_align_; + + struct BufferEntry { + Buffer remap_to; + bool in_scope; + bool is_external; + }; + + std::unordered_map buf_map_; + + IRVisitorWithAnalyzer* bound_analyzer_; +}; + class StorageFlattener : public StmtExprMutator { public: explicit StorageFlattener(const Map& extern_buffer_map, int cache_line_size, @@ -301,6 +488,11 @@ class StorageFlattener : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { + ICHECK_NE(op->attr_key, attr::buffer_dim_align) + << "StorageFlattener assumes that all buffers have accurate strides, " + << "and all buffer_dim_align annotations are removed. " + << "Please run BufferStrideLegalize first."; + if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); @@ -317,18 +509,6 @@ class StorageFlattener : public StmtExprMutator { return stmt; } else if (op->attr_key == attr::buffer_bind_scope) { return HandleBufferBindScope(op); - } else if (op->attr_key == attr::buffer_dim_align) { - auto buffer = Downcast(op->node); - const CallNode* tuple = op->value.as(); - ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); - auto& vinfo = dim_align_[buffer]; - int dim = tuple->args[0].as()->value; - if (static_cast(dim) >= vinfo.size()) { - vinfo.resize(dim + 1); - } - vinfo[dim].align_factor = tuple->args[1].as()->value; - vinfo[dim].align_offset = tuple->args[2].as()->value; - return this->VisitStmt(op->body); } return StmtExprMutator::VisitStmt_(op); } @@ -400,25 +580,7 @@ class StorageFlattener : public StmtExprMutator { << "Allocation exceed bound of memory tag " << skey.to_string(); } } - Array strides; - if (dim_align_.count(key) != 0 && shape.size() != 0) { - std::vector rstrides; - const std::vector& avec = dim_align_[key]; - int first_dim = 0; - PrimExpr stride = make_const(shape[first_dim].dtype(), 1); - for (size_t i = shape.size(); i != 0; --i) { - size_t dim = i - 1; - if (dim < avec.size() && avec[dim].align_factor != 0) { - PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); - PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); - stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); - stride = bound_analyzer_->Simplify(stride); - } - rstrides.push_back(stride); - stride = stride * shape[dim]; - } - strides = Array(rstrides.rbegin(), rstrides.rend()); - } + Array strides = op->buffer->strides; auto* ptr_type = op->buffer->data->type_annotation.as(); ICHECK(ptr_type); @@ -711,8 +873,6 @@ class StorageFlattener : public StmtExprMutator { std::unordered_map var_remap_; // Buffer map std::unordered_map buf_map_; - // Dimension alignment - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dim_align_; // The current thread scope. std::vector curr_thread_scope_; // Collects shapes. @@ -734,9 +894,16 @@ PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_at IRVisitorWithAnalyzer bound_analyzer; bound_analyzer(fptr->body); + fptr->body = BufferShapeLegalize(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); + + auto stride_legalize = BufferStrideLegalize(fptr->buffer_map, &bound_analyzer); + fptr->body = stride_legalize(std::move(fptr->body)); + fptr->buffer_map = stride_legalize.UpdatedExternBufferMap(); + fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer)(std::move(fptr->body)); + return func; } else { return func; From 0716e0125c11b786885731d0fd9abaebb438a883 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 22 Sep 2021 14:50:36 -0500 Subject: [PATCH 05/21] [TE] Refactor thread scope propagation out of StorageFlattener. Previously, StorageFlattener would use the scope in IterVar to assign a scope to allocated buffers, where not otherwise defined. This has been pulled out into a separate ThreadScopePropagate pass. --- src/tir/transforms/storage_flatten.cc | 175 ++++++++++++++++++++++---- 1 file changed, 150 insertions(+), 25 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index cdf5ab052291..6d5abc843a9a 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -460,6 +460,150 @@ class BufferStrideLegalize : public StmtExprMutator { IRVisitorWithAnalyzer* bound_analyzer_; }; +/* Use the scope of IterVar to determine storage scope. + * + * For buffers that do not have an explicit storage scope defined, a + * reasonable storage scope may be defined based on the thread scope + * that contains the buffer's allocation. + */ +class ThreadScopePropagate : public StmtExprMutator { + public: + explicit ThreadScopePropagate(const Map& extern_buffer_map) { + // External buffers shouldn't be overwritten, even if they have a + // BufferRealizeNode. + for (auto kv : extern_buffer_map) { + external_buffers_.insert(kv.second); + } + } + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = buf_remap_.find(GetRef(op)); + if (it != buf_remap_.end()) { + return it->second->data; + } else { + return GetRef(op); + } + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + ICHECK_NE(op->attr_key, attr::buffer_dim_align) + << "StorageFlattener assumes that all buffers have accurate strides, " + << "and all buffer_dim_align annotations are removed. " + << "Please run BufferStrideLegalize first."; + + if (op->attr_key == attr::thread_extent) { + IterVar iv = Downcast(op->node); + ThreadScope ts = ThreadScope::Create(iv->thread_tag); + curr_thread_scope_.push_back(ts); + Stmt stmt = StmtExprMutator::VisitStmt_(op); + curr_thread_scope_.pop_back(); + return stmt; + } else if (op->attr_key == attr::buffer_bind_scope) { + return HandleBufferBindScope(op); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + Stmt VisitStmt_(const BufferRealizeNode* op) final { + Var old_var = op->buffer->data; + + // Don't remap buffers that already have an explicit scope, + // external buffers, or buffers outside of a thread scope. + std::string str_scope = GetPtrStorageScope(old_var); + if ((str_scope.length() > 0) || external_buffers_.count(op->buffer) || + (curr_thread_scope_.size() == 0)) { + return StmtExprMutator::VisitStmt_(op); + } + + ICHECK_EQ(buf_remap_.count(old_var), 0) + << "Buffer var " << op->buffer->data << " appears in multiple BufferRealize nodes"; + + StorageScope skey; + skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); + + auto ptr_type = old_var->type_annotation.as(); + ICHECK(ptr_type); + Var new_var(old_var->name_hint, PointerType(ptr_type->element_type, skey.to_string()), + old_var->span); + + Buffer buf = op->buffer; + buf.CopyOnWrite()->data = new_var; + + buf_remap_[old_var] = buf; + + Stmt body = this->VisitStmt(op->body); + return BufferRealize(buf, op->bounds, op->condition, body, op->span); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK(op); + + auto it = buf_remap_.find(op->buffer->data); + if (it != buf_remap_.end()) { + return BufferLoad(it->second, op->indices, op->span); + } else { + return expr; + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op); + + auto it = buf_remap_.find(op->buffer->data); + if (it != buf_remap_.end()) { + return BufferStore(it->second, op->value, op->indices, op->span); + } else { + return stmt; + } + } + + private: + Stmt HandleBufferBindScope(const AttrStmtNode* op) { + Array arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2U); + Buffer buffer = Downcast(arr[0]); + ICHECK(buffer.defined()); + Buffer target = Downcast(arr[1]); + ICHECK(target.defined()); + + bool needs_rewrite = false; + + { + auto it = buf_remap_.find(buffer->data); + if (it != buf_remap_.end()) { + needs_rewrite = true; + buffer = it->second; + } + } + + { + auto it = buf_remap_.find(target->data); + if (it != buf_remap_.end()) { + needs_rewrite = true; + target = it->second; + } + } + + if (needs_rewrite) { + Stmt body = this->VisitStmt(op->body); + return AttrStmt(Array{buffer, target}, op->attr_key, op->value, body); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + std::unordered_map buf_remap_; + std::unordered_set external_buffers_; + + // The current thread scope. + std::vector curr_thread_scope_; +}; + class StorageFlattener : public StmtExprMutator { public: explicit StorageFlattener(const Map& extern_buffer_map, int cache_line_size, @@ -500,13 +644,6 @@ class StorageFlattener : public StmtExprMutator { ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; body = AttrStmt(it->second.buffer->data, op->attr_key, op->value, std::move(body)); return body; - } else if (op->attr_key == attr::thread_extent) { - IterVar iv = Downcast(op->node); - ThreadScope ts = ThreadScope::Create(iv->thread_tag); - curr_thread_scope_.push_back(ts); - Stmt stmt = StmtExprMutator::VisitStmt_(op); - curr_thread_scope_.pop_back(); - return stmt; } else if (op->attr_key == attr::buffer_bind_scope) { return HandleBufferBindScope(op); } @@ -558,16 +695,8 @@ class StorageFlattener : public StmtExprMutator { << "Inconsistent buffer shape and realization shape for " << op->buffer; Array shape = op->buffer->shape; - // deduce current storage scope. - StorageScope skey; - std::string strkey = GetPtrStorageScope(op->buffer->data); - if (strkey.length() == 0) { - if (curr_thread_scope_.size() != 0) { - skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); - } - } else { - skey = StorageScope::Create(strkey); - } + StorageScope skey = StorageScope::Create(GetPtrStorageScope(op->buffer->data)); + // use small alignment for small arrays auto dtype = op->buffer->dtype; int32_t const_size = AllocateNode::constant_allocation_size(shape); @@ -582,12 +711,8 @@ class StorageFlattener : public StmtExprMutator { } Array strides = op->buffer->strides; - auto* ptr_type = op->buffer->data->type_annotation.as(); - ICHECK(ptr_type); - auto new_var = - Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type, skey.to_string())); - e.buffer = Buffer(new_var, op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, - align, 0, kDefault); + e.buffer = Buffer(op->buffer->data, op->buffer->dtype, shape, strides, PrimExpr(), + op->buffer->name, align, 0, kDefault); buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); @@ -873,8 +998,6 @@ class StorageFlattener : public StmtExprMutator { std::unordered_map var_remap_; // Buffer map std::unordered_map buf_map_; - // The current thread scope. - std::vector curr_thread_scope_; // Collects shapes. std::vector>> shape_collector_; // bounds populator. We really need the analyzer from it. @@ -901,6 +1024,8 @@ PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_at fptr->body = stride_legalize(std::move(fptr->body)); fptr->buffer_map = stride_legalize.UpdatedExternBufferMap(); + fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body)); + fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer)(std::move(fptr->body)); From 9fa935ee56acb9240a85df3e6e1f9b5a9a0ed1bc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 22 Sep 2021 15:42:49 -0500 Subject: [PATCH 06/21] [TE] Refactor buffer bind mapping out of StorageFlattener. Previously, StorageFlattener would look for `attr::buffer_bind_scope` to determine if a Buffer object is a view into another buffer, and would apply that mapping while making the Allocate/Store/Load nodes. Now, the mapping of buffer binds is pulled out into a separate BufferStrideUnwrapper pass. This also resolves an issue in which BufferLoad/BufferStore nodes that refer to a Buffer defined through `attr::buffer_bind_scope` would generate Load/Store nodes that point to the linked buffer, rather than the actual buffer. --- src/tir/transforms/storage_flatten.cc | 495 +++++++++++++++++++++----- 1 file changed, 404 insertions(+), 91 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 6d5abc843a9a..6a476ed06323 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -604,6 +604,364 @@ class ThreadScopePropagate : public StmtExprMutator { std::vector curr_thread_scope_; }; +/* Map buffer binds to their source buffer + * + * Buffers defined using an attr::buffer_bind_scope annotation are + * views into some linked buffer, potentially into some restricted + * subregion of that buffer. This pass identifies such buffers, then + * rewrites all access of the bound buffers to be access into the + * linked buffer. + */ +class BufferBindUnwrapper : public StmtExprMutator { + public: + explicit BufferBindUnwrapper(const Map& extern_buffer_map, + IRVisitorWithAnalyzer* bound_analyzer) + : bound_analyzer_(bound_analyzer) { + for (auto kv : extern_buffer_map) { + BufferEntry e; + e.buffer = kv.second; + e.external = true; + buf_map_[kv.second.get()] = std::move(e); + } + } + + Stmt VisitStmt_(const StoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + auto it = var_remap_.find(op->buffer_var.get()); + if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { + // TODO(Lunderberg): Change from warning to error once all mixed + // use of physical/logical layouts is removed. + DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), " + << "but is accessed as a pointer (StoreNode)."; + + ICHECK(it->second.as()); + Var new_buf_var = Downcast(it->second); + return Store(new_buf_var, op->value, op->index, op->predicate); + } else { + return stmt; + } + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + auto it = var_remap_.find(op->buffer_var.get()); + if (it != var_remap_.end() && !it->second.same_as(op->buffer_var)) { + // TODO(Lunderberg): Change from warning to error once all mixed + // use of physical/logical layouts is removed. + DLOG(WARNING) << op->buffer_var << " was declared as buffer (buffer_bind_scope), " + << "but is accessed as a pointer (LoadNode)."; + + ICHECK(it->second.as()); + Var new_buf_var = Downcast(it->second); + return Load(op->dtype, new_buf_var, op->index, op->predicate); + } else { + return expr; + } + } + + Stmt VisitStmt_(const AttrStmtNode* op) final { + ICHECK_NE(op->attr_key, attr::buffer_dim_align) + << "BufferBindUnwrapper assumes that all buffers have accurate strides, " + << "and all buffer_dim_align annotations are removed. " + << "Please run BufferStrideLegalize first."; + + if (op->attr_key == attr::buffer_bind_scope) { + return HandleBufferBindScope(op); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = var_remap_.find(op); + if (it != var_remap_.end()) { + return it->second; + } else { + return GetRef(op); + } + } + + Array remap_indices(Array indices, Array begins, + Array extents) { + ICHECK_EQ(begins.size(), extents.size()); + + if (begins.size() == 0) { + return indices; + } + + ICHECK_EQ(begins.size(), indices.size()); + + Array out; + for (size_t i = 0; i < begins.size(); i++) { + out.push_back(begins[i] + indices[i]); + } + return out; + } + + Array remap_bounds(Array bounds, Array begins, Array extents) { + ICHECK_EQ(begins.size(), extents.size()); + + if (begins.size() == 0) { + return bounds; + } + + ICHECK_EQ(begins.size(), bounds.size()); + + Array out; + for (size_t i = 0; i < begins.size(); i++) { + out.push_back(Range::FromMinExtent(bounds[i]->min + begins[i], bounds[i]->extent)); + } + return out; + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + + auto it = buf_map_.find(op->buffer.get()); + ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; + const BufferEntry& e = it->second; + ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope"; + + if (e.remap) { + return BufferLoad(e.remap->target, + remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); + } else { + return expr; + } + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + + auto it = buf_map_.find(op->buffer.get()); + ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; + const BufferEntry& e = it->second; + ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope"; + + if (e.remap) { + return BufferStore(e.remap->target, op->value, + remap_indices(op->indices, e.remap->begins, e.remap->extents), op->span); + } else { + return stmt; + } + } + + Stmt VisitStmt_(const BufferRealizeNode* op) final { + const auto& key = op->buffer.get(); + + bool is_external = false; + + if (buf_map_.count(key)) { + ICHECK(buf_map_.at(key).external) + << "BufferRealize node for internal buffer " << op->buffer << " occurred multiple times."; + + is_external = true; + } else { + BufferEntry e; + e.bounds = op->bounds; + e.buffer = op->buffer; + buf_map_[key] = std::move(e); + } + + Stmt stmt = StmtExprMutator::VisitStmt_(op); + + if (is_external) { + buf_map_[key].in_scope = false; + } + + return stmt; + } + + Stmt VisitStmt_(const PrefetchNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + + const auto& key = op->buffer.get(); + auto it = buf_map_.find(key); + ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; + const BufferEntry& e = it->second; + + ICHECK(e.in_scope) << "Read a buffer that is already out of scope"; + ICHECK_EQ(e.buffer->shape.size(), op->bounds.size()) + << "Prefetch dim should be the same as buffer dim"; + + if (e.remap) { + return Prefetch(e.remap->target, remap_bounds(op->bounds, e.remap->begins, e.remap->extents), + op->span); + } else { + return stmt; + } + } + + private: + // The specific tensor data layout is not determined before + // StorageFlatten pass. We use buffer_bind_scope + // to specify before hand we want to bind a subregion + // of tensor to a symbolic buffer, which get used in extern. + // + // Example: + // + // realize A in range [i*4, extent=10) { + // bind Ab to A in [i*4+1, extent=4) { + // call_func(Ab.ptr, Ab.shape[0]) + // } + // } + // + // After StorageFlatten + // + // alloc A[10] + // call(A + 1, 4) + // + // Buffer is a protocol to declare specific + // data layout and shape we expect. + // So this function need to check: + // - If the bind range is within the realize range + // - If we can match the requirement of buffer + // - Remap variables such as Ab.ptr to the actual value. + // + // Here are a few possible failure cases: + // - Buffer is declared to have constant shape, + // but we try to bind it to a different one. + // - Buffer is declared to be compact(no strides) + // but this binded region is a subregion of + // a matrix(tensor), which means it requires strides. + // + // We do support a few relaxed case, such as bindingx + // region with shape [1, 1, n, m] to buffer with shape [n, m] + Stmt HandleBufferBindScope(const AttrStmtNode* op) { + // Unpack information from Attribute node + RemapInfo remap; + + Array arr = Downcast>(op->node); + ICHECK_EQ(arr.size(), 2U); + const Buffer source = Downcast(arr[0]); + ICHECK(source.defined()); + remap.target = Downcast(arr[1]); + ICHECK(remap.target.defined()); + const CallNode* tuple = op->value.as(); + ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); + + for (size_t i = 0; i < tuple->args.size(); i += 2) { + remap.begins.push_back(tuple->args[i]); + remap.extents.push_back(tuple->args[i + 1]); + } + + // Determine bounds in the target buffer + auto it = buf_map_.find(remap.target.get()); + ICHECK(it != buf_map_.end()) << "Cannot find buffer " << remap.target << " @ " + << remap.target.get(); + const BufferEntry& target_info = it->second; + ICHECK(target_info.in_scope) << "Cannot bind to a buffer that is out of scope"; + ICHECK_EQ(remap.begins.size(), target_info.buffer->shape.size()) + << "Incorrect number of arguments in buffer_bind_scope attribute. " + << "Expected (min_0, extent_0, min_1, extent_0, ..., min_N, extent_N)."; + + if (target_info.bounds.size() > 0) { + Array mapped_begins; + for (size_t i = 0; i < target_info.buffer->shape.size(); ++i) { + mapped_begins.push_back(remap.begins[i] - target_info.bounds[i]->min); + } + remap.begins = std::move(mapped_begins); + } + + ICHECK(target_info.remap == nullptr) << "Indirect remapping not handled"; + + for (size_t i = 0; i < remap.begins.size(); i++) { + remap.begins.Set(i, bound_analyzer_->Simplify(remap.begins[i])); + remap.extents.Set(i, bound_analyzer_->Simplify(remap.extents[i])); + } + + // Add a buffer remap entry + { + BufferEntry source_info; + source_info.buffer = source; + source_info.remap = std::make_unique(remap); + + buf_map_[source.get()] = std::move(source_info); + } + + // Generate slice that represents the source's view into the + // target buffer. + ICHECK_EQ(source->strides.size(), 0) << "Buffer view cannot have strides defined."; + + // Define remappings of any remaining Variables (e.g. Store/Load nodes). + ArgBinder binder(&var_remap_); + + binder.Bind(source->data, remap.target->data, source->name + ".data"); + binder.Bind(source->elem_offset, remap.target->ElemOffset(remap.begins), + source->name + ".elem_offset"); + + // Apply the remaps + Stmt body = op->body; + body = MergeNest(binder.asserts(), body); + body = MergeNest(binder.init_nest(), body); + body = this->VisitStmt(body); + // remove the binds + for (const Var& v : binder.defs()) { + var_remap_.erase(v.get()); + } + return body; + } + + struct RemapInfo { + Buffer target; + Array begins; + Array extents; + }; + + // The buffer entry in the flatten map + struct BufferEntry { + // The storage buffer + Buffer buffer; + // the bounds of realization, can be null, means everything + Region bounds; + // Whether the buffer is external + bool external{false}; + // Whether we are within the allocation scope of the buffer. + bool in_scope{true}; + + // The buffer to which the storage buffer should be remapped. + std::unique_ptr remap{nullptr}; + + PrimExpr ElemOffset() const { + ICHECK(remap); + + Buffer copy = remap->target; + { + Array shape; + for (auto r : bounds) { + shape.push_back(r->extent); + } + copy.CopyOnWrite()->shape = std::move(shape); + } + + Buffer target_slice = copy.MakeSlice(remap->begins, remap->extents); + if (buffer->strides.size() == 0) { + ICHECK_EQ(target_slice->strides.size(), 0U) + << "Trying to bind compact buffer to strided one strides=" << target_slice->strides; + } else { + target_slice = target_slice.MakeStrideView(); + } + + return copy->ElemOffset(remap->begins); + } + }; + + // The buffer assignment map + // Variable remap + std::unordered_map var_remap_; + // Buffer map + std::unordered_map buf_map_; + // Analyzer for the variable bounds, used to simplify the bounds populator. We really need the + // analyzer from it. However + IRVisitorWithAnalyzer* bound_analyzer_; +}; + class StorageFlattener : public StmtExprMutator { public: explicit StorageFlattener(const Map& extern_buffer_map, int cache_line_size, @@ -637,6 +995,10 @@ class StorageFlattener : public StmtExprMutator { << "and all buffer_dim_align annotations are removed. " << "Please run BufferStrideLegalize first."; + ICHECK_NE(op->attr_key, attr::buffer_bind_scope) + << "StorageFlattener assumes that all buffer binds have already been applied. " + << "Please run BufferBindUnwrapper first."; + if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); @@ -644,8 +1006,6 @@ class StorageFlattener : public StmtExprMutator { ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; body = AttrStmt(it->second.buffer->data, op->attr_key, op->value, std::move(body)); return body; - } else if (op->attr_key == attr::buffer_bind_scope) { - return HandleBufferBindScope(op); } return StmtExprMutator::VisitStmt_(op); } @@ -841,107 +1201,24 @@ class StorageFlattener : public StmtExprMutator { } PrimExpr VisitExpr_(const ProducerLoadNode* op) final { - LOG(FATAL) << "ProducerLoad cannot appear in a valid TIR PrimFunc."; + LOG(FATAL) << "ProducerLoad cannot appear in a valid TIR PrimFunc. " + << "Please run SchedulePostProcToPrimFunc first."; return PrimExpr(); } Stmt VisitStmt_(const ProducerStoreNode* op) final { - LOG(FATAL) << "Cannot handle Provide " - << " please run SchedulePostProcToPrimFunc first"; + LOG(FATAL) << "ProducerStore cannot appear in a valid TIR PrimFunc. " + << "Please run SchedulePostProcToPrimFunc first."; return Stmt(); } Stmt VisitStmt_(const ProducerRealizeNode* op) final { - LOG(FATAL) << "Cannot handle Realize " - << " please run SchedulePostProcToPrimFunc first"; + LOG(FATAL) << "ProducerRealize cannot appear in a valid TIR PrimFunc. " + << "Please run SchedulePostProcToPrimFunc first."; return Stmt(); } private: - // The specific tensor data layout is not determined before - // StorageFlatten pass. We use buffer_bind_scope - // to specify before hand we want to bind a subregion - // of tensor to a symbolic buffer, which get used in extern. - // - // Example: - // - // realize A in range [i*4, extent=10) { - // bind Ab to A in [i*4+1, extent=4) { - // call_func(Ab.ptr, Ab.shape[0]) - // } - // } - // - // After StorageFlatten - // - // alloc A[10] - // call(A + 1, 4) - // - // Buffer is a protocol to declare specific - // data layout and shape we expect. - // So this function need to check: - // - If the bind range is within the realize range - // - If we can match the requirement of buffer - // - Remap variables such as Ab.ptr to the actual value. - // - // Here are a few possible failure cases: - // - Buffer is declared to have constant shape, - // but we try to bind it to a different one. - // - Buffer is declared to be compact(no strides) - // but this binded region is a subregion of - // a matrix(tensor), which means it requires strides. - // - // We do support a few relaxed case, such as bindingx - // region with shape [1, 1, n, m] to buffer with shape [n, m] - Stmt HandleBufferBindScope(const AttrStmtNode* op) { - Array arr = Downcast>(op->node); - ICHECK_EQ(arr.size(), 2U); - const BufferNode* buffer = arr[0].as(); - const BufferNode* target = arr[1].as(); - const CallNode* tuple = op->value.as(); - ICHECK(buffer && target); - ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); - auto key = GetRef(target); - - auto it = buf_map_.find(key); - ICHECK(it != buf_map_.end()) << "Cannot find buffer of " << key; - const BufferEntry& be = it->second; - ICHECK(!be.released); - ICHECK_EQ(tuple->args.size(), be.buffer->shape.size() * 2); - Array begins, extents; - if (be.bounds.size() != 0) { - ICHECK_EQ(tuple->args.size(), be.bounds.size() * 2); - for (size_t i = 0; i < be.buffer->shape.size(); ++i) { - begins.push_back(tuple->args[2 * i] - be.bounds[i]->min); - extents.push_back(tuple->args[2 * i + 1]); - } - } else { - for (size_t i = 0; i < tuple->args.size(); i += 2) { - begins.push_back(tuple->args[i]); - auto new_extent = bound_analyzer_->Simplify(tuple->args[i + 1]); - extents.push_back(new_extent); - } - } - Buffer slice = be.buffer.MakeSlice(begins, extents); - if (buffer->strides.size() == 0) { - ICHECK_EQ(slice->strides.size(), 0U) - << "Trying to bind compact buffer to strided one strides=" << slice->strides; - } else { - slice = slice.MakeStrideView(); - } - // start binding - ArgBinder binder(&var_remap_); - binder.BindBuffer(Downcast(arr[0]), slice, buffer->name, true); - // Apply the remaps - Stmt body = MergeNest(binder.asserts(), op->body); - body = MergeNest(binder.init_nest(), body); - body = this->VisitStmt(body); - // remove the binds - for (const Var& v : binder.defs()) { - var_remap_.erase(v.get()); - } - return body; - } - // The buffer entry in the flatten map struct DimAlignInfo { int align_factor{0}; @@ -1009,6 +1286,40 @@ class StorageFlattener : public StmtExprMutator { bool create_bound_attributes_{false}; }; +// The specific tensor data layout is not determined before +// StorageFlatten pass. We use buffer_bind_scope +// to specify before hand we want to bind a subregion +// of tensor to a symbolic buffer, which get used in extern. +// +// Example: +// +// realize A in range [i*4, extent=10) { +// bind Ab to A in [i*4+1, extent=4) { +// call_func(Ab.ptr, Ab.shape[0]) +// } +// } +// +// After StorageFlatten +// +// alloc A[10] +// call(A + 1, 4) +// +// Buffer is a protocol to declare specific +// data layout and shape we expect. +// So this function need to check: +// - If the bind range is within the realize range +// - If we can match the requirement of buffer +// - Remap variables such as Ab.ptr to the actual value. +// +// Here are a few possible failure cases: +// - Buffer is declared to have constant shape, +// but we try to bind it to a different one. +// - Buffer is declared to be compact(no strides) +// but this binded region is a subregion of +// a matrix(tensor), which means it requires strides. +// +// We do support a few relaxed case, such as bindingx +// region with shape [1, 1, n, m] to buffer with shape [n, m] PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { // Only apply this pass to TIR from TE schedules Optional from_legacy_te_schedule = func->GetAttr("from_legacy_te_schedule", Bool(false)); @@ -1026,6 +1337,8 @@ PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_at fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body)); + fptr->body = BufferBindUnwrapper(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); + fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer)(std::move(fptr->body)); From 1c729e6061b6947142b7a783a6973f87356ac305 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 23 Sep 2021 14:09:45 -0500 Subject: [PATCH 07/21] [TIR] Removed checks on buffer->shape.size() Even after BufferShapeLegalize, rank-zero tensors may have an empty shape. --- src/tir/transforms/storage_flatten.cc | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 6a476ed06323..102db7b674cb 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -288,7 +288,7 @@ class BufferStrideLegalize : public StmtExprMutator { : bound_analyzer_(bound_analyzer) { for (auto kv : extern_buffer_map) { Buffer buf = kv.second; - Buffer with_strides = WithStrides(buf, true); + Buffer with_strides = WithStrides(buf); { BufferEntry entry; entry.remap_to = with_strides; @@ -302,7 +302,7 @@ class BufferStrideLegalize : public StmtExprMutator { Map UpdatedExternBufferMap() const { return updated_extern_buffer_map_; } - Buffer WithStrides(Buffer buf, bool allow_scalar = false) { + Buffer WithStrides(Buffer buf) { auto it = buf_map_.find(buf); if (it != buf_map_.end()) { const BufferEntry& entry = it->second; @@ -317,17 +317,6 @@ class BufferStrideLegalize : public StmtExprMutator { Array shape = buf->shape; - if (shape.size() == 0) { - // This is only allowed for buffers that point to external - // buffers. These are treated as 1-d pointers of unknown size, - // with stride of 1. - ICHECK(allow_scalar) - << "Buffer " << buf << " does not have a valid shape. " - << "BufferStrideLegalize requires all internal buffers to have a valid shape. " - << "Please run BufferShapeLegalize first"; - return buf; - } - // Keeping this to have matched behavior to previous version. // There are many parts of the codebase that assume that a strided // array cannot be compact. @@ -1048,9 +1037,6 @@ class StorageFlattener : public StmtExprMutator { BufferEntry e; e.bounds = op->bounds; - ICHECK(op->buffer->shape.size()) << "StorageFlattener expects buffer shapes to be defined. " - << "Please run through BufferShapeLegalize first."; - ICHECK_EQ(op->buffer->shape.size(), op->bounds.size()) << "Inconsistent buffer shape and realization shape for " << op->buffer; From 40ab5a405c7e11417405f753194dcccbc1b78a01 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 23 Sep 2021 15:28:40 -0500 Subject: [PATCH 08/21] [TIR] Relaxed check on a bufferview's striding. Original refactoring requiring that a bufferview have no explicit striding, and instead take the striding from the buffer that it is viewing. Modified to allow bufferview to specify striding, so long as it is consistent with the viewed buffer's striding. This reproduces the behavior of StorageFlatten before the refactoring. --- src/tir/transforms/storage_flatten.cc | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 102db7b674cb..f3db04b119c7 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -874,16 +874,22 @@ class BufferBindUnwrapper : public StmtExprMutator { buf_map_[source.get()] = std::move(source_info); } - // Generate slice that represents the source's view into the - // target buffer. - ICHECK_EQ(source->strides.size(), 0) << "Buffer view cannot have strides defined."; - // Define remappings of any remaining Variables (e.g. Store/Load nodes). ArgBinder binder(&var_remap_); - binder.Bind(source->data, remap.target->data, source->name + ".data"); - binder.Bind(source->elem_offset, remap.target->ElemOffset(remap.begins), - source->name + ".elem_offset"); + // Define a view that represents the source's view into the target + // buffer. This Buffer object is only used to define the mapping + // to the target buffer, and never actually appears in the TIR + // graph. + Buffer view = remap.target.MakeSlice(remap.begins, remap.extents); + if (source->strides.size() == 0) { + ICHECK_EQ(view->strides.size(), 0U) + << "Cannot bind a compact buffer to a strided buffer" << view->strides; + } else { + // Add explicit strides to the view, in order to bind to source.strides[i]. + view = view.MakeStrideView(); + } + binder.BindBuffer(source, view, source->name, true); // Apply the remaps Stmt body = op->body; From f3612631415780f457c5344460edaf63007f4c80 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 23 Sep 2021 16:23:31 -0500 Subject: [PATCH 09/21] [TIR] Fixed StorageFlatten test for shape_legalize. AttrStmtNodes that contain rewritten Buffers need to be rewritten as well. --- src/tir/transforms/storage_flatten.cc | 36 +++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index f3db04b119c7..ba1f4490d33e 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -179,30 +179,29 @@ class BufferShapeLegalize : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { - return HandleDoubleBuffer(op); + if (op->node->IsInstance()) { + // Visit body before checking internal_buf_map_, because we + // don't know if the BufferNode needs to be changed until we + // look in the body for a BufferRealizeNode with different + // extents. + Stmt body = this->VisitStmt(op->body); + + Buffer buffer = Downcast(op->node); + auto it = internal_buf_map_.find(buffer); + if (it != internal_buf_map_.end()) { + buffer = it->second.remap_to; + return AttrStmt(it->second.remap_to, op->attr_key, op->value, body); + } + return AttrStmt(buffer, op->attr_key, op->value, body); + } else if (op->attr_key == attr::buffer_bind_scope) { return HandleBufferBindScope(op); - } else { - return StmtExprMutator::VisitStmt_(op); } - } - private: - Stmt HandleDoubleBuffer(const AttrStmtNode* op) { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op); - - Buffer buffer = Downcast(op->node); - auto it = internal_buf_map_.find(buffer); - if (it != internal_buf_map_.end()) { - return AttrStmt(it->second.remap_to->data, op->attr_key, op->value, op->body); - } else { - return stmt; - } + return StmtExprMutator::VisitStmt_(op); } + private: Stmt HandleBufferBindScope(const AttrStmtNode* op) { Array arr = Downcast>(op->node); ICHECK_EQ(arr.size(), 2U); @@ -358,6 +357,7 @@ class BufferStrideLegalize : public StmtExprMutator { } vinfo[dim].align_factor = tuple->args[1].as()->value; vinfo[dim].align_offset = tuple->args[2].as()->value; + return this->VisitStmt(op->body); } else if (op->attr_key == attr::buffer_bind_scope) { Array arr = Downcast>(op->node); From 5a79bc35f66361c8c9d26651546b26b082d1b5f8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 24 Sep 2021 13:37:17 -0500 Subject: [PATCH 10/21] [TIR] Assigned storage scope The earlier stage of the refactor left a buffer's storage scope undefined if it's scope was not determined by the IterVar of a loop containing its allocation. Now, these are explicitly set to StorageScope::kGlobal, to match the previous behavior of StorageFlatten. --- src/tir/transforms/storage_flatten.cc | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index ba1f4490d33e..c8a997338554 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -453,7 +453,8 @@ class BufferStrideLegalize : public StmtExprMutator { * * For buffers that do not have an explicit storage scope defined, a * reasonable storage scope may be defined based on the thread scope - * that contains the buffer's allocation. + * that contains the buffer's allocation. All other buffers without a + * scope are assigned to global scope. */ class ThreadScopePropagate : public StmtExprMutator { public: @@ -498,10 +499,9 @@ class ThreadScopePropagate : public StmtExprMutator { Var old_var = op->buffer->data; // Don't remap buffers that already have an explicit scope, - // external buffers, or buffers outside of a thread scope. + // or external buffers. std::string str_scope = GetPtrStorageScope(old_var); - if ((str_scope.length() > 0) || external_buffers_.count(op->buffer) || - (curr_thread_scope_.size() == 0)) { + if ((str_scope.length() > 0) || external_buffers_.count(op->buffer)) { return StmtExprMutator::VisitStmt_(op); } @@ -509,7 +509,11 @@ class ThreadScopePropagate : public StmtExprMutator { << "Buffer var " << op->buffer->data << " appears in multiple BufferRealize nodes"; StorageScope skey; - skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); + if (curr_thread_scope_.size() == 0) { + skey.rank = StorageRank::kGlobal; + } else { + skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); + } auto ptr_type = old_var->type_annotation.as(); ICHECK(ptr_type); From 31c426924196685e4911705c0d3883da30881138 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 24 Sep 2021 13:41:42 -0500 Subject: [PATCH 11/21] Updated ICHECK_EQ to CHECK_EQ for a test that depends on user-provided data. --- src/tir/transforms/storage_flatten.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index c8a997338554..286604d1527b 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -71,7 +71,7 @@ class BufferShapeLegalize : public StmtExprMutator { Stmt VisitStmt_(const BufferRealizeNode* op) final { // External buffers should not be changed. if (extern_buffers_.count(op->buffer)) { - ICHECK_EQ(op->buffer->shape.size(), op->bounds.size()) + CHECK_EQ(op->buffer->shape.size(), op->bounds.size()) << "External buffer realize has mismatched dimension"; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); From 7d3f1ae41a52777755f28c10b8d36f618b5381fb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 24 Sep 2021 13:52:30 -0500 Subject: [PATCH 12/21] Added comments in storage_flatten.cc, indicating why buffer_bind_scope needs special handling. --- src/tir/transforms/storage_flatten.cc | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 286604d1527b..c0b8f060f834 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -202,6 +202,13 @@ class BufferShapeLegalize : public StmtExprMutator { } private: + // Any buffers that give views into a resized buffer should be + // updated, both to refer to the resized buffer and to have the view + // window updated. For example, suppose B1 is a 1-D buffer of size + // 100 which is only realized on the range (10,50), and buffer V1 is + // a view into B1[25:35]. When B1 is replaced with B2, a buffer of + // size 40 realized on the range (0,40), V1 must be replaced to be a + // view into B2[15:25]. Stmt HandleBufferBindScope(const AttrStmtNode* op) { Array arr = Downcast>(op->node); ICHECK_EQ(arr.size(), 2U); @@ -556,6 +563,10 @@ class ThreadScopePropagate : public StmtExprMutator { } private: + // If the rewritten buffers are part of a buffer_bind_scope, either + // as the buffer view or as the the buffer being viewed, then the + // buffer_bind_scope must be rewritten to refer to the updated + // buffers. Stmt HandleBufferBindScope(const AttrStmtNode* op) { Array arr = Downcast>(op->node); ICHECK_EQ(arr.size(), 2U); From 506c3914c42f6aa936cec0a3cf221db0dd036004 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 24 Sep 2021 14:03:49 -0500 Subject: [PATCH 13/21] Updated comment with a few examples of where compact buffers are assumed to have no strides defined. --- src/tir/transforms/storage_flatten.cc | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index c0b8f060f834..038af2cf1170 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -321,15 +321,20 @@ class BufferStrideLegalize : public StmtExprMutator { return buf; } - Array shape = buf->shape; - // Keeping this to have matched behavior to previous version. // There are many parts of the codebase that assume that a strided - // array cannot be compact. + // array cannot be compact. For example, ArgBinder::BindBuffer + // and tir.Specialize. if (dim_align_.count(buf) == 0) { return buf; } + // Can't define the strides for a buffer without a known shape. + Array shape = buf->shape; + if (shape.size() == 0) { + return buf; + } + std::vector rstrides; const std::vector& avec = dim_align_[buf]; int first_dim = 0; From 054cd720e5f8bbb6331ae8dbcb38b1e90015f633 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 27 Sep 2021 14:23:49 -0500 Subject: [PATCH 14/21] Updated following @csullivan's comments. --- src/tir/transforms/storage_flatten.cc | 62 ++------------------------- 1 file changed, 4 insertions(+), 58 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 038af2cf1170..f3f57a4514b7 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -808,40 +808,9 @@ class BufferBindUnwrapper : public StmtExprMutator { } private: - // The specific tensor data layout is not determined before - // StorageFlatten pass. We use buffer_bind_scope - // to specify before hand we want to bind a subregion - // of tensor to a symbolic buffer, which get used in extern. - // - // Example: - // - // realize A in range [i*4, extent=10) { - // bind Ab to A in [i*4+1, extent=4) { - // call_func(Ab.ptr, Ab.shape[0]) - // } - // } - // - // After StorageFlatten - // - // alloc A[10] - // call(A + 1, 4) - // - // Buffer is a protocol to declare specific - // data layout and shape we expect. - // So this function need to check: - // - If the bind range is within the realize range - // - If we can match the requirement of buffer - // - Remap variables such as Ab.ptr to the actual value. - // - // Here are a few possible failure cases: - // - Buffer is declared to have constant shape, - // but we try to bind it to a different one. - // - Buffer is declared to be compact(no strides) - // but this binded region is a subregion of - // a matrix(tensor), which means it requires strides. - // - // We do support a few relaxed case, such as bindingx - // region with shape [1, 1, n, m] to buffer with shape [n, m] + // Read the mapping from a buffer view to the actual buffer. This + // allows all later BufferStore/BufferLoad nodes to reference the + // actual buffer, rather than the buffer view. Stmt HandleBufferBindScope(const AttrStmtNode* op) { // Unpack information from Attribute node RemapInfo remap; @@ -942,29 +911,6 @@ class BufferBindUnwrapper : public StmtExprMutator { // The buffer to which the storage buffer should be remapped. std::unique_ptr remap{nullptr}; - - PrimExpr ElemOffset() const { - ICHECK(remap); - - Buffer copy = remap->target; - { - Array shape; - for (auto r : bounds) { - shape.push_back(r->extent); - } - copy.CopyOnWrite()->shape = std::move(shape); - } - - Buffer target_slice = copy.MakeSlice(remap->begins, remap->extents); - if (buffer->strides.size() == 0) { - ICHECK_EQ(target_slice->strides.size(), 0U) - << "Trying to bind compact buffer to strided one strides=" << target_slice->strides; - } else { - target_slice = target_slice.MakeStrideView(); - } - - return copy->ElemOffset(remap->begins); - } }; // The buffer assignment map @@ -1330,7 +1276,7 @@ class StorageFlattener : public StmtExprMutator { // but this binded region is a subregion of // a matrix(tensor), which means it requires strides. // -// We do support a few relaxed case, such as bindingx +// We do support a few relaxed case, such as binding a // region with shape [1, 1, n, m] to buffer with shape [n, m] PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { // Only apply this pass to TIR from TE schedules From acf1c3a1bd00e7a98a01b5b7a8e3abbb74c5a1da Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 27 Sep 2021 21:31:08 -0500 Subject: [PATCH 15/21] Added fuzzy mapping to the BufferShapeLegalize. Maintains earlier behavior of StorageFlatten, which allows buffer views to be mapped to higher dimension buffers, if the view extent is 1 in each extra dimension. --- src/tir/transforms/storage_flatten.cc | 77 ++++++++++++++++++++++++--- 1 file changed, 71 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index f3f57a4514b7..1d77568ef035 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -68,6 +68,15 @@ class BufferShapeLegalize : public StmtExprMutator { } } + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = var_remap_.find(op); + if (it != var_remap_.end()) { + return it->second; + } else { + return GetRef(op); + } + } + Stmt VisitStmt_(const BufferRealizeNode* op) final { // External buffers should not be changed. if (extern_buffers_.count(op->buffer)) { @@ -103,6 +112,12 @@ class BufferShapeLegalize : public StmtExprMutator { new_bounds.push_back({0, bound->extent}); } + if (op->buffer->shape.size()) { + ICHECK_EQ(op->buffer->shape.size(), realized_shape.size()) + << "Inconsistency between dimension of buffer " << op->buffer + << " and dimension of its realized bounds."; + } + Buffer key = op->buffer; Buffer buf = op->buffer; @@ -137,8 +152,15 @@ class BufferShapeLegalize : public StmtExprMutator { << "Inconsistent dimensions for buffer " << op->buffer->name; Array new_indices; - for (size_t i = 0; i < entry.realized_begins.size(); i++) { - new_indices.push_back(op->indices[i] - entry.realized_begins[i]); + + // Pad leading indices with zero, matching the "fuzzy_match" + // behavior from ArgBinder::BindBuffer. + size_t diff = entry.realized_begins.size() - op->indices.size(); + for (size_t i = 0; i < diff; i++) { + new_indices.push_back(0); + } + for (size_t i = 0; i < op->indices.size(); i++) { + new_indices.push_back(op->indices[i] - entry.realized_begins[i + diff]); } BufferStore updated = GetRef(op); @@ -160,12 +182,20 @@ class BufferShapeLegalize : public StmtExprMutator { if (it != internal_buf_map_.end()) { const InternalBufferRemap& entry = it->second; ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer"; - ICHECK_EQ(entry.realized_begins.size(), op->indices.size()) + + ICHECK_GE(entry.realized_begins.size(), op->indices.size()) << "Inconsistent dimensions for buffer " << op->buffer->name; Array new_indices; - for (size_t i = 0; i < entry.realized_begins.size(); i++) { - new_indices.push_back(op->indices[i] - entry.realized_begins[i]); + + // Pad leading indices with zero, matching the "fuzzy_match" + // behavior from ArgBinder::BindBuffer. + size_t diff = entry.realized_begins.size() - op->indices.size(); + for (size_t i = 0; i < diff; i++) { + new_indices.push_back(0); + } + for (size_t i = 0; i < op->indices.size(); i++) { + new_indices.push_back(op->indices[i] - entry.realized_begins[i + diff]); } BufferLoad updated = GetRef(op); @@ -246,11 +276,32 @@ class BufferShapeLegalize : public StmtExprMutator { realized_begins.push_back(0); } + ICHECK_GE(realized_shape.size(), buffer->shape.size()) + << "Cannot bind " << buffer << " to a shape of lower dimension."; + Buffer key = buffer; auto write_ptr = buffer.CopyOnWrite(); write_ptr->shape = realized_shape; + // If a buffer has strides defined, and is being remapped into a + // shape with additional dimensions, then define dummy values for + // the strides. + if ((buffer->strides.size() > 0) && (buffer->strides.size() != buffer->shape.size())) { + ICHECK_LT(buffer->strides.size(), realized_shape.size()) + << "Cannot bind the strides of " << buffer << " to a shape of lower dimension"; + + auto num_additional_strides = realized_shape.size() - buffer->strides.size(); + Array updated_strides; + for (size_t i = 0; i < num_additional_strides; i++) { + updated_strides.push_back(Var("stride", buffer->shape[i].dtype())); + } + for (auto stride : buffer->strides) { + updated_strides.push_back(stride); + } + write_ptr->strides = updated_strides; + } + { InternalBufferRemap remap; remap.realized_begins = realized_begins; @@ -259,13 +310,26 @@ class BufferShapeLegalize : public StmtExprMutator { internal_buf_map_[key] = remap; } + // Define remappings of any Variables referencing Buffer internals (e.g. Store/Load nodes). + ArgBinder binder(&var_remap_); + binder.BindBuffer(key, buffer, key->name, true); + Stmt stmt = AttrStmt(Array{buffer, target_remap.remap_to}, op->attr_key, Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span), this->VisitStmt(op->body)); + stmt = MergeNest(binder.asserts(), stmt); + stmt = MergeNest(binder.init_nest(), stmt); + + for (const Var& v : binder.defs()) { + var_remap_.erase(v.get()); + } + internal_buf_map_.at(key).in_scope = false; return stmt; } + std::unordered_map var_remap_; + std::unordered_set extern_buffers_; struct InternalBufferRemap { @@ -317,7 +381,8 @@ class BufferStrideLegalize : public StmtExprMutator { } if (buf->strides.size()) { - ICHECK_EQ(buf->strides.size(), buf->shape.size()); + ICHECK_EQ(buf->strides.size(), buf->shape.size()) + << "Buffer " << buf << " has inconsistent strides/shape."; return buf; } From 0b06e9881e342a59dbbdf0755c22c7d04f3bb04b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 28 Sep 2021 08:45:10 -0500 Subject: [PATCH 16/21] Updated BufferShapeLegalize, asserts need to be inside the buffer_bind_scope. --- src/tir/transforms/storage_flatten.cc | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 1d77568ef035..b197c773ca0a 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -314,11 +314,12 @@ class BufferShapeLegalize : public StmtExprMutator { ArgBinder binder(&var_remap_); binder.BindBuffer(key, buffer, key->name, true); + Stmt body = this->VisitStmt(op->body); + body = MergeNest(binder.asserts(), body); + body = MergeNest(binder.init_nest(), body); + Stmt stmt = AttrStmt(Array{buffer, target_remap.remap_to}, op->attr_key, - Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span), - this->VisitStmt(op->body)); - stmt = MergeNest(binder.asserts(), stmt); - stmt = MergeNest(binder.init_nest(), stmt); + Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span), body); for (const Var& v : binder.defs()) { var_remap_.erase(v.get()); From 87a5d48132a2116cd1817f788f426f6817805481 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 28 Sep 2021 15:42:19 -0500 Subject: [PATCH 17/21] Pulled all shape-dependent behavior into BufferShapeLegalize. Previously, BufferBindUnwrapper passed fuzzy_match=true to ArgBinder::BindBuffer, which could change the number of dimensions. Now, all buffer dimensions should be updated prior to BufferBindUnwrapper, and it is an error to have mismatched dimensions in BufferBindUnwrapper. --- src/tir/transforms/storage_flatten.cc | 58 +++++++++++++++++++-------- 1 file changed, 41 insertions(+), 17 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index b197c773ca0a..4e8e2a5714b6 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -64,7 +64,14 @@ class BufferShapeLegalize : public StmtExprMutator { IRVisitorWithAnalyzer* bound_analyzer) : bound_analyzer_(bound_analyzer) { for (auto kv : extern_buffer_map) { - extern_buffers_.insert(kv.second); + Buffer buf = kv.second; + extern_buffers_.insert(buf); + + InternalBufferRemap remap; + remap.remap_to = buf; + remap.realized_begins = Array(buf->shape.size(), 0); + remap.in_scope = true; + internal_buf_map_[buf] = remap; } } @@ -248,9 +255,8 @@ class BufferShapeLegalize : public StmtExprMutator { ICHECK(target.defined()); auto it = internal_buf_map_.find(target); - if (it == internal_buf_map_.end()) { - return StmtExprMutator::VisitStmt_(op); - } + ICHECK(it != internal_buf_map_.end()) + << "attr::buffer_bind_scope target " << target << " not in scope."; const InternalBufferRemap& target_remap = it->second; ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name @@ -262,7 +268,9 @@ class BufferShapeLegalize : public StmtExprMutator { Array new_tuple_args; Array realized_begins; Array realized_shape; - ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2); + ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2) + << "attr::buffer_bind_scope to define " << buffer << " as a view into " << target + << " does match dimensionality of " << target; for (size_t i = 0; i < target_remap.realized_begins.size(); i++) { PrimExpr parent_begin = tuple->args[2 * i]; PrimExpr view_extent = tuple->args[2 * i + 1]; @@ -276,32 +284,48 @@ class BufferShapeLegalize : public StmtExprMutator { realized_begins.push_back(0); } + // If a view is binding to a buffer of a higher dimensionality, + // then the leading dimensions should be padded out with shape of + // 1. ICHECK_GE(realized_shape.size(), buffer->shape.size()) << "Cannot bind " << buffer << " to a shape of lower dimension."; - - Buffer key = buffer; - - auto write_ptr = buffer.CopyOnWrite(); - write_ptr->shape = realized_shape; + if (realized_shape.size() > buffer->shape.size()) { + size_t diff = realized_shape.size() - buffer->shape.size(); + Array padded_shape; + for (size_t i = 0; i < diff; i++) { + padded_shape.push_back(1); + } + for (auto dim : buffer->shape) { + padded_shape.push_back(dim); + } + realized_shape = std::move(padded_shape); + } // If a buffer has strides defined, and is being remapped into a // shape with additional dimensions, then define dummy values for // the strides. - if ((buffer->strides.size() > 0) && (buffer->strides.size() != buffer->shape.size())) { - ICHECK_LT(buffer->strides.size(), realized_shape.size()) + Array realized_strides = buffer->strides; + if ((realized_strides.size() > 0) && (realized_strides.size() != realized_shape.size())) { + ICHECK_GE(realized_shape.size(), realized_strides.size()) << "Cannot bind the strides of " << buffer << " to a shape of lower dimension"; + size_t diff = realized_shape.size() - buffer->strides.size(); - auto num_additional_strides = realized_shape.size() - buffer->strides.size(); Array updated_strides; - for (size_t i = 0; i < num_additional_strides; i++) { - updated_strides.push_back(Var("stride", buffer->shape[i].dtype())); + for (size_t i = 0; i < diff; i++) { + updated_strides.push_back(Var("stride", buffer->shape[0].dtype())); } for (auto stride : buffer->strides) { updated_strides.push_back(stride); } - write_ptr->strides = updated_strides; + realized_strides = updated_strides; } + Buffer key = buffer; + + auto write_ptr = buffer.CopyOnWrite(); + write_ptr->shape = realized_shape; + write_ptr->strides = realized_strides; + { InternalBufferRemap remap; remap.realized_begins = realized_begins; @@ -944,7 +968,7 @@ class BufferBindUnwrapper : public StmtExprMutator { // Add explicit strides to the view, in order to bind to source.strides[i]. view = view.MakeStrideView(); } - binder.BindBuffer(source, view, source->name, true); + binder.BindBuffer(source, view, source->name, false); // Apply the remaps Stmt body = op->body; From 78af07721e3f3406641b5e8e9ef7b544b23946aa Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 29 Sep 2021 08:35:32 -0500 Subject: [PATCH 18/21] Added another pass to remove verifiable assert statements. ArgBinder::BindBuffer inserts these assert statements if they are not verifiable at the time of substitution. Previously, with one giant substitution, the assertions were verifiable at that time. After the refactor, with substitutions done in multiple stages for shape/stride/buffer_bind_scope, we need to clean up any assertions that are verifiable after all substitutions have occurred. --- src/tir/transforms/storage_flatten.cc | 34 +++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 4e8e2a5714b6..b5d8fd911337 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1334,6 +1334,38 @@ class StorageFlattener : public StmtExprMutator { bool create_bound_attributes_{false}; }; +/*! + * \brief Simplify assert statements. + * + * If an assert statement can be statically verified to be false, emit + * a failure at compile time. If an assert statement can be + * statically verified to be true, remove the assert statement. If + * neither case can be verified, keep the assert statement unmodified. + */ +class AssertSimplifier : public StmtMutator { + public: + explicit AssertSimplifier(IRVisitorWithAnalyzer* bound_analyzer) + : bound_analyzer_(bound_analyzer) {} + + Stmt VisitStmt_(const AssertStmtNode* op) final { + Stmt stmt = StmtMutator::VisitStmt_(op); + op = stmt.as(); + + PrimExpr condition = bound_analyzer_->Simplify(op->condition); + if (is_zero(condition)) { + LOG(FATAL) << "Assert statement failed during static checking: " << op->message; + } + if (is_one(condition)) { + return op->body; + } + + return stmt; + } + + private: + IRVisitorWithAnalyzer* bound_analyzer_; +}; + // The specific tensor data layout is not determined before // StorageFlatten pass. We use buffer_bind_scope // to specify before hand we want to bind a subregion @@ -1390,6 +1422,8 @@ PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_at fptr->body = StorageFlattener(fptr->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer)(std::move(fptr->body)); + fptr->body = AssertSimplifier(&bound_analyzer)(std::move(fptr->body)); + return func; } else { return func; From 3d3ec42437be465a22e3880b41debb7a56a06ba6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 29 Sep 2021 09:53:55 -0500 Subject: [PATCH 19/21] Minor cleanup - Removed StorageFlattener::BufferEntry::RelIndex, behavior already handled by BufferShapeLegalize. - Improved comments and error messages. - Extracted duplicate behavior in BufferLoad/BufferStore handling in BufferShapeLegalize. --- src/tir/transforms/storage_flatten.cc | 218 +++++++++++++------------- 1 file changed, 111 insertions(+), 107 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index b5d8fd911337..983c9f341568 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -67,11 +67,11 @@ class BufferShapeLegalize : public StmtExprMutator { Buffer buf = kv.second; extern_buffers_.insert(buf); - InternalBufferRemap remap; + BufferEntry remap; remap.remap_to = buf; - remap.realized_begins = Array(buf->shape.size(), 0); + remap.index_offsets = Array(buf->shape.size(), 0); remap.in_scope = true; - internal_buf_map_[buf] = remap; + buf_map_[buf] = remap; } } @@ -85,7 +85,10 @@ class BufferShapeLegalize : public StmtExprMutator { } Stmt VisitStmt_(const BufferRealizeNode* op) final { - // External buffers should not be changed. + // BufferRealizeNode for an external buffer serves as an + // annotation of the external buffers, and should not be changed. + // Instead, verify that the bounds match the external + // buffer. if (extern_buffers_.count(op->buffer)) { CHECK_EQ(op->buffer->shape.size(), op->bounds.size()) << "External buffer realize has mismatched dimension"; @@ -110,12 +113,12 @@ class BufferShapeLegalize : public StmtExprMutator { // Compute the new buffer shape, new realization bounds, and the // offsets to be applied to buffer access. Array realized_shape; - Array realized_begins; + Array index_offsets; Array new_bounds; for (size_t i = 0; i < op->bounds.size(); i++) { const Range& bound = op->bounds[i]; realized_shape.push_back(bound->extent); - realized_begins.push_back(bound->min); + index_offsets.push_back(bound->min); new_bounds.push_back({0, bound->extent}); } @@ -132,16 +135,16 @@ class BufferShapeLegalize : public StmtExprMutator { write_ptr->shape = realized_shape; { - InternalBufferRemap remap; + BufferEntry remap; remap.remap_to = buf; - remap.realized_begins = realized_begins; + remap.index_offsets = index_offsets; remap.in_scope = true; - internal_buf_map_[key] = remap; + buf_map_[key] = remap; } Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span); - internal_buf_map_.at(key).in_scope = false; + buf_map_.at(key).in_scope = false; return stmt; } @@ -151,28 +154,14 @@ class BufferShapeLegalize : public StmtExprMutator { op = stmt.as(); ICHECK(op); - auto it = internal_buf_map_.find(op->buffer); - if (it != internal_buf_map_.end()) { - const InternalBufferRemap& entry = it->second; + auto it = buf_map_.find(op->buffer); + if (it != buf_map_.end()) { + const BufferEntry& entry = it->second; ICHECK(entry.in_scope) << "Cannot store to an out-of-scope buffer"; - ICHECK_EQ(entry.realized_begins.size(), op->indices.size()) - << "Inconsistent dimensions for buffer " << op->buffer->name; - - Array new_indices; - - // Pad leading indices with zero, matching the "fuzzy_match" - // behavior from ArgBinder::BindBuffer. - size_t diff = entry.realized_begins.size() - op->indices.size(); - for (size_t i = 0; i < diff; i++) { - new_indices.push_back(0); - } - for (size_t i = 0; i < op->indices.size(); i++) { - new_indices.push_back(op->indices[i] - entry.realized_begins[i + diff]); - } BufferStore updated = GetRef(op); auto write_ptr = updated.CopyOnWrite(); - write_ptr->indices = new_indices; + write_ptr->indices = update_indices(op->indices, entry.index_offsets); write_ptr->buffer = entry.remap_to; stmt = updated; } @@ -185,29 +174,14 @@ class BufferShapeLegalize : public StmtExprMutator { op = expr.as(); ICHECK(op); - auto it = internal_buf_map_.find(op->buffer); - if (it != internal_buf_map_.end()) { - const InternalBufferRemap& entry = it->second; + auto it = buf_map_.find(op->buffer); + if (it != buf_map_.end()) { + const BufferEntry& entry = it->second; ICHECK(entry.in_scope) << "Cannot read from an out-of-scope buffer"; - ICHECK_GE(entry.realized_begins.size(), op->indices.size()) - << "Inconsistent dimensions for buffer " << op->buffer->name; - - Array new_indices; - - // Pad leading indices with zero, matching the "fuzzy_match" - // behavior from ArgBinder::BindBuffer. - size_t diff = entry.realized_begins.size() - op->indices.size(); - for (size_t i = 0; i < diff; i++) { - new_indices.push_back(0); - } - for (size_t i = 0; i < op->indices.size(); i++) { - new_indices.push_back(op->indices[i] - entry.realized_begins[i + diff]); - } - BufferLoad updated = GetRef(op); auto write_ptr = updated.CopyOnWrite(); - write_ptr->indices = new_indices; + write_ptr->indices = update_indices(op->indices, entry.index_offsets); write_ptr->buffer = entry.remap_to; expr = updated; } @@ -224,8 +198,8 @@ class BufferShapeLegalize : public StmtExprMutator { Stmt body = this->VisitStmt(op->body); Buffer buffer = Downcast(op->node); - auto it = internal_buf_map_.find(buffer); - if (it != internal_buf_map_.end()) { + auto it = buf_map_.find(buffer); + if (it != buf_map_.end()) { buffer = it->second.remap_to; return AttrStmt(it->second.remap_to, op->attr_key, op->value, body); } @@ -254,10 +228,9 @@ class BufferShapeLegalize : public StmtExprMutator { Buffer target = Downcast(arr[1]); ICHECK(target.defined()); - auto it = internal_buf_map_.find(target); - ICHECK(it != internal_buf_map_.end()) - << "attr::buffer_bind_scope target " << target << " not in scope."; - const InternalBufferRemap& target_remap = it->second; + auto it = buf_map_.find(target); + ICHECK(it != buf_map_.end()) << "attr::buffer_bind_scope target " << target << " not in scope."; + const BufferEntry& target_remap = it->second; ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name << " to the out-of-scope buffer " << target_remap.remap_to->name; @@ -267,19 +240,19 @@ class BufferShapeLegalize : public StmtExprMutator { Array new_tuple_args; Array realized_begins; - Array realized_shape; - ICHECK_EQ(tuple->args.size(), target_remap.realized_begins.size() * 2) + Array view_shape; + ICHECK_EQ(tuple->args.size(), target_remap.index_offsets.size() * 2) << "attr::buffer_bind_scope to define " << buffer << " as a view into " << target << " does match dimensionality of " << target; - for (size_t i = 0; i < target_remap.realized_begins.size(); i++) { + for (size_t i = 0; i < target_remap.index_offsets.size(); i++) { PrimExpr parent_begin = tuple->args[2 * i]; PrimExpr view_extent = tuple->args[2 * i + 1]; // Offset the begin of the buffer view by the offset of the target buffer. - new_tuple_args.push_back(parent_begin - target_remap.realized_begins[i]); + new_tuple_args.push_back(parent_begin - target_remap.index_offsets[i]); // Keep the extent of the buffer view the same. new_tuple_args.push_back(view_extent); // Use the extent of the buffer view to define the buffer view's shape. - realized_shape.push_back(view_extent); + view_shape.push_back(view_extent); // Within the buffer view, indices start at 0. realized_begins.push_back(0); } @@ -287,10 +260,10 @@ class BufferShapeLegalize : public StmtExprMutator { // If a view is binding to a buffer of a higher dimensionality, // then the leading dimensions should be padded out with shape of // 1. - ICHECK_GE(realized_shape.size(), buffer->shape.size()) + ICHECK_GE(view_shape.size(), buffer->shape.size()) << "Cannot bind " << buffer << " to a shape of lower dimension."; - if (realized_shape.size() > buffer->shape.size()) { - size_t diff = realized_shape.size() - buffer->shape.size(); + if (view_shape.size() > buffer->shape.size()) { + size_t diff = view_shape.size() - buffer->shape.size(); Array padded_shape; for (size_t i = 0; i < diff; i++) { padded_shape.push_back(1); @@ -298,17 +271,17 @@ class BufferShapeLegalize : public StmtExprMutator { for (auto dim : buffer->shape) { padded_shape.push_back(dim); } - realized_shape = std::move(padded_shape); + view_shape = std::move(padded_shape); } // If a buffer has strides defined, and is being remapped into a // shape with additional dimensions, then define dummy values for // the strides. Array realized_strides = buffer->strides; - if ((realized_strides.size() > 0) && (realized_strides.size() != realized_shape.size())) { - ICHECK_GE(realized_shape.size(), realized_strides.size()) + if ((realized_strides.size() > 0) && (realized_strides.size() != view_shape.size())) { + ICHECK_GE(view_shape.size(), realized_strides.size()) << "Cannot bind the strides of " << buffer << " to a shape of lower dimension"; - size_t diff = realized_shape.size() - buffer->strides.size(); + size_t diff = view_shape.size() - buffer->strides.size(); Array updated_strides; for (size_t i = 0; i < diff; i++) { @@ -323,18 +296,20 @@ class BufferShapeLegalize : public StmtExprMutator { Buffer key = buffer; auto write_ptr = buffer.CopyOnWrite(); - write_ptr->shape = realized_shape; + write_ptr->shape = view_shape; write_ptr->strides = realized_strides; { - InternalBufferRemap remap; - remap.realized_begins = realized_begins; + BufferEntry remap; + remap.index_offsets = realized_begins; remap.remap_to = buffer; remap.in_scope = true; - internal_buf_map_[key] = remap; + buf_map_[key] = remap; } - // Define remappings of any Variables referencing Buffer internals (e.g. Store/Load nodes). + // Define remappings of any Variables referencing Buffer internals + // (e.g. Store/Load nodes). Passing fuzzy_match=true allows the + // remapped buffer to have a number of dimensions. ArgBinder binder(&var_remap_); binder.BindBuffer(key, buffer, key->name, true); @@ -349,21 +324,43 @@ class BufferShapeLegalize : public StmtExprMutator { var_remap_.erase(v.get()); } - internal_buf_map_.at(key).in_scope = false; + buf_map_.at(key).in_scope = false; return stmt; } + Array update_indices(const Array& indices, const Array& offsets) { + ICHECK_GE(offsets.size(), indices.size()) + << "Cannot bind buffer to a shape of lower dimension."; + + Array new_indices; + + // Pad leading indices with zero, matching the "fuzzy_match" + // behavior from ArgBinder::BindBuffer. + size_t diff = offsets.size() - indices.size(); + for (size_t i = 0; i < diff; i++) { + new_indices.push_back(0); + } + + // Offset indices used to access buffers of a reduced size. + for (size_t i = 0; i < indices.size(); i++) { + PrimExpr offset = offsets[i + diff]; + new_indices.push_back(indices[i] - offset); + } + + return new_indices; + } + std::unordered_map var_remap_; std::unordered_set extern_buffers_; - struct InternalBufferRemap { + struct BufferEntry { Buffer remap_to; - Array realized_begins; + Array index_offsets; bool in_scope; }; - std::unordered_map internal_buf_map_; + std::unordered_map buf_map_; IRVisitorWithAnalyzer* bound_analyzer_; }; @@ -822,7 +819,7 @@ class BufferBindUnwrapper : public StmtExprMutator { auto it = buf_map_.find(op->buffer.get()); ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot read a buffer that is already out of scope"; + ICHECK(e.in_scope) << "Cannot read from buffer " << op->buffer << ", out of scope."; if (e.remap) { return BufferLoad(e.remap->target, @@ -839,7 +836,7 @@ class BufferBindUnwrapper : public StmtExprMutator { auto it = buf_map_.find(op->buffer.get()); ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << op->buffer; const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot write to a buffer that is already out of scope"; + ICHECK(e.in_scope) << "Cannot write to buffer" << op->buffer << ", out of scope."; if (e.remap) { return BufferStore(e.remap->target, op->value, @@ -921,10 +918,11 @@ class BufferBindUnwrapper : public StmtExprMutator { // Determine bounds in the target buffer auto it = buf_map_.find(remap.target.get()); - ICHECK(it != buf_map_.end()) << "Cannot find buffer " << remap.target << " @ " - << remap.target.get(); + ICHECK(it != buf_map_.end()) << "Cannot define " << source << " as a view into " << remap.target + << ", " << remap.target << " was not defined."; const BufferEntry& target_info = it->second; - ICHECK(target_info.in_scope) << "Cannot bind to a buffer that is out of scope"; + ICHECK(target_info.in_scope) << "Cannot define " << source << " as a view into " << remap.target + << ", " << remap.target << " is out of scope."; ICHECK_EQ(remap.begins.size(), target_info.buffer->shape.size()) << "Incorrect number of arguments in buffer_bind_scope attribute. " << "Expected (min_0, extent_0, min_1, extent_0, ..., min_N, extent_N)."; @@ -937,7 +935,10 @@ class BufferBindUnwrapper : public StmtExprMutator { remap.begins = std::move(mapped_begins); } - ICHECK(target_info.remap == nullptr) << "Indirect remapping not handled"; + ICHECK(target_info.remap == nullptr) + << "buffer_bind_scope defines " << source << " as a view into " << remap.target + << ", which is itself a buffer view. " + << "Indirect remapping not currently supported."; for (size_t i = 0; i < remap.begins.size(); i++) { remap.begins.Set(i, bound_analyzer_->Simplify(remap.begins[i])); @@ -963,11 +964,17 @@ class BufferBindUnwrapper : public StmtExprMutator { Buffer view = remap.target.MakeSlice(remap.begins, remap.extents); if (source->strides.size() == 0) { ICHECK_EQ(view->strides.size(), 0U) - << "Cannot bind a compact buffer to a strided buffer" << view->strides; + << "Cannot bind a compact buffer " << source << " to a strided buffer " << view + << " with strides " << view->strides; } else { // Add explicit strides to the view, in order to bind to source.strides[i]. view = view.MakeStrideView(); } + + // Bind any variables that reference the view (e.g. elem_offset, + // strides, shape). Pass fuzzy_match=false, because all shape + // transformations should have been handled in + // BufferShapeLegalize. binder.BindBuffer(source, view, source->name, false); // Apply the remaps @@ -1072,9 +1079,9 @@ class StorageFlattener : public StmtExprMutator { ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; const BufferEntry& e = it->second; - ICHECK(!e.released) << "Read a buffer that is already out of scope"; + ICHECK(e.in_scope) << "Cannot write to " << op->buffer << ", out of scope."; - Stmt body = e.buffer.vstore(e.RelIndex(op->indices), op->value); + Stmt body = e.buffer.vstore(op->indices, op->value); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } @@ -1092,16 +1099,28 @@ class StorageFlattener : public StmtExprMutator { const auto& key = op->buffer; if (buf_map_.count(key)) { - ICHECK(buf_map_.at(key).external); + ICHECK(buf_map_.at(key).external) + << "BufferRealize for internal buffer " << op->buffer << " appears multiple times."; return this->VisitStmt(op->body); } else { // create a buffer entry BufferEntry e; - e.bounds = op->bounds; ICHECK_EQ(op->buffer->shape.size(), op->bounds.size()) << "Inconsistent buffer shape and realization shape for " << op->buffer; + for (size_t i = 0; i < op->bounds.size(); i++) { + const auto& bound = op->bounds[i]; + const auto& dim_size = op->buffer->shape[i]; + ICHECK(is_zero(bound_analyzer_->Simplify(bound->min))) + << "Buffer " << op->buffer << " has realization bounds that do not start at zero. " + << "Please run BufferShapeLegalize first."; + ICHECK(is_one(bound_analyzer_->Simplify(bound->extent == dim_size))) + << "Buffer " << op->buffer + << " has realization extent that does not match its size. " + "Please run BufferShapeLegalize first."; + } + Array shape = op->buffer->shape; StorageScope skey = StorageScope::Create(GetPtrStorageScope(op->buffer->data)); @@ -1124,7 +1143,7 @@ class StorageFlattener : public StmtExprMutator { buf_map_[key] = e; Stmt body = this->VisitStmt(op->body); - buf_map_[key].released = true; + buf_map_[key].in_scope = false; Stmt ret; DataType storage_type = e.buffer->dtype; @@ -1186,12 +1205,12 @@ class StorageFlattener : public StmtExprMutator { auto it = buf_map_.find(key); ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; const BufferEntry& e = it->second; - ICHECK(!e.released) << "Read a buffer that is already out of scope"; + ICHECK(e.in_scope) << "Cannot read to " << op->buffer << ", out of scope."; if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); } - return e.buffer.vload(e.RelIndex(op->indices), e.buffer->dtype); + return e.buffer.vload(op->indices, e.buffer->dtype); } Stmt VisitStmt_(const PrefetchNode* op) final { @@ -1204,7 +1223,7 @@ class StorageFlattener : public StmtExprMutator { ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << key; const BufferEntry& e = it->second; - ICHECK(!e.released) << "Read a buffer that is already out of scope"; + ICHECK(e.in_scope) << "Cannot prefetch " << op->buffer << ", out of scope."; ICHECK_EQ(e.buffer->shape.size(), op->bounds.size()) << "Prefetch dim should be the same as buffer dim"; @@ -1237,7 +1256,7 @@ class StorageFlattener : public StmtExprMutator { if (i < starts) { stmt = For(vars[i], 0, op->bounds[i]->extent, ForKind::kSerial, stmt); } else { - PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); + PrimExpr load = e.buffer.vload(args, e.buffer->dtype); PrimExpr address = Call(DataType::Handle(), builtin::address_of(), {load}); PrimExpr prefetch = Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}); stmt = Evaluate(prefetch); @@ -1276,25 +1295,10 @@ class StorageFlattener : public StmtExprMutator { struct BufferEntry { // the buffer of storage Buffer buffer; - // the bounds of realization, can be null, means everything - Region bounds; // Whether the buffer is external bool external{false}; - // Whether we are out of allocation bounds and buffer get released. - bool released{false}; - // relative index - inline Array RelIndex(Array args) const { - if (bounds.size() != 0) { - Array index; - ICHECK_EQ(bounds.size(), args.size()); - for (size_t i = 0; i < bounds.size(); ++i) { - index.push_back(args[i] - bounds[i]->min); - } - return index; - } else { - return args; - } - } + // Whether the buffer is currently in scope. + bool in_scope{true}; }; bool ShapeIsValid(const Array& shape) { From ddfc56f8527da7b711f4aa5236482fd07837c9ad Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 29 Sep 2021 14:09:46 -0500 Subject: [PATCH 20/21] Updated to handle BufferRealizeNode with no defined bounds. --- src/tir/transforms/storage_flatten.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 983c9f341568..9316c8411bc5 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -329,6 +329,14 @@ class BufferShapeLegalize : public StmtExprMutator { } Array update_indices(const Array& indices, const Array& offsets) { + // offsets come from BufferRealizeNode::bounds, which is allowed + // to be empty to indicate realization of the full shape of the + // buffer. In that case, the indices do not need to be modified, + // but may need to be extended with leading zeroes. + if (offsets.size() == 0) { + return indices; + } + ICHECK_GE(offsets.size(), indices.size()) << "Cannot bind buffer to a shape of lower dimension."; From ee2dc22293952a12f44fd4e4df6ad31e474f8e2d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 29 Sep 2021 14:35:06 -0500 Subject: [PATCH 21/21] Updated to be less aggressive when checking AssertStmt A true Assert statement can be removed, but a false Assert statement requires CFA to give as a compile-time error. Since we only need the removal of true assert statements, skipping the CFA this time. --- src/tir/transforms/storage_flatten.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 9316c8411bc5..6a3ce596c2fe 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1349,10 +1349,9 @@ class StorageFlattener : public StmtExprMutator { /*! * \brief Simplify assert statements. * - * If an assert statement can be statically verified to be false, emit - * a failure at compile time. If an assert statement can be - * statically verified to be true, remove the assert statement. If - * neither case can be verified, keep the assert statement unmodified. + * If an assert statement can be statically verified to be true, + * remove the assert statement. Otherwise, keep the assert statement + * unmodified. */ class AssertSimplifier : public StmtMutator { public: @@ -1364,9 +1363,6 @@ class AssertSimplifier : public StmtMutator { op = stmt.as(); PrimExpr condition = bound_analyzer_->Simplify(op->condition); - if (is_zero(condition)) { - LOG(FATAL) << "Assert statement failed during static checking: " << op->message; - } if (is_one(condition)) { return op->body; }