Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hexagon] 2-Stage Pipeline; Lower Async TIR primitives to Hexagon User DMA #12785

Merged
merged 8 commits into from
Sep 20, 2022
Merged
10 changes: 10 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,16 @@ TVM_DLL const Op& texture2d_load();
*/
TVM_DLL const Op& mem_copy();

/*!
* \brief Initiate a non-blocking DMA copy from source to destination
*/
TVM_DLL const Op& dma_copy();

/*!
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
*/
TVM_DLL const Op& dma_wait();
adstraw marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Provide a true statement that can be used for simplifications
*
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,11 @@ TVM_DLL Pass TextureFlatten();
*/
TVM_DLL Pass LowerVtcmAlloc();

/*!
* \brief Lower Async TIR primitives to DMA copy and wait builtins
*/
TVM_DLL Pass LowerAsyncDMA();

/*!
* \brief Implements a Common Subexpression Elimination (CSE) for TIR
* which introduces let-in bindings for duplicated sub-expressions.
Expand Down
12 changes: 8 additions & 4 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_ptx_async_copy", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);

using runtime::PackedFunc;
using runtime::TVMArgs;
Expand Down Expand Up @@ -225,6 +225,11 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
}
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
pass_list.push_back(tir::transform::LowerVtcmAlloc());
bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();

if (use_async_copy) {
pass_list.push_back(tir::transform::LowerAsyncDMA());
}
pass_list.push_back(tir::transform::UnrollLoop());

// Add user-defined phase-2 passes
Expand Down Expand Up @@ -543,10 +548,9 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());

bool use_ptx_async_copy =
pass_ctx->GetConfig<Bool>("tir.use_ptx_async_copy", Bool(false)).value();
bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();

if (use_ptx_async_copy) {
if (use_async_copy) {
mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy());
}

Expand Down
25 changes: 25 additions & 0 deletions src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include "../workspace_pool.h"
#include "hexagon_common.h"
#include "hexagon_user_dma.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -206,6 +207,30 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVM
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
int queue_id = args[0];
ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA");
void* dst = args[1];
void* src = args[2];
int size = args[3];
ICHECK(size > 0);

int ret = DMA_RETRY;
do {
ret = HexagonUserDMA::Get().Copy(dst, src, size);
} while (ret == DMA_RETRY);
*rv = static_cast<int32_t>(ret);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVMRetValue* rv) {
int queue_id = args[0];
ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA");
int inflight = args[1];
ICHECK(inflight >= 0);
HexagonUserDMA::Get().Wait(inflight);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
Expand Down
6 changes: 6 additions & 0 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,12 @@ TIR_DEFINE_BUILTIN_FUNC(texture2d_load)
TIR_DEFINE_BUILTIN_FUNC(mem_copy).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(dma_copy).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(dma_wait).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(assume)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo))
.set_num_inputs(1);
Expand Down
194 changes: 194 additions & 0 deletions src/tir/transforms/lower_async_dma.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file lower_async_dma.cc
*/

#include <tvm/arith/analyzer.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "ir_utils.h"

