diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index f297ca090482d..85fb9003e87fe 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -111,9 +111,10 @@ class Postproc : public runtime::ObjectRef { TVM_DLL static Postproc DisallowDynamicLoop(); /*! * \brief Create a postprocessor that checks if all async mem copies are not strided. + * \param merge_async_commit_queue_scope Whether or not to merge async commit queue scope. * \return The postprocessor created */ - TVM_DLL static Postproc DisallowAsyncStridedMemCopy(); + TVM_DLL static Postproc DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope = true); /*! * \brief Create a postprocessor that rewrites the cooperative fetch annotation to * actual vectorized cooperative fetching in loop bindings. diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index e8bcc028fc58b..708abde2cd315 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -727,49 +727,14 @@ TVM_DLL const Op& texture2d_load(); /*! * \brief Initiate a non-blocking DMA copy from source to destination - * - * The copy is launched immediately. - * - * If a `dma_start_group()` call is active, the copy will be added - * to the current group for tracking of in-flight group counts. - * - * If no `dma_start_group()` call is active, the copy will be tracked - * individually i.e. as a group with size 1. */ TVM_DLL const Op& dma_copy(); /*! - * \brief Wait until the number of DMA groups in flight is less than - * or equal to some maximum - * - * Calling `dma_wait()` while a group is active is unsupported. + * \brief Wait until the number of DMAs in flight is less than or equal to some maximum */ TVM_DLL const Op& dma_wait(); -/*! - * \brief Start a group of DMA copies - * - * Any call to `dma_copy()` that occurs after `dma_start_group()` will - * be added to the current group for tracking of in-flight group counts. - * - * Only one DMA group may be active at a given time. Calling - * `dma_start_group()` while a group is active is unsupported. - */ -TVM_DLL const Op& dma_start_group(); - -/*! - * \brief End a group of DMA copies - * - * Track all calls to `dma_copy()` that occurred since the preceding - * `dma_start_group()` as a single group in-flight. - * - * Calling `dma_end_group()` without an active group is unsupported. - * - * Note: A group of DMA calls may be empty, and will still contribute - * to the count of in-flight groups used by `dma_wait()`. - */ -TVM_DLL const Op& dma_end_group(); - /*! * \brief Provide a true statement that can be used for simplifications * diff --git a/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py b/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py index 0dcff9bf45a3e..7e0e00de2949b 100644 --- a/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py +++ b/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py @@ -23,9 +23,16 @@ @register_object("meta_schedule.DisallowAsyncStridedMemCopy") class DisallowAsyncStridedMemCopy(Postproc): - """A postprocessor that disallows schedules that use async strided mem copies.""" + """A postprocessor that disallows schedules that use async strided mem copies. - def __init__(self) -> None: + Parameters + ---------- + merge_async_commit_queue_scope : bool + Whether or not to merge the async commit queue scope. + """ + + def __init__(self, merge_async_commit_queue_scope=True) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocDisallowAsyncStridedMemCopy, # type: ignore # pylint: disable=no-member + merge_async_commit_queue_scope, ) diff --git a/python/tvm/tir/tensor_intrin/hexagon.py b/python/tvm/tir/tensor_intrin/hexagon.py index 22dd9a977c659..7a348f3f1a45f 100644 --- a/python/tvm/tir/tensor_intrin/hexagon.py +++ b/python/tvm/tir/tensor_intrin/hexagon.py @@ -47,27 +47,20 @@ def sync_dma_load_impl(a: T.handle, c: T.handle) -> None: T.writes(C[0:size]) T.evaluate( T.tvm_call_packed( - "device_api.hexagon.dma_copy_dltensor", - T.tvm_stack_make_array( - T.address_of(C[0], dtype="handle"), - T.tvm_stack_make_shape(size, dtype="handle"), - 0, - 1, - C.dtype, - 0, - dtype="handle", - ), - T.tvm_stack_make_array( - T.address_of(A[0], dtype="handle"), - T.tvm_stack_make_shape(size, dtype="handle"), - 0, - 1, - A.dtype, - 0, - dtype="handle", - ), - T.cast(size, dtype="int"), - False, # Do not use experimental bypass mode. + "device_api.hexagon.dma_copy", + -1, # Use QueueId of -1 to not interfere with async copies. + T.address_of(C[0], dtype="handle"), + T.address_of(A[0], dtype="handle"), + size, + 0, # Do not use experimental bypass mode. + dtype="int32", + ) + ) + T.evaluate( + T.tvm_call_packed( + "device_api.hexagon.dma_wait", + -1, + 0, # Wait for the sync queue (-1) to have 0 messages. dtype="int32", ) ) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 3458376848f13..1e126ebf7bc0c 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -52,6 +52,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index 2d1507c8994d9..952810a47aeef 100644 --- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -145,6 +145,9 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { pass_list.push_back(tir::transform::InjectDoubleBuffer()); pass_list.push_back(tir::transform::VectorizeLoop(true)); pass_list.push_back(tir::transform::StorageRewrite()); + transform::PassContext pass_ctx = transform::PassContext::Current(); + pass_ctx->config.Set("tir.merge_async_commit_queue_scope", + Bool(merge_async_commit_queue_scope)); tir::PrimFunc f = WithAttr(GetRef(prim_func), "global_symbol", runtime::String(g_var->name_hint)); IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); @@ -166,12 +169,15 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { return Postproc(n); } + bool merge_async_commit_queue_scope = true; + static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy"; TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode); }; -Postproc Postproc::DisallowAsyncStridedMemCopy() { +Postproc Postproc::DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope) { ObjectPtr n = make_object(); + n->merge_async_commit_queue_scope = merge_async_commit_queue_scope; return Postproc(n); } diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index 16e67aa9650f1..ee2a826b02ea1 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -210,10 +210,10 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor") }); TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVMRetValue* rv) { - uint32_t queue_id = static_cast(args[0]); + int queue_id = args[0]; void* dst = args[1]; void* src = args[2]; - uint32_t size = static_cast(args[3]); + int size = args[3]; ICHECK(size > 0); bool bypass_cache = args[4]; @@ -226,26 +226,13 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVM }); TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVMRetValue* rv) { - uint32_t queue_id = static_cast(args[0]); + int queue_id = args[0]; int inflight = args[1]; ICHECK(inflight >= 0); HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight); *rv = static_cast(0); }); -TVM_REGISTER_GLOBAL("device_api.hexagon.dma_start_group") - .set_body([](TVMArgs args, TVMRetValue* rv) { - uint32_t queue_id = static_cast(args[0]); - HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id); - *rv = static_cast(0); - }); - -TVM_REGISTER_GLOBAL("device_api.hexagon.dma_end_group").set_body([](TVMArgs args, TVMRetValue* rv) { - uint32_t queue_id = static_cast(args[0]); - HexagonDeviceAPI::Global()->UserDMA()->EndGroup(queue_id); - *rv = static_cast(0); -}); - TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) { int32_t device_type = args[0]; int32_t device_id = args[1]; diff --git a/src/runtime/hexagon/hexagon_user_dma.cc b/src/runtime/hexagon/hexagon_user_dma.cc index 11214a46e8099..c30fd645bbd0a 100644 --- a/src/runtime/hexagon/hexagon_user_dma.cc +++ b/src/runtime/hexagon/hexagon_user_dma.cc @@ -32,8 +32,7 @@ unsigned int HexagonUserDMA::Init() { return status; } -int HexagonUserDMA::Copy(uint32_t queue_id, void* dst, void* src, uint32_t length, - bool bypass_cache) { +int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache) { // length limited to 24 bits if (length > DESC_LENGTH_MASK) { return DMA_FAILURE; @@ -104,15 +103,15 @@ int HexagonUserDMA::Copy(uint32_t queue_id, void* dst, void* src, uint32_t lengt return DMA_SUCCESS; } -void HexagonUserDMA::Wait(uint32_t queue_id, uint32_t max_dmas_in_flight) { +void HexagonUserDMA::Wait(int queue_id, uint32_t max_dmas_in_flight) { // wait (forever) until max DMAs in flight <= actual DMAs in flight - while (DMAGroupsInFlight(queue_id) > max_dmas_in_flight) { + while (DMAsInFlight(queue_id) > max_dmas_in_flight) { } } -uint32_t HexagonUserDMA::Poll(uint32_t queue_id) { return DMAGroupsInFlight(queue_id); } +uint32_t HexagonUserDMA::Poll(int queue_id) { return DMAsInFlight(queue_id); } -uint32_t HexagonUserDMA::DMAGroupsInFlight(uint32_t queue_id) { +uint32_t HexagonUserDMA::DMAsInFlight(int queue_id) { dmpoll(); // update DMA engine status return descriptors_->InFlight(queue_id); } @@ -126,8 +125,7 @@ HexagonUserDMA::HexagonUserDMA() { unsigned int done = dma_desc_get_done(dma_desc); return (done != DESC_DONE_COMPLETE); }; - descriptors_ = - new QueuedRingBuffer(MAX_DMA_QUEUES, MAX_DMA_DESCRIPTORS, desc_in_flight); + descriptors_ = new QueuedRingBuffer(MAX_DMA_DESCRIPTORS, desc_in_flight); } HexagonUserDMA::~HexagonUserDMA() { diff --git a/src/runtime/hexagon/hexagon_user_dma.h b/src/runtime/hexagon/hexagon_user_dma.h index 70590cbe4faa9..9397a16e3f036 100644 --- a/src/runtime/hexagon/hexagon_user_dma.h +++ b/src/runtime/hexagon/hexagon_user_dma.h @@ -34,8 +34,7 @@ namespace hexagon { #define DMA_FAILURE -1 #define DMA_RETRY 1 #define MAX_DMA_DESCRIPTORS 100 -#define MAX_DMA_QUEUES 10 -#define SYNC_DMA_QUEUE MAX_DMA_QUEUES - 1 +#define SYNC_DMA_QUEUE -1 class HexagonUserDMA { public: @@ -48,50 +47,32 @@ class HexagonUserDMA { /*! * \brief Initiate DMA to copy memory from source to destination address - * \param queue_id The virtual DMA queue * \param dst Destination address * \param src Source address * \param length Length in bytes to copy * \returns Status: DMA_SUCCESS or DMA_FAILURE */ - int Copy(uint32_t queue_id, void* dst, void* src, uint32_t length, bool bypass_cache); + int Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache); /*! * \brief Wait until the number of DMAs in flight is less than or equal to some maximum - * \param queue_id The virtual DMA queue * \param max_dmas_in_flight Maximum number of DMAs allowed to be in flight * to satisfy the `Wait` e.g. use `Wait(0)` to wait on "all" outstanding DMAs to complete */ - void Wait(uint32_t queue_id, uint32_t max_dmas_in_flight); + void Wait(int queue_id, uint32_t max_dmas_in_flight); /*! * \brief Poll the number of DMAs in flight - * \param queue_id The virtual DMA queue * \returns Number of DMAs in flight */ - uint32_t Poll(uint32_t queue_id); - - /*! - * \brief Start a group of DMA copies - * \param queue_id The virtual DMA queue - */ - void StartGroup(uint32_t queue_id) { descriptors_->StartGroup(queue_id); } - - /*! - * \brief End a group of DMA copies - * \param queue_id The virtual DMA queue - */ - void EndGroup(uint32_t queue_id) { descriptors_->EndGroup(queue_id); } + uint32_t Poll(int queue_id); private: //! \brief Initializes the Hexagon User DMA engine unsigned int Init(); - /*! - * \brief Calculates and returns the number of DMAs in flight - * \param queue_id The virtual DMA queue - */ - uint32_t DMAGroupsInFlight(uint32_t queue_id); + //! \brief Calculates and returns the number of DMAs in flight + uint32_t DMAsInFlight(int queue_id); //! \brief Tracks whether the very first DMA has been executed bool first_dma_ = true; diff --git a/src/runtime/hexagon/ring_buffer.h b/src/runtime/hexagon/ring_buffer.h index 91adad6a65e57..4294ded8f52a1 100644 --- a/src/runtime/hexagon/ring_buffer.h +++ b/src/runtime/hexagon/ring_buffer.h @@ -21,8 +21,6 @@ #define TVM_RUNTIME_HEXAGON_RING_BUFFER_H_ #include -#include -#include #include #include "hexagon_common.h" @@ -96,33 +94,17 @@ class RingBuffer { template class QueuedRingBuffer : RingBuffer { public: - QueuedRingBuffer(uint32_t max_queues, uint32_t ring_buff_size, std::function in_flight) - : RingBuffer(ring_buff_size, in_flight), max_queues_(max_queues) { - queue_descriptors_.resize(max_queues_); - } + QueuedRingBuffer(uint32_t ring_buff_size, std::function in_flight) + : RingBuffer(ring_buff_size, in_flight) {} //! \brief Returns pointer to next T; add the queue ID for tracking - T* Next(uint32_t queue_id) { - CHECK_LT(queue_id, max_queues_); + T* Next(int queue_id) { queue_ids_.push_back(queue_id); - queue_descriptor* d = &queue_descriptors_[queue_id]; - if (d->group_started) { - // if we have a group started just update then pending count - d->pending_in_group++; - } else { - // else create group with size one - d->groups.push(1); - d->pending_total++; - } return RingBuffer::Next(); } - //! \brief Returns the number of groups of Ts in flight for a given queue ID - uint32_t InFlight(uint32_t queue_id) { - CHECK_LT(queue_id, max_queues_); - queue_descriptor* d = &queue_descriptors_[queue_id]; - CHECK(!d->group_started); - + //! \brief Returns the number of Ts in flight for a given queue ID + uint32_t InFlight(int queue_id) { uint32_t in_flight = 0; // look at the queue IDs for the RingBuffer entries in flight for (size_t i = queue_ids_.size() - RingBuffer::InFlight(); i < queue_ids_.size(); ++i) { @@ -131,57 +113,11 @@ class QueuedRingBuffer : RingBuffer { in_flight++; } } - - // calculate number of groups in flight - while (!d->groups.empty() && d->pending_total - d->groups.front() >= in_flight) { - d->pending_total -= d->groups.front(); - d->groups.pop(); - } - - // return the number of groups in flight - return d->groups.size(); - } - - //! \brief Start a group of Ts, if not called the deafault group size is one - void StartGroup(uint32_t queue_id) { - CHECK_LT(queue_id, max_queues_); - queue_descriptor* d = &queue_descriptors_[queue_id]; - CHECK(!d->group_started); - - // start group - d->group_started = true; - d->pending_in_group = 0; - } - - //! \brief End a group of Ts - void EndGroup(uint32_t queue_id) { - CHECK_LT(queue_id, max_queues_); - queue_descriptor* d = &queue_descriptors_[queue_id]; - CHECK(d->group_started); - CHECK(d->pending_in_group); - - // create group - if (d->pending_in_group) { - d->groups.emplace(d->pending_in_group); - } - d->pending_total += d->pending_in_group; - - // end group - d->group_started = false; - d->pending_in_group = 0; + return in_flight; } private: - struct queue_descriptor { - uint32_t pending_total = 0; - uint32_t pending_in_group = 0; - bool group_started = false; - std::queue groups; - }; - - const int max_queues_; std::vector queue_ids_; - std::vector queue_descriptors_; }; } // namespace hexagon diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index c855904284501..e240b7b701ba1 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -335,12 +335,6 @@ TIR_DEFINE_BUILTIN_FUNC(dma_copy).set_attr("TCallEffectKind", TIR_DEFINE_BUILTIN_FUNC(dma_wait).set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(dma_start_group) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(dma_end_group) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); - TIR_DEFINE_BUILTIN_FUNC(assume) .set_attr("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo)) .set_num_inputs(1); diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 7e6739a512cbd..51523a37399b4 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -309,9 +309,10 @@ class PipelineRewriter : public StmtExprMutator { const Array pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, const std::unordered_map& fragment_info, - const Map preserved_annotations) { + const Map preserved_annotations, bool merge_async_commit_queue_scope) { PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop, - pipeline_info, fragment_info, preserved_annotations); + pipeline_info, fragment_info, preserved_annotations, + merge_async_commit_queue_scope); return rewriter.BuildPipeline(); } @@ -321,7 +322,8 @@ class PipelineRewriter : public StmtExprMutator { const Array& pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, const std::unordered_map& fragment_info, - const Map preserved_annotations) + const Map preserved_annotations, + bool merge_async_commit_queue_scope) : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), double_buffers_(double_buffers), @@ -329,7 +331,8 @@ class PipelineRewriter : public StmtExprMutator { pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info), fragment_info_(fragment_info), - preserved_annotations_(preserved_annotations) {} + preserved_annotations_(preserved_annotations), + merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {} Stmt BuildPipeline() { // Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions @@ -763,7 +766,7 @@ class PipelineRewriter : public StmtExprMutator { group_bodies.push_back(new_blocks[i].block->body); } - if (group_bodies.size() > 1) { + if (merge_async_commit_queue_scope_ && group_bodies.size() > 1) { auto merged_bodies = SeqStmt(group_bodies); group_bodies.clear(); group_bodies.push_back(merged_bodies); @@ -850,7 +853,8 @@ class PipelineRewriter : public StmtExprMutator { auto& local_state = async_states_local[stage]; int commit_group_id = -1; - if (local_state.commit_groups.empty() || local_state.consumed) { + if (local_state.commit_groups.empty() || local_state.consumed || + !merge_async_commit_queue_scope_) { // consumed == true means there is already a consumer stage waiting for an // eariler async operation of this stage. In such cases, we make multiple commit_queue // for this stage. @@ -950,6 +954,7 @@ class PipelineRewriter : public StmtExprMutator { Array ordered_stmts_; std::map async_states; Map preserved_annotations_; + bool merge_async_commit_queue_scope_ = true; }; /*! @@ -988,8 +993,8 @@ void BuildDependencyGraph( class PipelineInjector : private StmtExprMutator { public: - static Stmt Inject(const PrimFunc& func) { - PipelineInjector injector; + static Stmt Inject(const PrimFunc& func, bool merge_async_commit_queue_scope) { + PipelineInjector injector(merge_async_commit_queue_scope); for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; injector.buffer_data_to_buffer_.Set(buffer->data, buffer); @@ -999,7 +1004,8 @@ class PipelineInjector : private StmtExprMutator { } private: - PipelineInjector() {} + explicit PipelineInjector(bool merge_async_commit_queue_scope) + : merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {} /*! * \brief Check the pipeline satisfies the following conditions: @@ -1134,9 +1140,9 @@ class PipelineInjector : private StmtExprMutator { ValidatePipelineBody(pipeline_info, original_order); // Step 4: Rewrite the pipeline body. - Stmt pipeline = PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers, - pipeline_allocs, GetRef(op), pipeline_info, - fragment_info_, preserved_annotations); + Stmt pipeline = PipelineRewriter::Rewrite( + buffer_data_to_buffer_, double_buffers, pipeline_allocs, GetRef(op), pipeline_info, + fragment_info_, preserved_annotations, merge_async_commit_queue_scope_); if (const auto* realize = op->body.as()) { const auto& block = realize->block; @@ -1205,6 +1211,7 @@ class PipelineInjector : private StmtExprMutator { Map buffer_data_to_buffer_; std::unordered_map fragment_info_; std::unordered_set double_buffers; + bool merge_async_commit_queue_scope_ = true; }; } // namespace software_pipeline @@ -1218,7 +1225,9 @@ namespace transform { Pass InjectSoftwarePipeline() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* fptr = f.CopyOnWrite(); - fptr->body = software_pipeline::PipelineInjector::Inject(f); + bool merge_async_commit_queue_scope = + ctx->GetConfig("tir.merge_async_commit_queue_scope", Bool(true)).value(); + fptr->body = software_pipeline::PipelineInjector::Inject(f, merge_async_commit_queue_scope); fptr->body = ConvertSSA(std::move(fptr->body)); return f; }; diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index d899b6ec70abf..57cff7985f1c6 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -22,63 +22,26 @@ */ #include -#include #include -#include -#include -#include #include #include -#include - -#include "../../arith/ir_mutator_with_analyzer.h" #include "ir_utils.h" namespace tvm { namespace tir { -class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { +class AsyncDMALowerer : public StmtExprMutator { public: - explicit AsyncDMALowerer(bool dma_bypass_cache, arith::Analyzer* analyzer) - : IRMutatorWithAnalyzer(analyzer), dma_bypass_cache_(dma_bypass_cache) {} - - Stmt VisitStmt_(const ForNode* loop) final { - // if for loop is not within async_commit_queue_scope - if (!async_queue_id_.has_value()) { - return arith::IRMutatorWithAnalyzer::VisitStmt_(loop); - } + explicit AsyncDMALowerer(bool dma_bypass_cache) : dma_bypass_cache_(dma_bypass_cache) {} - // if for loop is not a memcpy of a contiguous region - std::optional mem_copy = IdentifyMemCpy(GetRef(loop), analyzer_); - if (!mem_copy.has_value() || mem_copy->dest->region.size() != 1 || - mem_copy->source->region.size() != 1) { - LOG(FATAL) << "Unable to lower async dma due to non contiguous memory access"; - } - - // now that we are about to perform the `copy` transform - // save queue ID for inspection in `wait` transform - // and, increment the number of DMA copies in the group - queue_ids_.insert(async_queue_id_.value()); - dmas_in_group_++; - - tvm::PrimExpr src_min = mem_copy->source->region[0]->min; - tvm::PrimExpr dst_min = mem_copy->dest->region[0]->min; - tvm::PrimExpr dst_extent = mem_copy->dest->region[0]->extent; - - auto src = BufferLoad(mem_copy->source->buffer, {src_min}); - auto dst = BufferLoad(mem_copy->dest->buffer, {dst_min}); - return Evaluate( - Call(DataType::Int(32), builtin::dma_copy(), - {async_queue_id_.value(), Call(DataType::Handle(), builtin::address_of(), {dst}), - Call(DataType::Handle(), builtin::address_of(), {src}), - dst_extent * src->dtype.bytes(), dma_bypass_cache_})); + // Create member statement to track a mapping from iter var to iter range + Stmt VisitStmt_(const ForNode* op) final { + input_iters.Set(op->loop_var, Range(op->min, op->extent)); + return StmtExprMutator::VisitStmt_(op); } Stmt VisitStmt_(const AttrStmtNode* op) final { - // populate analyzer knowledge of loop iterators - auto previsit = arith::IRMutatorWithAnalyzer::VisitStmt_(op); - // Convert this, for example: // attr [0] "async_wait_queue_scope" = 0; // attr [0] "async_wait_inflight_count" = 0; @@ -100,7 +63,7 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { DLOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed in the " "`async_wait_queue_scope` transform has not been previously observed in the " "`async_commit_queue_scope` transform"; - return previsit; + return StmtExprMutator::VisitStmt_(op); } auto async_wait = op->body.as(); @@ -108,13 +71,14 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " "`async_wait_queue_scope` does not contain an `AttrStmtNode` with key " "`async_wait_inflight_count`"; - return previsit; + return StmtExprMutator::VisitStmt_(op); } + auto call_dma_wait = Evaluate(Call(DataType::Int(32), builtin::dma_wait(), {queue_id, async_wait->value})); // concatenate the call with the body and return - return SeqStmt({call_dma_wait, arith::IRMutatorWithAnalyzer::VisitStmt(async_wait->body)}); + return SeqStmt({call_dma_wait, StmtExprMutator::VisitStmt(async_wait->body)}); // Convert this, for example: // attr [0] "async_commit_queue_scope" = 0; @@ -135,27 +99,112 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { // get queue ID auto queue_id_node = op->value.as(); ICHECK(queue_id_node); - async_queue_id_ = queue_id_node->value; - auto result = arith::IRMutatorWithAnalyzer::VisitStmt_(op); - if (dmas_in_group_ > 1) { - auto call_dma_start_group = Evaluate( - Call(DataType::Int(32), builtin::dma_start_group(), {async_queue_id_.value()})); - auto call_dma_end_group = - Evaluate(Call(DataType::Int(32), builtin::dma_end_group(), {async_queue_id_.value()})); - result = SeqStmt({call_dma_start_group, result, call_dma_end_group}); + int queue_id = queue_id_node->value; + + // walk the graph to verify this is a mem copy ... + // 1) async_commit_queue_scope contains async_scope + auto async_scope = op->body.as(); + if (!async_scope || async_scope->attr_key != tir::attr::async_scope) { + DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " + "`async_commit_queue_scope` does not contain an `AttrStmtNode` with key " + "`async_scope`"; + return StmtExprMutator::VisitStmt_(op); + } + + // 2) async_scope contains single for loop + auto for_loop = async_scope->body.as(); + if (!for_loop) { + DLOG(INFO) << "AsyncDMALowerer exiting because the body of the `AttrStmtNode` with key " + "`async_scope` does not contain a single `ForNode`"; + return StmtExprMutator::VisitStmt_(op); + } + + // Add the current loop to the input iters mapping. + input_iters.Set(for_loop->loop_var, Range(for_loop->min, for_loop->extent)); + + // 3) for loop contains buffer store with single index + auto bufferstorenode = for_loop->body.as(); + if (!bufferstorenode || bufferstorenode->indices.size() != 1) { + DLOG(INFO) + << "AsyncDMALowerer exiting because the body of the `ForNode` does not contain a " + "single `BufferStoreNode` with a single index variable"; + return StmtExprMutator::VisitStmt_(op); + } + + // 4) buffer store value is a buffer load with single index + auto bufferloadnode = bufferstorenode->value.as(); + if (!bufferloadnode || bufferloadnode->indices.size() != 1) { + DLOG(INFO) << "AsyncDMALowerer exiting because the value of the `BufferStoreNode` is not a " + "single `BufferLoadNode` with a single index variable"; + return StmtExprMutator::VisitStmt_(op); + } + + // get store buffer; assert it exists and is contiguous given it uses a single index + auto bufferstore = bufferstorenode->buffer.as(); + ICHECK(bufferstore && bufferstore->strides.empty()); + + // get load buffer; assert it exists and is contiguous given it uses a single index + auto bufferload = bufferloadnode->buffer.as(); + ICHECK(bufferload && bufferload->strides.empty()); + + // we will be replacing the entire for loop including its index + // with a DMA copy instrinsic that spans the entire index space of the for loop + // so we will need to replace the for loop index with value zero in the buffer indices + // thus we eliminate the index from the expression so the DMA copy receives the buffer range + // base address + Map loop_var_remap = {{for_loop->loop_var, IntImm(DataType::Int(32), 0)}}; + + // map loop variable to zero for the store index & simplify + Array store_index = bufferstorenode->indices; + + // Use DetectIterMap to detect whether store index is non-contiguous. + arith::Analyzer analyzer; + auto store_iter_map = DetectIterMap(store_index, input_iters, 1, + arith::IterMapLevel::Surjective, &analyzer, false); + if (!store_iter_map->errors.empty()) { + LOG(FATAL) + << "Unable to lower async dma for non contiguous memory access with store index: " + << store_index; + } + + store_index.MutateByApply([&](PrimExpr expr) { + arith::Analyzer analyzer; + return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); + }); + + // map loop variable to zero for the load index & simplify + Array load_index = bufferloadnode->indices; + + // Use DetectIterMap to detect whether load index is non-contiguous. + auto load_iter_map = DetectIterMap(load_index, input_iters, 1, + arith::IterMapLevel::Surjective, &analyzer, false); + if (!load_iter_map->errors.empty()) { + LOG(FATAL) << "Unable to lower async dma for non contiguous memory access with load index: " + << load_index; } - async_queue_id_ = std::nullopt; - dmas_in_group_ = 0; - return result; + load_index.MutateByApply([&](PrimExpr expr) { + arith::Analyzer analyzer; + return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); + }); + + // now that we are about to perform the `copy` transform + // save queue ID for inspection in `wait` transform + queue_ids_.insert(queue_id); + + return Evaluate(Call(DataType::Int(32), builtin::dma_copy(), + {queue_id, + Call(DataType::Handle(), builtin::address_of(), + {BufferLoad(bufferstorenode->buffer, store_index)}), + Call(DataType::Handle(), builtin::address_of(), + {BufferLoad(bufferloadnode->buffer, load_index)}), + for_loop->extent * bufferloadnode->dtype.bytes(), dma_bypass_cache_})); } - return arith::IRMutatorWithAnalyzer::VisitStmt_(op); + return StmtExprMutator::VisitStmt_(op); } private: - int dmas_in_group_ = 0; std::set queue_ids_; - std::optional async_queue_id_ = std::nullopt; bool dma_bypass_cache_; Map input_iters = Map(); }; @@ -165,10 +214,9 @@ namespace transform { Pass LowerAsyncDMA() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto fptr = f.CopyOnWrite(); - arith::Analyzer analyzer; bool dma_bypass_cache = ctx->GetConfig("tir.experimental_dma_bypass_cache", Bool(false)).value(); - fptr->body = AsyncDMALowerer(dma_bypass_cache, &analyzer)(std::move(fptr->body)); + fptr->body = AsyncDMALowerer(dma_bypass_cache)(std::move(fptr->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {}); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index d8df2cc55a0bc..7f047874234fc 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -320,10 +320,6 @@ class BuiltinLower : public StmtExprMutator { return MakeDMACopy(op); } else if (op->op.same_as(builtin::dma_wait())) { return MakeDMAWait(op); - } else if (op->op.same_as(builtin::dma_start_group())) { - return MakeDMAStartGroup(op); - } else if (op->op.same_as(builtin::dma_end_group())) { - return MakeDMAEndGroup(op); } else { return StmtExprMutator::VisitExpr_(op); } @@ -357,28 +353,6 @@ class BuiltinLower : public StmtExprMutator { return VisitExpr(call_packed); } - PrimExpr MakeDMAStartGroup(const CallNode* op) { - PrimExpr queue_id = op->args[0]; - - std::string fdevapi_prefix = - "device_api." + std::string(runtime::DeviceName(device_type_.as()->value)); - - Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(), - {StringImm(fdevapi_prefix + ".dma_start_group"), queue_id}); - return VisitExpr(call_packed); - } - - PrimExpr MakeDMAEndGroup(const CallNode* op) { - PrimExpr queue_id = op->args[0]; - - std::string fdevapi_prefix = - "device_api." + std::string(runtime::DeviceName(device_type_.as()->value)); - - Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(), - {StringImm(fdevapi_prefix + ".dma_end_group"), queue_id}); - return VisitExpr(call_packed); - } - // call shape PrimExpr MakeShape(const CallNode* op) { // if args.size() == 0, it represents a scalar shape () diff --git a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc index 9697006bf1aa6..e4ffe3a0de9c6 100644 --- a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc @@ -47,7 +47,7 @@ class HexagonUserDMATest : public ::testing::Test { public: HexagonUserDMA* user_dma; int ret = 0; - uint32_t queue_id = 0; + int queue_id = 0; void* src = nullptr; void* dst = nullptr; char* src_char = nullptr; diff --git a/tests/cpp-runtime/hexagon/ring_buffer_tests.cc b/tests/cpp-runtime/hexagon/ring_buffer_tests.cc index a3abf82b863fc..8cf363bae0b39 100644 --- a/tests/cpp-runtime/hexagon/ring_buffer_tests.cc +++ b/tests/cpp-runtime/hexagon/ring_buffer_tests.cc @@ -40,7 +40,7 @@ class RingBufferTest : public ::testing::Test { int finished = 42; int inflight = 43; - uint32_t size = 8; + uint32_t size = 4; uint32_t half = size / 2; RingBuffer* ring_buff = nullptr; }; @@ -160,11 +160,11 @@ TEST_F(RingBufferTest, half_in_flight) { // mark it inflight and check *ptr = inflight; - ASSERT_EQ(ring_buff->InFlight(), half + 1); + ASSERT_EQ(ring_buff->InFlight(), 3); // mark it finished and check also blocked *ptr = finished; - ASSERT_EQ(ring_buff->InFlight(), half + 1); + ASSERT_EQ(ring_buff->InFlight(), 3); } TEST_F(RingBufferTest, half_in_flight_blocked) { @@ -190,23 +190,13 @@ TEST_F(RingBufferTest, half_in_flight_blocked) { } class QueuedRingBufferTest : public RingBufferTest { - void SetUp() override { - queued_ring_buff = new QueuedRingBuffer(MAX_QUEUES, size, in_flight); - } + void SetUp() override { queued_ring_buff = new QueuedRingBuffer(size, in_flight); } void TearDown() override { delete queued_ring_buff; } public: - int MAX_QUEUES = 2; QueuedRingBuffer* queued_ring_buff = nullptr; }; -TEST_F(QueuedRingBufferTest, invalid_queue) { - ASSERT_THROW(queued_ring_buff->Next(MAX_QUEUES), InternalError); - ASSERT_THROW(queued_ring_buff->InFlight(MAX_QUEUES), InternalError); - ASSERT_THROW(queued_ring_buff->StartGroup(MAX_QUEUES), InternalError); - ASSERT_THROW(queued_ring_buff->EndGroup(MAX_QUEUES), InternalError); -} - TEST_F(QueuedRingBufferTest, two_queues) { int* q0 = queued_ring_buff->Next(0); *q0 = inflight; @@ -226,188 +216,3 @@ TEST_F(QueuedRingBufferTest, two_queues) { ASSERT_EQ(queued_ring_buff->InFlight(0), 0); ASSERT_EQ(queued_ring_buff->InFlight(1), 0); } - -TEST_F(QueuedRingBufferTest, group_end_before_group_start) { - ASSERT_THROW(queued_ring_buff->EndGroup(0), InternalError); -} - -TEST_F(QueuedRingBufferTest, group_restart) { - queued_ring_buff->StartGroup(0); - ASSERT_THROW(queued_ring_buff->StartGroup(0), InternalError); -} - -TEST_F(QueuedRingBufferTest, zero_size_group) { - queued_ring_buff->StartGroup(0); - ASSERT_THROW(queued_ring_buff->EndGroup(0), InternalError); -} - -TEST_F(QueuedRingBufferTest, in_flight_before_group_end) { - queued_ring_buff->StartGroup(0); - ASSERT_THROW(queued_ring_buff->InFlight(0), InternalError); -} - -TEST_F(QueuedRingBufferTest, group_of_one) { - queued_ring_buff->StartGroup(0); - int* g0_0 = queued_ring_buff->Next(0); - *g0_0 = inflight; - queued_ring_buff->EndGroup(0); - - ASSERT_EQ(queued_ring_buff->InFlight(0), 1); - *g0_0 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 0); -} - -TEST_F(QueuedRingBufferTest, group_of_two) { - queued_ring_buff->StartGroup(0); - int* g0_0 = queued_ring_buff->Next(0); - *g0_0 = inflight; - int* g0_1 = queued_ring_buff->Next(0); - *g0_1 = inflight; - queued_ring_buff->EndGroup(0); - - // neither done => group in flight - ASSERT_EQ(queued_ring_buff->InFlight(0), 1); - - // half done => group in flight - *g0_0 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 1); - - // both done => group finished - *g0_1 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 0); -} - -TEST_F(QueuedRingBufferTest, group_of_three) { - queued_ring_buff->StartGroup(0); - int* g0_0 = queued_ring_buff->Next(0); - *g0_0 = inflight; - int* g0_1 = queued_ring_buff->Next(0); - *g0_1 = inflight; - int* g0_2 = queued_ring_buff->Next(0); - *g0_2 = inflight; - queued_ring_buff->EndGroup(0); - - // neither done => group in flight - ASSERT_EQ(queued_ring_buff->InFlight(0), 1); - - // 1/3 done => group in flight - *g0_0 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 1); - - // 2/3 done => group in flight - *g0_1 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 1); - - // all done => group finished - *g0_2 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 0); -} - -TEST_F(QueuedRingBufferTest, two_groups_of_two) { - queued_ring_buff->StartGroup(0); - int* g0_0 = queued_ring_buff->Next(0); - *g0_0 = inflight; - int* g0_1 = queued_ring_buff->Next(0); - *g0_1 = inflight; - queued_ring_buff->EndGroup(0); - - queued_ring_buff->StartGroup(0); - int* g1_0 = queued_ring_buff->Next(0); - *g1_0 = inflight; - int* g1_1 = queued_ring_buff->Next(0); - *g1_1 = inflight; - queued_ring_buff->EndGroup(0); - - // two groups in flight - ASSERT_EQ(queued_ring_buff->InFlight(0), 2); - - // group 0 half done => two groups in flight - *g0_0 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 2); - - // group 0 done => one group in flight - *g0_1 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 1); - - // group 1 half done => one group in flight - *g1_0 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 1); - - // group 1 done => zero groups in flight - *g1_1 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 0); -} - -TEST_F(QueuedRingBufferTest, two_queues_two_groups_of_two) { - queued_ring_buff->StartGroup(0); - int* q0g0_0 = queued_ring_buff->Next(0); - *q0g0_0 = inflight; - int* q0g0_1 = queued_ring_buff->Next(0); - *q0g0_1 = inflight; - queued_ring_buff->EndGroup(0); - - queued_ring_buff->StartGroup(1); - int* q1g0_0 = queued_ring_buff->Next(1); - *q1g0_0 = inflight; - int* q1g0_1 = queued_ring_buff->Next(1); - *q1g0_1 = inflight; - queued_ring_buff->EndGroup(1); - - queued_ring_buff->StartGroup(0); - int* q0g1_0 = queued_ring_buff->Next(0); - *q0g1_0 = inflight; - int* q0g1_1 = queued_ring_buff->Next(0); - *q0g1_1 = inflight; - queued_ring_buff->EndGroup(0); - - queued_ring_buff->StartGroup(1); - int* q1g1_0 = queued_ring_buff->Next(1); - *q1g1_0 = inflight; - int* q1g1_1 = queued_ring_buff->Next(1); - *q1g1_1 = inflight; - queued_ring_buff->EndGroup(1); - - // two queues with two groups in flight each - ASSERT_EQ(queued_ring_buff->InFlight(0), 2); - ASSERT_EQ(queued_ring_buff->InFlight(1), 2); - - // queue 0 group 0 half done => no change - *q0g0_0 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 2); - ASSERT_EQ(queued_ring_buff->InFlight(1), 2); - - // queue 0 group 0 done => queue 0 with one group in flight - *q0g0_1 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 1); - ASSERT_EQ(queued_ring_buff->InFlight(1), 2); - - // queue 1 group 0 half done => no change - *q1g0_0 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 1); - ASSERT_EQ(queued_ring_buff->InFlight(1), 2); - - // queue 1 group 0 done => queue 1 with one group in flight - *q1g0_1 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 1); - ASSERT_EQ(queued_ring_buff->InFlight(1), 1); - - // queue 0 group 1 half done => no change - *q0g1_0 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 1); - ASSERT_EQ(queued_ring_buff->InFlight(1), 1); - - // queue 0 group 1 done => queue 0 with zero groups in flight - *q0g1_1 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 0); - ASSERT_EQ(queued_ring_buff->InFlight(1), 1); - - // queue 1 group 1 half done => no change - *q1g1_0 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 0); - ASSERT_EQ(queued_ring_buff->InFlight(1), 1); - - // queue 1 group 1 done => queue 1 with zero groups in flight - *q1g1_1 = finished; - ASSERT_EQ(queued_ring_buff->InFlight(0), 0); - ASSERT_EQ(queued_ring_buff->InFlight(1), 0); -} diff --git a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py index 030e47ac581df..d985d21209369 100644 --- a/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py +++ b/tests/python/contrib/test_hexagon/metaschedule_e2e/test_resnet50_int8.py @@ -521,6 +521,7 @@ def test_async_dma_resnet50(hexagon_launcher): pass_config = { "tir.use_async_copy": 1, + "tir.merge_async_commit_queue_scope": False, "relay.backend.use_meta_schedule": True, "relay.backend.tir_converter": "default", } diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 04e7595abf37b..bc8edca7d8442 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -268,6 +268,7 @@ def evaluate( c_data, expected_output=None, use_async_copy=0, + merge_async_commit_queue_scope=False, ): """Evaluate function.""" target_hexagon = tvm.target.hexagon("v68", link_params=True) @@ -275,6 +276,7 @@ def evaluate( config={ "tir.use_async_copy": use_async_copy, "tir.experimental_dma_bypass_cache": 1, + "tir.merge_async_commit_queue_scope": merge_async_commit_queue_scope, } ): func_tir = tvm.build( @@ -483,6 +485,7 @@ def test_loading_vtcm_for_vrmpy( np.zeros(expected_output.shape, "int32"), expected_output, use_async_copy=1, + merge_async_commit_queue_scope=False, ) sch = get_fake_conv_vtcm_schedule(size_a, size_w) @@ -883,13 +886,14 @@ def test_non_contiguous(): """Test Non Contiguous memory lowering.""" sch = tvm.tir.Schedule(conv2d_async_non_contig) target_hexagon = tvm.target.hexagon("v68", link_params=True) - err_rgx = r"Unable to lower async dma due to non contiguous memory access" + err_rgx = r"Unable to lower async dma for non contiguous memory access with load index: " # Currently we do not support non contiguous memory access being lowered to # async dma so we throw an error. with pytest.raises(tvm.TVMError, match=err_rgx): with tvm.transform.PassContext( config={ "tir.use_async_copy": 1, + "tir.merge_async_commit_queue_scope": 0, } ): tvm.build( diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index 498e29e407b44..7c010f363fe1e 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -181,6 +181,7 @@ def test_async_software_pipeline( config={ "tir.use_async_copy": 1, "tir.experimental_dma_bypass_cache": 1, + "tir.merge_async_commit_queue_scope": False, } ): # tvm.lower(schedule.mod["main"]).show()