Skip to content

Commit

Permalink
Revert "[Hexagon] Add concept of DMA groups (apache#14254)"
Browse files Browse the repository at this point in the history
This reverts commit c6c89c3.
  • Loading branch information
LeiWang1999 committed Mar 24, 2023
1 parent 126cfb1 commit 9f5c141
Show file tree
Hide file tree
Showing 19 changed files with 199 additions and 488 deletions.
3 changes: 2 additions & 1 deletion include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ class Postproc : public runtime::ObjectRef {
TVM_DLL static Postproc DisallowDynamicLoop();
/*!
* \brief Create a postprocessor that checks if all async mem copies are not strided.
* \param merge_async_commit_queue_scope Whether or not to merge async commit queue scope.
* \return The postprocessor created
*/
TVM_DLL static Postproc DisallowAsyncStridedMemCopy();
TVM_DLL static Postproc DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope = true);
/*!
* \brief Create a postprocessor that rewrites the cooperative fetch annotation to
* actual vectorized cooperative fetching in loop bindings.
Expand Down
37 changes: 1 addition & 36 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -727,49 +727,14 @@ TVM_DLL const Op& texture2d_load();

/*!
* \brief Initiate a non-blocking DMA copy from source to destination
*
* The copy is launched immediately.
*
* If a `dma_start_group()` call is active, the copy will be added
* to the current group for tracking of in-flight group counts.
*
* If no `dma_start_group()` call is active, the copy will be tracked
* individually i.e. as a group with size 1.
*/
TVM_DLL const Op& dma_copy();

/*!
* \brief Wait until the number of DMA groups in flight is less than
* or equal to some maximum
*
* Calling `dma_wait()` while a group is active is unsupported.
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
*/
TVM_DLL const Op& dma_wait();

/*!
* \brief Start a group of DMA copies
*
* Any call to `dma_copy()` that occurs after `dma_start_group()` will
* be added to the current group for tracking of in-flight group counts.
*
* Only one DMA group may be active at a given time. Calling
* `dma_start_group()` while a group is active is unsupported.
*/
TVM_DLL const Op& dma_start_group();

/*!
* \brief End a group of DMA copies
*
* Track all calls to `dma_copy()` that occurred since the preceding
* `dma_start_group()` as a single group in-flight.
*
* Calling `dma_end_group()` without an active group is unsupported.
*
* Note: A group of DMA calls may be empty, and will still contribute
* to the count of in-flight groups used by `dma_wait()`.
*/
TVM_DLL const Op& dma_end_group();