namespace tvm {
namespace tir {

class AsyncDMALowerer : public StmtExprMutator {
public:
AsyncDMALowerer() {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
adstraw marked this conversation as resolved.
Show resolved Hide resolved
// Convert this, for example:
// attr [0] "async_wait_queue_scope" = 0;
// attr [0] "async_wait_inflight_count" = 0;
//
// To this:
// @tir.dma_wait(
// 0, /* queue id */
// 0, /* in flight count */
// dtype=int32
// )
if (op->attr_key == tir::attr::async_wait_queue_scope) {
// get queue ID
auto queue_id_node = op->value.as<IntImmNode>();
ICHECK(queue_id_node);
int queue_id = queue_id_node->value;

// abort if we have not seen this queue ID in `copy` transform
if (queue_ids.find(queue_id) == queue_ids.end()) {
adstraw marked this conversation as resolved.
Show resolved Hide resolved
DLOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed in the "
"`async_wait_queue_scope` transform has not been previously observed in the "
"`async_commit_queue_scope` transform";
return StmtExprMutator::VisitStmt_(op);
}

auto async_wait = op->body.as<AttrStmtNode>();
if (!async_wait || async_wait->attr_key != tir::attr::async_wait_inflight_count) {
DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
"`async_wait_queue_scope` does not contain an `AttrStmtNode` with key "
"`async_wait_inflight_count`";
return StmtExprMutator::VisitStmt_(op);
}

auto call_dma_wait =
Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {queue_id, async_wait->value}));

// concatenate the call with the body and return
return SeqStmt({call_dma_wait, async_wait->body});

// Convert this, for example:
// attr [0] "async_commit_queue_scope" = 0;
// attr [0] "async_scope" = 1;
// for (ax0: int32, 0, 128) {
// A_global[ax0] = A[ax0]
// }
//
// To this:
// @tir.dma_copy(
// 0, /* queue id */
// @tir.address_of(A_global[0], dtype=handle),
// @tir.address_of(A[0], dtype=handle),
// 128, /* size */
// dtype=int32
// )
} else if (op->attr_key == tir::attr::async_commit_queue_scope) {
// get queue ID
auto queue_id_node = op->value.as<IntImmNode>();
ICHECK(queue_id_node);
int queue_id = queue_id_node->value;

// save queue ID for inspection in `wait` transform
queue_ids.insert(queue_id);

// walk the graph to verify this is a mem copy ...
// 1) async_commit_queue_scope contains async_scope
auto async_scope = op->body.as<AttrStmtNode>();
if (!async_scope || async_scope->attr_key != tir::attr::async_scope) {
DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
"`async_commit_queue_scope` does not contain an `AttrStmtNode` with key "
"`async_scope`";
return StmtExprMutator::VisitStmt_(op);
}

// 2) async_scope contains single for loop
auto for_loop = async_scope->body.as<ForNode>();
if (!for_loop) {
DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key "
"`async_scope` does not contain a single `ForNode`";
return StmtExprMutator::VisitStmt_(op);
}

// 3) for loop contains buffer store with single index
auto bufferstorenode = for_loop->body.as<BufferStoreNode>();
if (!bufferstorenode || bufferstorenode->indices.size() != 1) {
DLOG(INFO)
<< "AsyncDMALowerer exiting because the body of the `ForNode` does not contain a "
"single `BufferStoreNode` with a single index variable";
return StmtExprMutator::VisitStmt_(op);
}

// 4) buffer store value is a buffer load with single index
auto bufferloadnode = bufferstorenode->value.as<BufferLoadNode>();
if (!bufferloadnode || bufferloadnode->indices.size() != 1) {
adstraw marked this conversation as resolved.
Show resolved Hide resolved
DLOG(INFO) << "AsyncDMALowerer exiting because the value of the `BufferStoreNode` is not a "
"single `BufferLoadNode` with a single index variable";
return StmtExprMutator::VisitStmt_(op);
}

// get store buffer; assert it exists and is contiguous given it uses a single index
auto bufferstore = bufferstorenode->buffer.as<BufferNode>();
ICHECK(bufferstore && bufferstore->strides.empty());

// get load buffer; assert it exists and is contiguous given it uses a single index
auto bufferload = bufferloadnode->buffer.as<BufferNode>();
ICHECK(bufferload && bufferload->strides.empty());

// we will be replacing the entire for loop including its index
// with a DMA copy instrinsic that spans the entire index space of the for loop
// so we will need to replace the for loop index with value zero in the buffer indices
// thus we eliminate the index from the expression so the DMA copy receives the buffer range
// base address
Map<Var, PrimExpr> loop_var_remap = {{for_loop->loop_var, IntImm(DataType::Int(32), 0)}};

// map loop variable to zero for the store index & simplify
Array<PrimExpr> store_index = bufferstorenode->indices;
store_index.MutateByApply([&](PrimExpr expr) {
arith::Analyzer analyzer;
return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
});

// map loop variable to zero for the load index & simplify
Array<PrimExpr> load_index = bufferloadnode->indices;
load_index.MutateByApply([&](PrimExpr expr) {
arith::Analyzer analyzer;
return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap));
});

return Evaluate(Call(DataType::Int(32), builtin::dma_copy(),
{queue_id,
Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(bufferstorenode->buffer, store_index)}),
Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(bufferloadnode->buffer, load_index)}),
for_loop->extent * bufferloadnode->dtype.bytes()}));
}
return StmtExprMutator::VisitStmt_(op);
}

private:
std::set<int> queue_ids;
};

namespace transform {

Pass LowerAsyncDMA() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto fptr = f.CopyOnWrite();
fptr->body = AsyncDMALowerer()(std::move(fptr->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {});
}

TVM_REGISTER_GLOBAL("tir.transform.LowerAsyncDMA").set_body_typed(LowerAsyncDMA);
} // namespace transform

} // namespace tir
} // namespace tvm
30 changes: 30 additions & 0 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,10 @@ class BuiltinLower : public StmtExprMutator {
return make_zero(op->dtype);
} else if (op->op.same_as(builtin::mem_copy())) {
return MakeMemCopy(op);
} else if (op->op.same_as(builtin::dma_copy())) {
return MakeDMACopy(op);
} else if (op->op.same_as(builtin::dma_wait())) {
return MakeDMAWait(op);
} else {
return StmtExprMutator::VisitExpr_(op);
}
Expand All @@ -335,6 +339,32 @@ class BuiltinLower : public StmtExprMutator {
return VisitExpr(call_packed);
}

PrimExpr MakeDMACopy(const CallNode* op) {
PrimExpr queue_id = op->args[0];
PrimExpr dst = op->args[1];
PrimExpr src = op->args[2];
PrimExpr size = op->args[3];

std::string fdevapi_prefix =
"device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));

Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".dma_copy"), queue_id, dst, src, size});
return VisitExpr(call_packed);
}

PrimExpr MakeDMAWait(const CallNode* op) {
PrimExpr queue_id = op->args[0];
PrimExpr inflight = op->args[1];

std::string fdevapi_prefix =
"device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));

Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".dma_wait"), queue_id, inflight});
return VisitExpr(call_packed);
}

// call shape
PrimExpr MakeShape(const CallNode* op) {
// if args.size() == 0, it represents a scalar shape ()
Expand Down
Loading