Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TIR][Schedule] Support for specific consumer block targeting in cach…
Browse files Browse the repository at this point in the history
…e_read (apache#12505)

* Add optional consumer blocks to cache_read.

* remove comments

* Fully functional

* Add test for consumer targetting.

* Formatting.

* Add missing parameter comment.

* Fix comments

* Simplify type of consumer_blocks in python.

* Change how consumer_blocks is printed in python.
  • Loading branch information
Josh Fromm authored and xinetzone committed Nov 25, 2022
1 parent bee0cd0 commit 4432061
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 30 deletions.
4 changes: 3 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,12 @@ class ScheduleNode : public runtime::Object {
* \param block_rv The consumer block of the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope.
* \param consumer_blocks An optional list of consumers of the cache to rewrite.
* \return The cache stage block.
*/
virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) = 0;
const String& storage_scope,
const Array<BlockRV> consumer_blocks = {}) = 0;
/*!
* \brief Create a block that writes a buffer region into a write cache. It requires:
* 1) There is only one block who writes the target buffer.
Expand Down
17 changes: 15 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,7 +1012,11 @@ def after_unroll(a: T.handle, b: T.handle) -> None:

@type_checked
def cache_read(
self, block: Union[BlockRV, str], read_buffer_index: int, storage_scope: str
self,
block: Union[BlockRV, str],
read_buffer_index: int,
storage_scope: str,
consumer_blocks: Optional[List[Union[BlockRV, str]]] = None,
) -> BlockRV:
"""Create a block that reads a buffer region into a read cache. It requires:
Expand All @@ -1031,6 +1035,10 @@ def cache_read(
storage_scope: str
The target storage scope.
consumer_blocks: Optional[List[Union[BlockRV, str]]]
An optional list of consumers that should read from the cache. If not specified,
all consumers will use the cache.
Returns
-------
cached_block : BlockRV
Expand Down Expand Up @@ -1079,9 +1087,14 @@ def after_cache_read(a: T.handle, b: T.handle) -> None:
B[vi, vj] = A_local[vi, vj] * 2.0
"""
if consumer_blocks is None:
consumer_blocks = []

# Convert any string block names into Block RVs.
consumer_blocks = [self._normalize_block_arg(b) for b in consumer_blocks]
block = self._normalize_block_arg(block)
return _ffi_api.ScheduleCacheRead( # type: ignore # pylint: disable=no-member
self, block, read_buffer_index, storage_scope
self, block, read_buffer_index, storage_scope, consumer_blocks
)

@type_checked
Expand Down
11 changes: 9 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,10 +535,17 @@ void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) {
/******** Schedule: Insert cache stages ********/

BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) {
const String& storage_scope,
const Array<BlockRV> consumer_blocks) {
StmtSRef result{nullptr};
// Create a new array of SRefs from the consumer block list.
Array<StmtSRef> consumer_block_refs = {};
for (BlockRV block : consumer_blocks) {
consumer_block_refs.push_back(this->GetSRef(block));
}
TVM_TIR_SCHEDULE_BEGIN();
result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope);
result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope,
consumer_block_refs);
TVM_TIR_SCHEDULE_END("cache-read", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ class ConcreteScheduleNode : public ScheduleNode {
void Bind(const LoopRV& loop_rv, const String& thread_axis) override;
void Unroll(const LoopRV& loop_rv) override;
/******** Schedule: Insert cache stages ********/
BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) override;
BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope,
const Array<BlockRV> consumer_blocks = {}) override;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) override;
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,11 @@ TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref);
* \param block_sref The consumer block of the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope.
* \param consumer_blocks Array of blocks that consume the cache.
* \return The cache stage block.
*/
TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index,
const String& storage_scope);
const String& storage_scope, const Array<StmtSRef> consumer_blocks = {});
/*!
* \brief Create a block that writes a buffer region into a write cache. It requires:
* 1) There is only one block that writes the target buffer.
Expand Down
61 changes: 44 additions & 17 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ struct CacheStageInfo {
Stmt cache_stage;
/*! \brief The map used for ScheduleStateNode::Replace. */
Map<Block, Block> block_reuse;
/*! \brief A list of blocks that will consume the new cache. */
Array<StmtSRef> consumer_blocks;
};

