Skip to content

Commit

Permalink
TIR Schedule primitive - decompose_padding (#12174)
Browse files Browse the repository at this point in the history
Co-authored-by: baoxinqi <wrongtest@intellif.com>
  • Loading branch information
2 people authored and baoxinqi committed Jul 27, 2022
1 parent 8fa4113 commit 2d6a91f
Show file tree
Hide file tree
Showing 12 changed files with 1,048 additions and 26 deletions.
9 changes: 9 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,15 @@ class ScheduleNode : public runtime::Object {
BufferIndexType buffer_index_type,
const Array<IntImm>& axis_separators) = 0;

/*!
* \brief Decompose a padding block into a block filling const pad values and a block
* writing in-bound values.
* \param block_rv The block that match the padding pattern.
* \param loop_rv The loop above which the const filling block is inserted before.
* \return The const pad value filling block.
*/
virtual BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;

/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
Expand Down
78 changes: 78 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2650,6 +2650,84 @@ def after_set_axis_separators(
self, block, buffer_index, buffer_index_type_enum, axis_separators
)

########## Schedule: Padding decomposition #########
@type_checked
def decompose_padding(self, block: Union[BlockRV, str], loop: LoopRV) -> BlockRV:
"""Decompose a block of padding computation pattern into two separate blocks.
a) The block which fill const pad values into full write region;
b) The block which fill in-bound values into region where pad predicate is true.
The pad value filling block is inserted right before the given loop.
The schedule primitive requires:
1) The input block is a complete block.
2) The input loop is the ancestor of the block.
3) The input block is a block which match padding pattern.
Parameters
----------
block : Union[BlockRV, str]
The padding block to be decomposed.
loop : LoopRV
The loop above which the pad value filling block is inserted before.
Returns
-------
pad_value_block : BlockRV
The block filling const pad values.
Examples
--------
Before decompose-padding, in TensorIR, the IR is:
.. code-block:: python
@T.prim_func
def before_decompose(x: T.Buffer[128, "int32"], y: T.Buffer[140, "int32"]):
for i in range(140):
with T.block("block"):
vi = T.axis.remap("S", [i])
y[vi] = T.if_then_else(vi >= 6 and vi < 134, x[vi - 6], 0, dtype="int32")
Create the schedule and do decompose-padding with specified loop:
.. code-block:: python
sch = tir.Schedule(before_decompose, debug_mask="all")
block = sch.get_block("block")
sch.decompose_padding(block, sch.get_loops(block)[0])
print(sch.mod["main].script())
After applying decompose-padding, the IR becomes:
.. code-block:: python
@T.prim_func
def after_decompose(x: T.Buffer[128, "int32"], y: T.Buffer[140, "int32"]):
for i in T.serial(140):
with T.block("block_pad_const"):
vi = T.axis.spatial(140, i)
y[vi] = 0
for i in T.serial(128):
with T.block("block"):
vi = T.axis.spatial(128, i)
y[vi + 6] = x[vi]
"""
block = self._normalize_block_arg(block)
return _ffi_api.ScheduleDecomposePadding( # type: ignore # pylint: disable=no-member
self, block, loop
)

@type_checked
def can_decompose_padding(self, block: Union[BlockRV, str], loop: LoopRV) -> bool:
"""Check whether the block match padding pattern and can be decomposed."""
return _ffi_api.CanDecomposePadding(self, block, loop) # type: ignore # pylint: disable=no-member

########## Schedule: Misc ##########

@type_checked
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,15 @@ void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_
this->state_->DebugVerify();
}

BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::DecomposePadding(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv));
TVM_TIR_SCHEDULE_END("decompose-padding", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class ConcreteScheduleNode : public ScheduleNode {
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const Array<IntImm>& axis_separators) override;
/******** Schedule: Padding decomposition ********/
BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) override;
/******** Schedule: Misc ********/
void EnterPostproc() override {}

Expand Down
32 changes: 31 additions & 1 deletion src/tir/schedule/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
*/
#ifndef TVM_TIR_SCHEDULE_ERROR_H_
#define TVM_TIR_SCHEDULE_ERROR_H_

#include <tvm/tir/schedule/state.h>

#include <string>
#include <utility>

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -52,6 +54,34 @@ class ScheduleError : public tvm::runtime::Error {
String RenderReport(const String& primitive) const;
};

class LoopPositionError : public ScheduleError {
public:
explicit LoopPositionError(IRModule mod, For loop, Block block, const std::string& primitive)
: mod_(std::move(mod)),
loop_(std::move(loop)),
block_(std::move(block)),
primitive_(primitive) {}

String FastErrorString() const final {
return "ScheduleError: " + primitive_ + " expect the loop to be an ancestor of block";
}

String DetailRenderTemplate() const final {
std::ostringstream os;
os << "ScheduleError: The input loop {0} of " << primitive_
<< " is required to be be an ancestor of block {1}.";
return os.str();
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }

IRModule mod_;
For loop_;
Block block_;
std::string primitive_;
};

} // namespace tir
} // namespace tvm

Expand Down
11 changes: 11 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,17 @@ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int
TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
const IndexMap& index_map);

/******** Schedule: Padding decomposition ********/
/*!
* \brief Decompose a padding block into a block filling const pad values and a block
* writing in-bound values.
* \param block_sref The block sref that match the padding pattern.
* \param loop_sref The loop above which the const filling block is inserted before.
* \return The padding value filling block sref.
*/
TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref);

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
Loading

0 comments on commit 2d6a91f

Please sign in to comment.