/*!
* \brief Provide a true statement that can be used for simplifications
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@

@register_object("meta_schedule.DisallowAsyncStridedMemCopy")
class DisallowAsyncStridedMemCopy(Postproc):
"""A postprocessor that disallows schedules that use async strided mem copies."""
"""A postprocessor that disallows schedules that use async strided mem copies.
def __init__(self) -> None:
Parameters
----------
merge_async_commit_queue_scope : bool
Whether or not to merge the async commit queue scope.
"""

def __init__(self, merge_async_commit_queue_scope=True) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocDisallowAsyncStridedMemCopy, # type: ignore # pylint: disable=no-member
merge_async_commit_queue_scope,
)
35 changes: 14 additions & 21 deletions python/tvm/tir/tensor_intrin/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,20 @@ def sync_dma_load_impl(a: T.handle, c: T.handle) -> None:
T.writes(C[0:size])
T.evaluate(
T.tvm_call_packed(
"device_api.hexagon.dma_copy_dltensor",
T.tvm_stack_make_array(
T.address_of(C[0], dtype="handle"),
T.tvm_stack_make_shape(size, dtype="handle"),
0,
1,
C.dtype,
0,
dtype="handle",
),
T.tvm_stack_make_array(
T.address_of(A[0], dtype="handle"),
T.tvm_stack_make_shape(size, dtype="handle"),
0,
1,
A.dtype,
0,
dtype="handle",
),
T.cast(size, dtype="int"),
False, # Do not use experimental bypass mode.
"device_api.hexagon.dma_copy",
-1, # Use QueueId of -1 to not interfere with async copies.
T.address_of(C[0], dtype="handle"),
T.address_of(A[0], dtype="handle"),
size,
0, # Do not use experimental bypass mode.
dtype="int32",
)
)
T.evaluate(
T.tvm_call_packed(
"device_api.hexagon.dma_wait",
-1,
0, # Wait for the sync queue (-1) to have 0 messages.
dtype="int32",
)
)
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
pass_list.push_back(tir::transform::InjectDoubleBuffer());
pass_list.push_back(tir::transform::VectorizeLoop(true));
pass_list.push_back(tir::transform::StorageRewrite());
transform::PassContext pass_ctx = transform::PassContext::Current();
pass_ctx->config.Set("tir.merge_async_commit_queue_scope",
Bool(merge_async_commit_queue_scope));
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
runtime::String(g_var->name_hint));
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
Expand All @@ -166,12 +169,15 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
return Postproc(n);
}

bool merge_async_commit_queue_scope = true;

static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy";
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode);
};

Postproc Postproc::DisallowAsyncStridedMemCopy() {
Postproc Postproc::DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope) {
ObjectPtr<DisallowAsyncStridedMemCopyNode> n = make_object<DisallowAsyncStridedMemCopyNode>();
n->merge_async_commit_queue_scope = merge_async_commit_queue_scope;
return Postproc(n);
}

Expand Down
19 changes: 3 additions & 16 deletions src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,10 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor")
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
uint32_t queue_id = static_cast<int>(args[0]);
int queue_id = args[0];
void* dst = args[1];
void* src = args[2];
uint32_t size = static_cast<int>(args[3]);
int size = args[3];
ICHECK(size > 0);
bool bypass_cache = args[4];

Expand All @@ -226,26 +226,13 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVM
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVMRetValue* rv) {
uint32_t queue_id = static_cast<int>(args[0]);
int queue_id = args[0];
int inflight = args[1];
ICHECK(inflight >= 0);
HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_start_group")
.set_body([](TVMArgs args, TVMRetValue* rv) {
uint32_t queue_id = static_cast<int>(args[0]);
HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_end_group").set_body([](TVMArgs args, TVMRetValue* rv) {
uint32_t queue_id = static_cast<int>(args[0]);
HexagonDeviceAPI::Global()->UserDMA()->EndGroup(queue_id);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
Expand Down
14 changes: 6 additions & 8 deletions src/runtime/hexagon/hexagon_user_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ unsigned int HexagonUserDMA::Init() {
return status;
}

int HexagonUserDMA::Copy(uint32_t queue_id, void* dst, void* src, uint32_t length,
bool bypass_cache) {
int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache) {
// length limited to 24 bits
if (length > DESC_LENGTH_MASK) {
return DMA_FAILURE;
Expand Down Expand Up @@ -104,15 +103,15 @@ int HexagonUserDMA::Copy(uint32_t queue_id, void* dst, void* src, uint32_t lengt
return DMA_SUCCESS;
}

void HexagonUserDMA::Wait(uint32_t queue_id, uint32_t max_dmas_in_flight) {
void HexagonUserDMA::Wait(int queue_id, uint32_t max_dmas_in_flight) {
// wait (forever) until max DMAs in flight <= actual DMAs in flight
while (DMAGroupsInFlight(queue_id) > max_dmas_in_flight) {
while (DMAsInFlight(queue_id) > max_dmas_in_flight) {
}
}

uint32_t HexagonUserDMA::Poll(uint32_t queue_id) { return DMAGroupsInFlight(queue_id); }
uint32_t HexagonUserDMA::Poll(int queue_id) { return DMAsInFlight(queue_id); }

uint32_t HexagonUserDMA::DMAGroupsInFlight(uint32_t queue_id) {
uint32_t HexagonUserDMA::DMAsInFlight(int queue_id) {
dmpoll(); // update DMA engine status
return descriptors_->InFlight(queue_id);
}
Expand All @@ -126,8 +125,7 @@ HexagonUserDMA::HexagonUserDMA() {
unsigned int done = dma_desc_get_done(dma_desc);
return (done != DESC_DONE_COMPLETE);
};
descriptors_ =
new QueuedRingBuffer<dma_desc_2d_t>(MAX_DMA_QUEUES, MAX_DMA_DESCRIPTORS, desc_in_flight);
descriptors_ = new QueuedRingBuffer<dma_desc_2d_t>(MAX_DMA_DESCRIPTORS, desc_in_flight);
}

HexagonUserDMA::~HexagonUserDMA() {
Expand Down
31 changes: 6 additions & 25 deletions src/runtime/hexagon/hexagon_user_dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ namespace hexagon {
#define DMA_FAILURE -1
#define DMA_RETRY 1
#define MAX_DMA_DESCRIPTORS 100
#define MAX_DMA_QUEUES 10
#define SYNC_DMA_QUEUE MAX_DMA_QUEUES - 1
#define SYNC_DMA_QUEUE -1

class HexagonUserDMA {
public:
Expand All @@ -48,50 +47,32 @@ class HexagonUserDMA {

/*!
* \brief Initiate DMA to copy memory from source to destination address
* \param queue_id The virtual DMA queue
* \param dst Destination address
* \param src Source address
* \param length Length in bytes to copy
* \returns Status: DMA_SUCCESS or DMA_FAILURE
*/
int Copy(uint32_t queue_id, void* dst, void* src, uint32_t length, bool bypass_cache);
int Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache);

/*!
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
* \param queue_id The virtual DMA queue
* \param max_dmas_in_flight Maximum number of DMAs allowed to be in flight
* to satisfy the `Wait` e.g. use `Wait(0)` to wait on "all" outstanding DMAs to complete
*/
void Wait(uint32_t queue_id, uint32_t max_dmas_in_flight);
void Wait(int queue_id, uint32_t max_dmas_in_flight);

/*!
* \brief Poll the number of DMAs in flight
* \param queue_id The virtual DMA queue
* \returns Number of DMAs in flight
*/
uint32_t Poll(uint32_t queue_id);

/*!
* \brief Start a group of DMA copies
* \param queue_id The virtual DMA queue
*/
void StartGroup(uint32_t queue_id) { descriptors_->StartGroup(queue_id); }

/*!
* \brief End a group of DMA copies
* \param queue_id The virtual DMA queue
*/
void EndGroup(uint32_t queue_id) { descriptors_->EndGroup(queue_id); }
uint32_t Poll(int queue_id);

private:
//! \brief Initializes the Hexagon User DMA engine
unsigned int Init();

/*!
* \brief Calculates and returns the number of DMAs in flight
* \param queue_id The virtual DMA queue
*/
uint32_t DMAGroupsInFlight(uint32_t queue_id);
//! \brief Calculates and returns the number of DMAs in flight
uint32_t DMAsInFlight(int queue_id);

//! \brief Tracks whether the very first DMA has been executed
bool first_dma_ = true;
Expand Down
Loading

0 comments on commit 9f5c141

Please sign in to comment.