/*! \brief Return the buffer region realted with the buffer */
Expand Down Expand Up @@ -525,7 +527,20 @@ class CacheReadRewriter : public StmtExprMutator {

Stmt VisitStmt_(const BlockNode* block) final {
Block old_stmt = GetRef<Block>(block);
// We don't mutate the block which generates info->read_buffer
// Check if this block is one of the specified consumers.
// If no consumer blocks are specified, all blocks should be considered consumers.
bool is_consumer = info_->consumer_blocks.empty();
// Otherwise check if this is one of the specified blocks.
for (StmtSRef consumer_sref : info_->consumer_blocks) {
const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_node, consumer_sref);
Block consumer_block = GetRef<Block>(consumer_node);
if (old_stmt.same_as(consumer_block)) {
is_consumer = true;
}
}
// Keep track of this blocks status. We'll use this when rewriting loads.
current_block_consumes = is_consumer;
// We don't mutate the block which generates info->read_buffer.
if (block != scope_sref_->stmt &&
GetBufferRegionFromBuffer(block->writes, info_->read_buffer).defined()) {
return std::move(old_stmt);
Expand All @@ -547,23 +562,26 @@ class CacheReadRewriter : public StmtExprMutator {
stmt = Block(n);
} else {
// Otherwise, update read regions and match_buffers
Array<BufferRegion> reads =
ReplaceBuffer(block->reads, info_->read_buffer, info_->write_buffer);
Array<MatchBufferRegion> match_buffers =
ReplaceBuffer(block->match_buffers, info_->read_buffer, info_->write_buffer);
if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) {
ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
n->reads = std::move(reads);
n->match_buffers = std::move(match_buffers);
stmt = Block(n);
// Only make this change if the block is one of the specified consumers.
if (is_consumer) {
Array<BufferRegion> reads =
ReplaceBuffer(block->reads, info_->read_buffer, info_->write_buffer);
Array<MatchBufferRegion> match_buffers =
ReplaceBuffer(block->match_buffers, info_->read_buffer, info_->write_buffer);
if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) {
ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
n->reads = std::move(reads);
n->match_buffers = std::move(match_buffers);
stmt = Block(n);
}
}
}
info_->block_reuse.Set(old_stmt, stmt);
return std::move(stmt);
}

