diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index e2083778431e..e5d2c440e57b 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -264,6 +264,21 @@ class ScheduleNode : public runtime::Object { * \return The rfactor block */ virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0; + /******** Schedule: Block annotation ********/ + /*! + * \brief Set alignment requirement for specific dimension such that + * stride[axis] == k * factor + offset for some k. This is useful to set memory layout for + * more friendly memory access pattern. For example, we can set alignment to be factor=2, + * offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared + * memory. + * \param block_rv The producer block of the buffer + * \param buffer_index The index of the buffer in block's write region + * \param axis The dimension to be specified for alignment + * \param factor The factor multiple of alignment + * \param offset The required offset factor + */ + virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, + int offset) = 0; /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ /******** Schedule: Misc ********/ diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 4bbb5b9b1582..e8415d2bd522 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -710,6 +710,79 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None: """ return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member + ######## Schedule: Block annotatoin ######## + + def storage_align( # pylint: disable=too-many-arguments + self, block: BlockRV, buffer_index: int, axis: int, factor: int, offset: int + ) -> None: + """Set alignment requirement for specific dimension such that + stride[axis] == k * factor + offset for some k. This is useful to set memory layout for more + friendly memory access pattern. For example, we can set alignment to be factor=2, offset=1 + to avoid bank conflict for thread access on higher dimension in GPU shared memory. + + Parameters + ---------- + block : BlockRV + The producer block of the buffer. + buffer_index : int + The index of the buffer in block's write region. + axis : int + The dimension to be specified for alignment. + factor : int + The factor multiple of alignment. + offset : int + The required offset factor. + + Examples + -------- + + Before storage_align, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_storage_align(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + Create the schedule and do storage_align: + + .. code-block:: python + + sch = tir.Schedule(before_storage_align) + sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, factor=128, offset=1) + print(tvm.script.asscript(sch.mod["main"])) + + After applying rfactor, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_storage_align(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.alloc_buffer((128, 128)) + C = tir.match_buffer(c, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + tir.block_attr({"buffer_dim_align": [[[0, 128, 1]]]}) + B[vi, vj] = A[vi, vj] * 2.0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + After lowering passes, buffer B will have strides as [129, 1]. + + Note + ---- + Storage_align requires the buffer to be an intermediate buffer defined via `alloc_buffer`. + """ + _ffi_api.ScheduleStorageAlign( # type: ignore # pylint: disable=no-member + self, block, buffer_index, axis, factor, offset + ) + ########## Schedule: Blockize & Tensorize ########## ########## Schedule: Annotation ########## diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 9baf4b5245ea..370aa01a33c0 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -202,6 +202,19 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self */ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); +/******** Block-buffer relation ********/ + +/*! + * \brief Get the BlockRealize of the single child block of the block or loop specified by + * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks + * \param self The schedule state + * \param block The queried block + * \param n The index of the queried buffer + * \return The buffer of the n-th write region of the block. + * \throw ScheduleError If the buffer index is out of bound. + */ +Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n); + /******** Commutative Reducer ********/ /*! diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 3ee98ec5b7d2..8d1913fdee86 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -527,6 +527,45 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr } } +/******** Block-buffer relation ********/ + +Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) { + class WriteBufferIndexOutOfRangeError : public ScheduleError { + public: + explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index) + : mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {} + + String FastErrorString() const final { + return "ScheduleError: The input `buffer_index` is out of range. It is required to be in " + "range [0, num_write_regions) where `num_write_regions` is the number of buffer " + "regions written by the block."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + size_t num_writes = block_->writes.size(); + os << "The block {0} has " << num_writes + << " write regions, so `buffer_index` is required to be in [0, " << num_writes + << "). However, the input `buffer_index` is " << buffer_index_ + << ", which is out of the expected range"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + private: + IRModule mod_; + Block block_; + int buffer_index_; + }; + + if (n < 0 || static_cast(n) >= block->writes.size()) { + throw WriteBufferIndexOutOfRangeError(self->mod, block, n); + } + return block->writes[n]->buffer; +} + /******** Pattern Matcher ********/ /*! diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 610628c6d88a..688ea8059c0e 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -362,6 +362,16 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { } /******** Schedule: loop binding/annotation ********/ +/******** Schedule: block annotation ********/ + +void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, + int factor, int offset) { + TVM_TIR_SCHEDULE_BEGIN(); + tir::StorageAlign(state_, this->GetSRef(block_rv), buffer_index, axis, factor, offset); + TVM_TIR_SCHEDULE_END("storage-align", this->error_render_level_); + this->state_->DebugVerify(); +} + /******** Schedule: cache read/write ********/ /******** Schedule: reduction ********/ diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index ec0dd079243b..cfdd9c8452f7 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -88,6 +88,9 @@ class ConcreteScheduleNode : public ScheduleNode { void ReverseComputeInline(const BlockRV& block) override; /******** Schedule: Reduction ********/ BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; + /******** Schedule: Block annotation ********/ + void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, + int offset) override; /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ /******** Schedule: Misc ********/ diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 22e25f1c54a7..4b9c76947bb1 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -104,6 +104,26 @@ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref * \return The sref of the rfactor block */ TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis); +/******** Schedule: Block annotation ********/ +/*! + * \brief Set alignment requirement for specific dimension such that + * stride[axis] == k * factor + offset for some k. This is useful to set memory layout for + * more friendly memory access pattern. For example, we can set alignment to be factor=2, + * offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared + * memory. + * \param block_sref The producer block of the buffer + * \param buffer_index The index of the buffer in block's write region + * \param axis The dimension to be specified for alignment + * \param factor The factor multiple of alignment + * \param offset The required offset factor + */ +TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, + int axis, int factor, int offset); + +/******** Annotation types for StorageAlign ********/ +using StorageAlignTuple = Array; // (buffer_idx, axis, factor, offset) +using StorageAlignAnnotation = Array; // unordered array of StorageAlignTuple + /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ /******** Schedule: Misc ********/ diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc new file mode 100644 index 000000000000..937bc7c3802f --- /dev/null +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../transform.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +class StorageAlignAxisOutOfRangeError : public ScheduleError { + public: + explicit StorageAlignAxisOutOfRangeError(IRModule mod, Buffer buffer, int axis) + : mod_(std::move(mod)), buffer_(std::move(buffer)), axis_(axis) {} + + String FastErrorString() const final { + return "ScheduleError: The input `axis` is out of range. It is required to be in range " + "[-ndim, ndim) where `ndim` is the number of dimensions of the buffer to set " + "storage alignment."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + int ndim = static_cast(buffer_->shape.size()); + os << "The buffer to set storage alignment of, " << buffer_->name << ", has " << ndim + << " dimension(s), so `axis` is required to be in [" << -(ndim) << ", " << ndim + << ") for storage_align. However, the input `axis` is " << axis_ + << ", which is out of the expected range."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int axis) { + int ndim = static_cast(buffer->shape.size()); + if (axis < -ndim || axis >= ndim) { + throw StorageAlignAxisOutOfRangeError(mod, buffer, axis); + } + // If axis is negative, convert it to a non-negative one. + if (axis < 0) { + axis += ndim; + } + return axis; + } + + private: + IRModule mod_; + Buffer buffer_; + int axis_; +}; + +/*! + * \brief Find the defining site of the buffer in the given block and its ancestors + * \param block_sref The block sref + * \param buffer The buffer + * \return The defining site of the buffer and whether the buffer is allocated (otherwise the + * buffer is from match_buffer). + */ +std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_sref, + const Buffer& buffer) { + // Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or + // match_buffers. + const StmtSRefNode* defining_site_sref = block_sref.get(); + while (defining_site_sref != nullptr) { + const auto* block = defining_site_sref->StmtAs(); + // If this sref is not a block sref, skip it. + if (block == nullptr) { + defining_site_sref = defining_site_sref->parent; + continue; + } + // Try to find the buffer in `allloc_buffers` + for (const Buffer& alloc_buffer : block->alloc_buffers) { + if (buffer.same_as(alloc_buffer)) { + return {GetRef(defining_site_sref), true}; + } + } + // We do not allow the buffer being defined in `match_buffer`. + for (const MatchBufferRegion match_buffer : block->match_buffers) { + if (buffer.same_as(match_buffer)) { + return {GetRef(defining_site_sref), false}; + } + } + defining_site_sref = defining_site_sref->parent; + } + // If we cannot find the defining site block, it means that the buffer must be in the function's + // buffer_map, which isn't an intermediate buffer. + return {NullOpt, false}; +} + +class NonAllocatedBufferError : public ScheduleError { + public: + explicit NonAllocatedBufferError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {} + + String FastErrorString() const final { + return "ScheduleError: The input buffer is not allocated by a block. This means the buffer is " + " either a function parameter or defined in `match_buffer` of a block."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The input buffer " << buffer_->name + << " is not allocated by a block. This means the buffer is either a function parameter or " + "defined in `match_buffer` of a block."; + return os.str(); + } + + static void CheckBufferAllocated(const IRModule& mod, const StmtSRef& block_sref, + const Buffer& buffer) { + Optional defining_site_sref; + bool is_alloc; + std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, buffer); + if (!defining_site_sref || !is_alloc) { + throw NonAllocatedBufferError(mod, buffer); + } + } + + Array LocationsOfInterest() const final { return {}; } + IRModule mod() const final { return mod_; } + + private: + IRModule mod_; + Buffer buffer_; +}; + +class StorageAlignInvalidFactorError : public ScheduleError { + public: + explicit StorageAlignInvalidFactorError(IRModule mod, int factor) + : mod_(std::move(mod)), factor_(factor) {} + + String FastErrorString() const final { + return "ScheduleError: The input `factor` of storage_align is expected to be a positive " + "number."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The input `factor` of storage_align is expected to be a positive number. However, the " + "input `factor` is " + << factor_ << ", which is out of the expected range."; + return os.str(); + } + + static void Check(const IRModule& mod, int factor) { + if (factor <= 0) { + throw StorageAlignInvalidFactorError(mod, factor); + } + } + + Array LocationsOfInterest() const final { return {}; } + IRModule mod() const final { return mod_; } + + private: + IRModule mod_; + int factor_; +}; + +class StorageAlignInvalidAnnotationError : public ScheduleError { + public: + explicit StorageAlignInvalidAnnotationError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The block annotation for storage align is expected to be an array of " + "4-integer-tuples (buffer_index, axis, factor, offset)."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The block annotation for storage align is expected to be an array of 4-integer-tuples " + "(buffer_index, axis, factor, offset). However, the block annotation with key " + << attr::buffer_dim_align << " of the block {0} is " + << block_->annotations.at(attr::buffer_dim_align) << ", which is unexpected."; + return os.str(); + } + + static StorageAlignAnnotation CheckAndGetAnnotation(const IRModule& mod, const Block& block) { + // Get existing annotation value. + auto it = block->annotations.find(attr::buffer_dim_align); + if (it != block->annotations.end()) { + if (!IsValidAnnotation(block, (*it).second)) { + throw StorageAlignInvalidAnnotationError(mod, block); + } + return Downcast((*it).second); + } + + // Create new annotation value + StorageAlignAnnotation storage_align_annotation; + return storage_align_annotation; + } + + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod() const final { return mod_; } + + private: + static bool IsValidAnnotation(const Block& block, const ObjectRef& anno_value) { + if (!anno_value->IsInstance()) { + return false; + } + auto storage_align_annotations = Downcast>(anno_value); + for (const ObjectRef& storage_align_annotation : storage_align_annotations) { + if (!storage_align_annotation->IsInstance()) { + return false; + } + auto storage_align_tuple = Downcast>(storage_align_annotation); + // Check if the annotation is a 4-tuple. + if (storage_align_tuple.size() != 4) { + return false; + } + for (const ObjectRef& tuple_element : storage_align_tuple) { + if (!tuple_element->IsInstance()) { + return false; + } + } + } + return true; + } + + IRModule mod_; + Block block_; +}; + +void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis, + int factor, int offset) { + const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref); + Buffer buffer = GetNthWriteBuffer(self, GetRef(block_ptr), buffer_index); + StorageAlignInvalidFactorError::Check(self->mod, factor); + axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis); + NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer); + + // Step 1: Get existing or create new annotation value. + StorageAlignAnnotation storage_align_annotation = + StorageAlignInvalidAnnotationError::CheckAndGetAnnotation(self->mod, + GetRef(block_ptr)); + + // Step 2: Update the annotation value + // Array> buffer_storage_align = storage_align_annotation[buffer_index]; + bool found = false; + StorageAlignTuple new_storage_align_tuple{Integer(buffer_index), Integer(axis), Integer(factor), + Integer(offset)}; + for (size_t j = 0; j < storage_align_annotation.size(); ++j) { + const auto& storage_align_tuple = storage_align_annotation[j]; + ICHECK(storage_align_tuple.size() == 4); + if (storage_align_tuple[0] == buffer_index && storage_align_tuple[1] == axis) { + storage_align_annotation.Set(j, std::move(new_storage_align_tuple)); + found = true; + break; + } + } + if (!found) { + storage_align_annotation.push_back(std::move(new_storage_align_tuple)); + } + + // Step 3: Replace the block with the new annotation + Block new_block = WithAnnotation(block_ptr, attr::buffer_dim_align, storage_align_annotation); + self->Replace(block_sref, new_block, {{GetRef(block_ptr), new_block}}); +} + +/******** Instruction Registration ********/ + +struct StorageAlignTraits : public UnpackedInstTraits { + static constexpr const char* kName = "StorageAlign"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 4; + static constexpr size_t kNumDecisions = 0; + + static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, + Integer axis, Integer factor, Integer offset) { + return sch->StorageAlign(block_rv, buffer_index->value, axis->value, factor->value, + offset->value); + } + + static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, + Integer axis, Integer factor, Integer offset) { + PythonAPICall py("storage_align"); + py.Input("block", block_rv); + py.Input("buffer_index", buffer_index); + py.Input("axis", axis); + py.Input("factor", factor); + py.Input("offset", offset); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 3232a3344ee7..d6dc0b446e16 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -135,6 +135,9 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") /******** (FFI) Reduction ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor") .set_body_method(&ScheduleNode::RFactor); +/******** (FFI) Block annotation ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign") + .set_body_method(&ScheduleNode::StorageAlign); /******** (FFI) Blockize & Tensorize ********/ /******** (FFI) Annotation ********/ /******** (FFI) Misc ********/ diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index d664d7f6ce98..e0ffdc7b019f 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -137,6 +137,19 @@ BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { return result; } +/******** Schedule: Block annotation ********/ + +void TracedScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, + int factor, int offset) { + ConcreteScheduleNode::StorageAlign(block_rv, buffer_index, axis, factor, offset); + static const InstructionKind& kind = InstructionKind::Get("StorageAlign"); + trace_->Append(/*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index), Integer(axis), Integer(factor), Integer(offset)}, + /*outputs=*/{})); +} + /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index b4518cbba8b5..4650c44ba8c3 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -61,6 +61,9 @@ class TracedScheduleNode : public ConcreteScheduleNode { void ReverseComputeInline(const BlockRV& block_rv) final; /******** Schedule: Reduction ********/ BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final; + /******** Schedule: Block annotation ********/ + void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, + int offset) final; /******** Schedule: Blockize & Tensorize ********/ /******** Schedule: Annotation ********/ /******** Schedule: Misc ********/ diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc new file mode 100644 index 000000000000..f27e0f6d62eb --- /dev/null +++ b/src/tir/schedule/transform.cc @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./transform.h" + +namespace tvm { +namespace tir { + +/******** Annotation ********/ +Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) { + Map annotations = block->annotations; + annotations.Set(attr_key, attr_value); + ObjectPtr new_block = make_object(*block); + new_block->annotations = std::move(annotations); + return Block(new_block); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h new file mode 100644 index 000000000000..53483829a303 --- /dev/null +++ b/src/tir/schedule/transform.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_TIR_SCHEDULE_TRANSFORM_H_ +#define TVM_TIR_SCHEDULE_TRANSFORM_H_ + +#include + +namespace tvm { +namespace tir { + +/******** Annotation ********/ + +/*! + * \brief Create a new block with the given annotation added + * \param block The block with original annotation + * \param attr_key The annotation key to be added + * \param attr_value The annotation value to be added + * \return A new block with the given annotation as its last annotation + */ +Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value); + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SCHEDULE_TRANSFORM_H_ diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index b1a4fd45ef0d..961ea1721fa1 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -303,18 +303,61 @@ class BufferAccessRegionCollector : public StmtExprVisitor { support::Arena arena_; }; +/*! \brief Collect storage alignment information from block annotations. */ +class StorageAlignCollector : public StmtVisitor { + public: + static std::unordered_map Collect( + const PrimFunc& f) { + StorageAlignCollector collector; + collector(f->body); + return std::move(collector.storage_align_); + } + + private: + void VisitStmt_(const BlockNode* op) final { + auto it = op->annotations.find(attr::buffer_dim_align); + if (it != op->annotations.end()) { + auto storage_align_annotation = Downcast((*it).second); + for (const auto& storage_align_tuple : storage_align_annotation) { + int buffer_index = storage_align_tuple[0]->value; + const Buffer& buffer = op->writes[buffer_index]->buffer; + storage_align_[buffer].push_back(storage_align_tuple); + } + } + StmtVisitor::VisitStmt_(op); + } + + /*! \brief The map from Buffer to its storage alignment information. */ + std::unordered_map storage_align_; +}; + /*! \brief Reallocate the buffers with minimal region. */ class BufferCompactor : public StmtExprMutator { public: static Stmt Compact( const PrimFunc& f, - const std::unordered_map& regions) { + const std::unordered_map& regions, + const std::unordered_map& + storage_align) { std::unordered_map buffer_info; for (const auto& kv : regions) { const Buffer& buffer = kv.first; Region region = kv.second; - buffer_info.emplace(buffer, BufferAllocInfo(std::move(region))); + BufferAllocInfo buffer_alloc_info(std::move(region)); + auto it = storage_align.find(buffer); + if (it != storage_align.end()) { + std::vector dim_aligns(buffer->shape.size()); + for (const StorageAlignTuple& dim_align : (*it).second) { + ICHECK(dim_align.size() == 4); + int dim = dim_align[1]->value; + int factor = dim_align[2]->value; + int offset = dim_align[3]->value; + dim_aligns.at(dim) = {factor, offset}; + } + buffer_alloc_info.dim_aligns = std::move(dim_aligns); + } + buffer_info.emplace(buffer, std::move(buffer_alloc_info)); } BufferCompactor compactor(std::move(buffer_info)); Stmt stmt = compactor(f->body); @@ -322,9 +365,19 @@ class BufferCompactor : public StmtExprMutator { } private: + /*! \brief The storage alignment for a dimension */ + struct DimAlignInfo { + /*! \brief The factor of the alignment */ + int align_factor{0}; + /*! \brief The offset of the alignment */ + int align_offset{0}; + }; + struct BufferAllocInfo { /*! \brief The buffer access region. */ Region region; + /*! \brief The storage alignment information. */ + std::vector dim_aligns; /*! * \brief The reallocated buffer with minimal size. * \note The value if NullOpt if the buffer do not need reallocate (e.g parameter buffer). @@ -380,8 +433,25 @@ class BufferCompactor : public StmtExprMutator { for (const Range& range : info.region) { shape.push_back(range->extent); } + Array strides; + if (info.dim_aligns.size()) { + ICHECK(info.dim_aligns.size() == shape.size()); + strides.resize(shape.size()); + PrimExpr stride = make_const(shape[0].dtype(), 1); + for (size_t i = shape.size(); i != 0; --i) { + size_t dim = i - 1; + if (info.dim_aligns[dim].align_factor != 0) { + PrimExpr factor = make_const(stride.dtype(), info.dim_aligns[dim].align_factor); + PrimExpr offset = make_const(stride.dtype(), info.dim_aligns[dim].align_offset); + stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); + } + strides.Set(dim, stride); + stride = stride * shape[dim]; + } + } ObjectPtr n = make_object(*buffer.get()); n->shape = std::move(shape); + n->strides = std::move(strides); info.new_buffer = Buffer(std::move(n)); result.push_back(info.new_buffer); } @@ -458,7 +528,9 @@ PrimFunc CompactBufferAllocation(PrimFunc f) { PrimFuncNode* fptr = f.CopyOnWrite(); std::unordered_map region = BufferAccessRegionCollector::Collect(f); - fptr->body = BufferCompactor::Compact(f, region); + std::unordered_map + storage_align = StorageAlignCollector::Collect(f); + fptr->body = BufferCompactor::Compact(f, region, storage_align); return f; } else { return f; diff --git a/tests/python/unittest/test_tir_schedule_storage_align.py b/tests/python/unittest/test_tir_schedule_storage_align.py new file mode 100644 index 000000000000..a0a069347f95 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_storage_align.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +from tvm import tir +from tvm.script import ty +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name + +@tvm.script.tir +def element_wise(a: ty.handle, c: ty.handle) -> None: + C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with tir.block([], "root"): + tir.reads([]) + tir.writes([]) + B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in tir.serial(0, 128): + for ax1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, ax1) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + B[vi, vj] = (A[vi, vj]*tir.float32(2)) + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi_1, vj_1]: + tir.bind(vi_1, i0) + tir.bind(vj_1, i1) + tir.reads([B[vi_1, vj_1]]) + tir.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) + + +@tvm.script.tir +def element_wise_storage_align(a: ty.handle, c: ty.handle) -> None: + C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with tir.block([], "root"): + tir.reads([]) + tir.writes([]) + B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in tir.serial(0, 128): + for ax1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, i0) + tir.bind(vj, ax1) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + tir.block_attr({"buffer_dim_align":[[0, 0, 128, 127]]}) + B[vi, vj] = (A[vi, vj]*tir.float32(2)) + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi_1, vj_1]: + tir.bind(vi_1, i0) + tir.bind(vj_1, i1) + tir.reads([B[vi_1, vj_1]]) + tir.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) + + +@tvm.script.tir +def element_wise_invalid_annotation(a: ty.handle, c: ty.handle) -> None: + C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with tir.block([], "root"): + tir.reads([]) + tir.writes([]) + B = tir.alloc_buffer([128, 128], elem_offset=0, align=128, offset_factor=1) + for i0 in tir.serial(0, 128): + for ax1 in tir.serial(0, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.block_attr({"buffer_dim_align": [0]}) + tir.bind(vi, i0) + tir.bind(vj, ax1) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + B[vi, vj] = (A[vi, vj]*tir.float32(2)) + for i1 in tir.serial(0, 128): + with tir.block([128, 128], "C") as [vi_1, vj_1]: + tir.bind(vi_1, i0) + tir.bind(vj_1, i1) + tir.reads([B[vi_1, vj_1]]) + tir.writes([C[vi_1, vj_1]]) + C[vi_1, vj_1] = (B[vi_1, vj_1] + tir.float32(1)) + + +def test_storage_align(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + s.storage_align(B, 0, axis=0, factor=128, offset=127) + tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_storage_align_update(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + s.storage_align(B, 0, axis=0, factor=128, offset=0) + s.storage_align(B, 0, axis=0, factor=128, offset=127) + tvm.ir.assert_structural_equal(element_wise_storage_align, s.mod["main"]) + verify_trace_roundtrip(sch=s, mod=func) + + +def test_storage_align_invalid_factor1(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 0, axis=0, factor=0, offset=127) + + +def test_storage_align_invalid_factor2(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 0, axis=0, factor=-1, offset=127) + + +def test_storage_align_invalid_buffer(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + C = s.get_block("C") + with pytest.raises(tir.ScheduleError): + s.storage_align(C, 0, axis=0, factor=128, offset=127) + + +def test_storage_align_invalid_buffer_index(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 2, axis=0, factor=128, offset=127) + + +def test_storage_align_invalid_axis(): + func = element_wise + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 0, axis=2, factor=128, offset=127) + + +def test_storage_align_invalid_annotation(): + func = element_wise_invalid_annotation + s = tir.Schedule(func, debug_mask='all') + B = s.get_block("B") + with pytest.raises(tir.ScheduleError): + s.storage_align(B, 0, axis=2, factor=128, offset=127) + + +if __name__ == "__main__": + test_storage_align() + test_storage_align_update() + test_storage_align_invalid_factor1() + test_storage_align_invalid_factor2() + test_storage_align_invalid_buffer() + test_storage_align_invalid_buffer_index() + test_storage_align_invalid_axis() + test_storage_align_invalid_annotation() diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index fb53b420f4ce..15da022e67d6 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -339,6 +339,50 @@ def compacted_match_buffer_func(a: ty.handle, c: ty.handle) -> None: C1[()] = B2[()] * 2.0 +@tvm.script.tir +def storage_align_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes(C[i, 0:16]) + B = tir.alloc_buffer((16, 16), "float32") + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(A[i, j]) + tir.writes(B[i, j]) + tir.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) + B[i, j] = A[i, j] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(B[i, j]) + tir.writes(C[i, j]) + C[i, j] = B[i, j] * 2.0 + + +@tvm.script.tir +def compacted_storage_align_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes(C[i, 0:16]) + B = tir.alloc_buffer((1, 16), strides=(31, 1), dtypes="float32") + for j in range(0, 16): + with tir.block() as []: + tir.reads(A[i, j]) + tir.writes(B[0, j]) + tir.block_attr({"buffer_dim_align": [[0, 0, 16, 15]]}) + B[0, j] = A[i, j] + 1.0 + for j in range(0, 16): + with tir.block() as []: + tir.reads(B[0, j]) + tir.writes(C[i, j]) + C[i, j] = B[0, j] * 2.0 + + def test_elementwise(): _check(elementwise_func, compacted_elementwise_func) @@ -380,6 +424,10 @@ def test_lower_te(): tvm.ir.assert_structural_equal(mod, orig_mod) # CompactBufferAllocation should do nothing on TE +def test_storage_align(): + _check(storage_align_func, compacted_storage_align_func) + + if __name__ == "__main__": test_elementwise() test_unschedulable_block() @@ -389,3 +437,4 @@ def test_lower_te(): test_symbolic() test_complex() test_match_buffer() + test_storage_align()