Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Schedule] DecomposePadding #12174

Merged
merged 1 commit into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,15 @@ class ScheduleNode : public runtime::Object {
BufferIndexType buffer_index_type,
const Array<IntImm>& axis_separators) = 0;

/*!
* \brief Decompose a padding block into a block filling const pad values and a block
* writing in-bound values.
* \param block_rv The block that match the padding pattern.
* \param loop_rv The loop above which the const filling block is inserted before.
* \return The const pad value filling block.
*/
virtual BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;

/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
Expand Down
78 changes: 78 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2639,6 +2639,84 @@ def after_set_axis_separators(
self, block, buffer_index, buffer_index_type_enum, axis_separators
)

########## Schedule: Padding decomposition #########
@type_checked
def decompose_padding(self, block: Union[BlockRV, str], loop: LoopRV) -> BlockRV:
"""Decompose a block of padding computation pattern into two separate blocks.

a) The block which fill const pad values into full write region;

b) The block which fill in-bound values into region where pad predicate is true.

The pad value filling block is inserted right before the given loop.

The schedule primitive requires:

1) The input block is a complete block.

2) The input loop is the ancestor of the block.

3) The input block is a block which match padding pattern.

Parameters
----------
block : Union[BlockRV, str]
The padding block to be decomposed.
loop : LoopRV
The loop above which the pad value filling block is inserted before.

Returns
-------
pad_value_block : BlockRV
The block filling const pad values.

Examples
--------
Before decompose-padding, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_decompose(x: T.Buffer[128, "int32"], y: T.Buffer[140, "int32"]):
for i in range(140):
with T.block("block"):
vi = T.axis.remap("S", [i])
y[vi] = T.if_then_else(vi >= 6 and vi < 134, x[vi - 6], 0, dtype="int32")

Create the schedule and do decompose-padding with specified loop:

.. code-block:: python

sch = tir.Schedule(before_decompose, debug_mask="all")
block = sch.get_block("block")
sch.decompose_padding(block, sch.get_loops(block)[0])
print(sch.mod["main].script())

After applying decompose-padding, the IR becomes:

.. code-block:: python

@T.prim_func
def after_decompose(x: T.Buffer[128, "int32"], y: T.Buffer[140, "int32"]):
for i in T.serial(140):
with T.block("block_pad_const"):
vi = T.axis.spatial(140, i)
y[vi] = 0
for i in T.serial(128):
with T.block("block"):
vi = T.axis.spatial(128, i)
y[vi + 6] = x[vi]
"""
block = self._normalize_block_arg(block)
return _ffi_api.ScheduleDecomposePadding( # type: ignore # pylint: disable=no-member
self, block, loop
)

@type_checked
def can_decompose_padding(self, block: Union[BlockRV, str], loop: LoopRV) -> bool:
"""Check whether the block match padding pattern and can be decomposed."""
return _ffi_api.CanDecomposePadding(self, block, loop) # type: ignore # pylint: disable=no-member

########## Schedule: Misc ##########

@type_checked
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,15 @@ void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_
this->state_->DebugVerify();
}

BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::DecomposePadding(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv));
TVM_TIR_SCHEDULE_END("decompose-padding", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class ConcreteScheduleNode : public ScheduleNode {
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const Array<IntImm>& axis_separators) override;
/******** Schedule: Padding decomposition ********/
BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) override;
/******** Schedule: Misc ********/
void EnterPostproc() override {}

Expand Down
32 changes: 31 additions & 1 deletion src/tir/schedule/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
*/
#ifndef TVM_TIR_SCHEDULE_ERROR_H_
#define TVM_TIR_SCHEDULE_ERROR_H_

#include <tvm/tir/schedule/state.h>

#include <string>
#include <utility>

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -52,6 +54,34 @@ class ScheduleError : public tvm::runtime::Error {
String RenderReport(const String& primitive) const;
};

class LoopPositionError : public ScheduleError {
public:
explicit LoopPositionError(IRModule mod, For loop, Block block, const std::string& primitive)
: mod_(std::move(mod)),
loop_(std::move(loop)),
block_(std::move(block)),
primitive_(primitive) {}

String FastErrorString() const final {
return "ScheduleError: " + primitive_ + " expect the loop to be an ancestor of block";
}

String DetailRenderTemplate() const final {
std::ostringstream os;
os << "ScheduleError: The input loop {0} of " << primitive_
<< " is required to be be an ancestor of block {1}.";
return os.str();
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {loop_, block_}; }

IRModule mod_;
For loop_;
Block block_;
std::string primitive_;
};

} // namespace tir
} // namespace tvm

Expand Down
11 changes: 11 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,17 @@ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int
TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
const IndexMap& index_map);

/******** Schedule: Padding decomposition ********/
/*!
* \brief Decompose a padding block into a block filling const pad values and a block
* writing in-bound values.
* \param block_sref The block sref that match the padding pattern.
* \param loop_sref The loop above which the const filling block is inserted before.
* \return The padding value filling block sref.
*/
TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref);

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
Loading