PrimExpr VisitExpr_(const BufferLoadNode* load) final {
if (load->buffer.same_as(info_->read_buffer)) {
if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) {
ObjectPtr<BufferLoadNode> n = make_object<BufferLoadNode>(*load);
n->buffer = info_->write_buffer;
return PrimExpr(n);
Expand All @@ -588,6 +606,8 @@ class CacheReadRewriter : public StmtExprMutator {
const StmtSRef& scope_sref_;
/*! \brief The info for inserting cache stage */
CacheStageInfo* info_;
/*! \brief Whether the most recently visited block is a specified consumer. */
bool current_block_consumes;
};

/*! \brief Mutator for CacheWrite */
Expand Down Expand Up @@ -963,7 +983,7 @@ class ReIndexRewriter : public StmtExprMutator {
/******** Implementation ********/

StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index,
const String& storage_scope) {
const String& storage_scope, const Array<StmtSRef> consumer_blocks) {
/*!
* Check:
* - The index is in the array of block reading region
Expand Down Expand Up @@ -992,6 +1012,8 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
info.write_buffer = WithScope(read_buffer, storage_scope);
// Create the corresponding buffer allocation
info.alloc = info.write_buffer;
// Indicate which buffers should consume the cache.
info.consumer_blocks = consumer_blocks;

// Step 3. Update cache stage info.
BufferRegion cache_region{nullptr};
Expand Down Expand Up @@ -1170,21 +1192,26 @@ struct CacheReadTraits : public UnpackedInstTraits<CacheReadTraits> {
static constexpr bool kIsPure = false;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer read_buffer_index,
static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block,
Array<BlockRV> consumer_blocks, Integer read_buffer_index,
String storage_scope) {
return sch->CacheRead(block, read_buffer_index->value, storage_scope);
return sch->CacheRead(block, read_buffer_index->value, storage_scope, consumer_blocks);
}

static String UnpackedAsPython(Array<String> outputs, String block, Integer read_buffer_index,
String storage_scope) {
static String UnpackedAsPython(Array<String> outputs, String block, Array<String> consumer_blocks,
Integer read_buffer_index, String storage_scope) {
PythonAPICall py("cache_read");
py.Input("block", block);
py.Input("read_buffer_index", read_buffer_index->value);
py.Input("storage_scope", storage_scope);
// Only write out consumer blocks if provided.
if (!consumer_blocks.empty()) {
py.Input("consumer_blocks", consumer_blocks);
}
py.SingleOutput(outputs);
return py.Str();
}
Expand Down
3 changes: 3 additions & 0 deletions src/tir/schedule/trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs,
<< "TypeError: Expect 'tir.Var', but gets: " << dst->GetTypeKey();
return GetRef<Var>(static_cast<const VarNode*>(dst));
}));
} else if (input->IsInstance<ArrayNode>()) {
// Recursively convert elements of the array into a new list of ObjectRefs.
result.push_back(TranslateInputRVs(Downcast<Array<ObjectRef>>(input), rv_map));
} else {
ICHECK(false) << "TypeError: Cannot recognize the type of an input random variable: "
<< input->GetTypeKey();
Expand Down
8 changes: 5 additions & 3 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,14 @@ void TracedScheduleNode::Unroll(const LoopRV& loop_rv) {

/******** Schedule: Insert cache stages ********/
BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) {
BlockRV result = ConcreteScheduleNode::CacheRead(block_rv, read_buffer_index, storage_scope);
const String& storage_scope,
const Array<BlockRV> consumer_blocks) {
BlockRV result =
ConcreteScheduleNode::CacheRead(block_rv, read_buffer_index, storage_scope, consumer_blocks);

static const InstructionKind& kind = InstructionKind::Get("CacheRead");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv},
/*inputs=*/{block_rv, consumer_blocks},
/*attrs=*/{Integer(read_buffer_index), storage_scope},
/*outputs=*/{result}));
return result;
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
void Bind(const LoopRV& loop_rv, const String& thread_axis) final;
void Unroll(const LoopRV& loop_rv) final;
/******** Schedule: Insert cache stages ********/
BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) final;
BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope,
const Array<BlockRV> consumer_blocks = {}) final;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) final;
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
Expand Down
42 changes: 42 additions & 0 deletions tests/python/unittest/test_tir_schedule_cache_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,32 @@ def cache_read_multi_consumer() -> None:
C[vi] = A_global[vi]


@T.prim_func
def cache_read_multi_consumer_target() -> None:
A = T.alloc_buffer((128))
B = T.alloc_buffer((128))
C = T.alloc_buffer((128))
A_global = T.alloc_buffer((128))
for i in T.grid(8):
for j in T.grid(16):
with T.block("A"):
vi = T.axis.S(128, i * 16 + j)
A[vi] = 1.0
for j in T.grid(16):
with T.block("A"):
vi = T.axis.S(128, i * 16 + j)
A_global[vi] = A[vi]
for j in T.grid(16):
with T.block("B"):
vi = T.axis.S(128, i * 16 + j)
B[vi] = A[vi] + 1.0

for i in T.grid(128):
with T.block("C"):
vi = T.axis.S(128, i)
C[vi] = A_global[vi]


@T.prim_func
def continuous_cache_read(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
Expand Down Expand Up @@ -783,6 +809,22 @@ def test_cache_read_location(use_block_name):
tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)

# Test that specific consumer block targetting works.
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_b, 0, "global", consumer_blocks=[block_c])
tvm.ir.assert_structural_equal(cache_read_multi_consumer_target, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)

# Also test setting multiple consumers yields same result as unspecified.
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_b, 0, "global", consumer_blocks=[block_b, block_c])
tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)


def test_continuous_cache_read(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
Expand Down

0 comments on commit 4432061

Please sign in to comment.