From 4d2a04cf16d456ea365d208a6048a72209d551dc Mon Sep 17 00:00:00 2001 From: eedalong Date: Thu, 14 Mar 2024 17:22:37 +0800 Subject: [PATCH 1/2] support async collective op execution --- .../disc/transforms/mhlo_decomp_rewriters.cc | 63 ++++++++++++-- .../tests/mhlo_decomp_rewriter.mlir | 4 +- ...omp_rewriter_with_async_collective_op.mlir | 15 ++++ tao_compiler/mlir/ral/collective.cu.cc | 83 ++++++++++++++++++- .../context/base/cuda/cuda_context_impl.cc | 33 ++++++++ .../ral/context/base/cuda/cuda_context_impl.h | 7 ++ 6 files changed, 194 insertions(+), 11 deletions(-) mode change 100644 => 100755 tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir create mode 100755 tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter_with_async_collective_op.mlir diff --git a/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc b/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc index e5f6b46ab21..f101c9a3d55 100644 --- a/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc +++ b/tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc @@ -134,6 +134,21 @@ LogicalResult SliceOpConvert::matchAndRewrite(mhlo::SliceOp op, } } // namespace namespace { + +bool IsAsyncCollective(Operation* op) { + if (llvm::isa(op)) { + if (const char* env_p = std::getenv("ENABLE_ASYNC_ALL_REDUCE")) { + return std::strcmp(env_p, "true") == 0 || std::strcmp(env_p, "True") == 0; + } + } else if (llvm::isa(op)) { + if (const char* env_p = std::getenv("ENABLE_ASYNC_ALL_GATHER")) { + return std::strcmp(env_p, "true") == 0 || std::strcmp(env_p, "True") == 0; + } + } + + return false; +} + enum ReductionKind { ALL_REDUCE_SUM, ALL_REDUCE_PRODUCT, @@ -192,6 +207,9 @@ struct CollectiveOpConverter : public OpRewritePattern { if (!reductionKind) { return failure(); } + + bool is_async = IsAsyncCollective(op.getOperation()); + for (int i = 0; i < op->getOperands().size(); ++i) { // no need call all_reduce op if no consumer if (op->getResult(i).getUsers().empty()) { @@ -206,19 +224,48 @@ struct CollectiveOpConverter : public OpRewritePattern { op->setAttr("output_layouts", rewriter.getStringAttr("*")); op->setAttr("expected_input_layouts", rewriter.getStringAttr("*")); op->setAttr("expected_output_layouts", rewriter.getStringAttr("*")); - SmallVector newAttrs; - newAttrs.push_back( + + SmallVector attrs; + attrs.push_back( NamedAttribute(rewriter.getStringAttr("reduction_kind"), rewriter.getStringAttr(reductionKind.value()))); + attrs.push_back(NamedAttribute(rewriter.getStringAttr("is_async"), + rewriter.getBoolAttr(is_async))); + auto customAttrs = DictionaryAttr::get(op->getContext(), attrs); + op->setAttr("custom_attrs", customAttrs); - auto newCustomAttrs = DictionaryAttr::get(op->getContext(), newAttrs); - - op->setAttr("custom_attrs", newCustomAttrs); - - auto newOutput = rewriter.create( + auto reduce_op = rewriter.create( op->getLoc(), op->getResults()[i].getType(), op->getOperands()[i], op->getAttrs()); - newOutputs.push_back(newOutput.getResult(0)); + + if (is_async) { + int64_t async_pair_token = + reinterpret_cast(reduce_op.getOperation()); + attrs.push_back( + NamedAttribute(rewriter.getStringAttr("async_token_key"), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + async_pair_token))); + auto newCustomAttrs = + DictionaryAttr::get(reduce_op->getContext(), attrs); + reduce_op->setAttr("custom_attrs", newCustomAttrs); + } + + if (is_async) { + // Insert CollectiveDoneOp + auto collective_done_op = rewriter.create( + reduce_op->getLoc(), reduce_op->getResults()[0].getType(), + reduce_op->getResults()[0], reduce_op->getAttrs()); + collective_done_op->setAttr( + "call_target_name", + rewriter.getStringAttr("ral_async_collective_done")); + + // Place collective_done_op right before the first consumer. + collective_done_op->moveBefore(*(op->getResult(i).user_begin())); + + newOutputs.push_back(collective_done_op.getResult(0)); + } else { + newOutputs.push_back(reduce_op.getResult(0)); + } } rewriter.replaceOp(op, newOutputs); return success(); diff --git a/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir b/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir old mode 100644 new mode 100755 index 10d3c5b9df6..bf817f11021 --- a/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter.mlir @@ -37,8 +37,8 @@ func.func @batch_norm_inference(%arg0: tensor, %arg1: tensor<128x } func.func @main(%arg0: tensor, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor) { - // CHECK: %0 = "mhlo_disc.custom_call_v2"(%arg1) {call_target_name = "ral_all_reduce", custom_attrs = {reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: %1 = "mhlo_disc.custom_call_v2"(%arg0) {call_target_name = "ral_all_reduce", custom_attrs = {reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor) -> tensor + // CHECK: %0 = "mhlo_disc.custom_call_v2"(%arg1) {call_target_name = "ral_all_reduce", custom_attrs = {is_async = false, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %1 = "mhlo_disc.custom_call_v2"(%arg0) {call_target_name = "ral_all_reduce", custom_attrs = {is_async = false, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor) -> tensor %0:2 = "mhlo.all_reduce"(%arg1, %arg0) ({ ^bb0(%arg2: tensor, %arg3: tensor): %1 = mhlo.add %arg2, %arg3 : tensor diff --git a/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter_with_async_collective_op.mlir b/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter_with_async_collective_op.mlir new file mode 100755 index 00000000000..a96d5ed7e4a --- /dev/null +++ b/tao_compiler/mlir/disc/transforms/tests/mhlo_decomp_rewriter_with_async_collective_op.mlir @@ -0,0 +1,15 @@ +// RUN: ENABLE_ASYNC_ALL_REDUCE=true disc-opt -disc-mhlo-decomp-rewriter -split-input-file %s -o - | FileCheck %s + +func.func @main(%arg0: tensor, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor) { + // CHECK: %0 = "mhlo_disc.custom_call_v2"(%arg1) {call_target_name = "ral_all_reduce", custom_attrs = {async_token_key = {{.*}} : i64, is_async = true, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %1 = "mhlo_disc.custom_call_v2"(%arg0) {call_target_name = "ral_all_reduce", custom_attrs = {async_token_key = {{.*}} : i64, is_async = true, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor) -> tensor + // CHECK: %2 = "mhlo_disc.custom_call_v2"(%0) {call_target_name = "ral_async_collective_done", custom_attrs = {async_token_key = {{.*}} : i64, is_async = true, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: %3 = "mhlo_disc.custom_call_v2"(%1) {call_target_name = "ral_async_collective_done", custom_attrs = {async_token_key = {{.*}} : i64, is_async = true, reduction_kind = "sum"}, device = "d", expected_input_layouts = "*", expected_output_layouts = "*", has_side_effect = false, input_layouts = "*", input_placements = "d", output_layouts = "*", output_placements = "d", replica_groups = dense<> : tensor<0x0xi64>} : (tensor) -> tensor + %0:2 = "mhlo.all_reduce"(%arg1, %arg0) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = mhlo.add %arg2, %arg3 : tensor + mhlo.return %1 : tensor + }) {replica_groups = dense<> : tensor<0x0xi64>} : (tensor<4xf32>, tensor) -> (tensor<4xf32>, tensor) + // CHECK: return %2, %3 : tensor<4xf32>, tensor + return %0#0, %0#1 : tensor<4xf32>, tensor +} diff --git a/tao_compiler/mlir/ral/collective.cu.cc b/tao_compiler/mlir/ral/collective.cu.cc index 35f92258aa0..431253d7301 100644 --- a/tao_compiler/mlir/ral/collective.cu.cc +++ b/tao_compiler/mlir/ral/collective.cu.cc @@ -63,6 +63,9 @@ MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, auto& dictAttr = attr->as(); std::string reductionKind = dictAttr.get("reduction_kind").template as().getValue(); + + bool isAsync = dictAttr.get("is_async").template as().getValue(); + ncclDataType_t ncclDtype = ncclDataTypeMapper::value; auto ncclReductionType = getNcclReductionType(reductionKind); @@ -74,7 +77,7 @@ MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, auto gpu_driver = ctx->getDriver( tao::ral::gpu::GPUDriver::name()); auto gpu_stream = - static_cast(gpu_driver->asCUStream(ctx, stream_handle)); + static_cast(ctx)->getCommStream(); auto nccl_comm = static_cast(ctx)->getNcclComm(); auto ptr = static_cast(gpu_driver->alloc(ctx, element_count * sizeof(T))); @@ -87,9 +90,69 @@ MemRefType ral_all_reduce(ExecutionContext* ctx, void* stream_handle, if (ncclResult != ncclSuccess) { ctx->signalError(Context::FAILURE, "fail to call ncclAllReduce\n"); } + + if (isAsync && gpu_stream) { + int64_t token_key = + dictAttr.get("async_token_key").template as().getValue(); + cudaEvent_t event; + + auto event_status = cudaEventCreate(&event); + if (event_status != cudaSuccess) { + ctx->signalError(Context::FAILURE, "cudaEventCreate failed\n"); + } + + auto record_status = cudaEventRecord(event, gpu_stream); + if (record_status != cudaSuccess) { + cudaEventDestroy(event); + ctx->signalError(Context::FAILURE, "cudaEventRecord failed\n"); + } + + static_cast(ctx)->addAsyncPairToken( + token_key, event); + } + return output; } +template +MemRefType ral_async_collective_done(ExecutionContext* ctx, + void* stream_handle, + MemRefType input, + void* customAttrs) { + auto attr = + getOrParsePDLAttr(ctx, customAttrs, "simple_test_fused_add_mul_kernel"); + if (!attr) { + ctx->signalError( + Context::FAILURE, + "fail to parse custom_attrs in ral_async_collective_done\n"); + } + + auto& dictAttr = attr->as(); + int64_t token_key = + dictAttr.get("async_token_key").template as().getValue(); + auto event = + static_cast(ctx)->getAsyncPairToken( + token_key); + if (event) { + auto sync_status = cudaEventSynchronize(event); + if (sync_status != cudaSuccess) { + ctx->signalError(Context::FAILURE, "cudaEventSynchronize failed\n"); + } + static_cast(ctx)->removeAsyncPairToken( + token_key); + cudaEventDestroy(event); + } + + // Increase ref count for input to prevent double free + auto it = + static_cast(ctx)->device_ptr_map.find( + input.data); + ; + ++it->second; + + return input; +} + TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); @@ -98,5 +161,23 @@ TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); TAO_RAL_API("ral_all_reduce", "gpu", ral_all_reduce); + +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); +TAO_RAL_API("ral_async_collective_done", "gpu", + ral_async_collective_done); + } // namespace ral } // namespace tao diff --git a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc index 0af72f978f2..7c84d83639c 100644 --- a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc +++ b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.cc @@ -119,10 +119,13 @@ struct BaseCudaContextState : public tao::ral::Context::Resource { ncclComm_t nccl_comm = nullptr; GpuStreamHandle stream = nullptr; + GpuStreamHandle comm_stream = nullptr; // map blob ptr -> loaded module std::map blobs; // map -> callable kernel std::map, GpuFunctionHandle> kernels; + // map int64 -> cudaEvent_t + std::map async_pair_tokens; std::shared_ptr gpu_allocator; bool cache_workspace_mem_across_execution; @@ -146,6 +149,7 @@ struct BaseCudaContextState : public tao::ral::Context::Resource { "StreamSync"); #else reportErrorIfAny(cuStreamSynchronize(stream), ctx, "StreamSync"); + reportErrorIfAny(cuStreamSynchronize(comm_stream), ctx, "StreamSync"); #endif for (const_buffer_t buffer : device_persistent_buffers) { gpu_allocator->dealloc(const_cast(buffer)); @@ -173,6 +177,7 @@ std::unique_ptr MakeBaseCudaContext( auto state = new BaseCudaContextState; state->stream = gpu_opt.stream; state->nccl_comm = gpu_opt.nccl_comm; + state->comm_stream = gpu_opt.comm_stream; if (gpu_opt.gpu_allocator != nullptr) { state->gpu_allocator = gpu_opt.gpu_allocator; } else { @@ -206,6 +211,34 @@ ncclComm_t BaseCudaExecutionContext::getNcclComm() { return state->nccl_comm; } +GpuStreamHandle BaseCudaExecutionContext::getCommStream() { + auto* state = getResource(kRalBaseCudaContextState); + return state->comm_stream; +} + +cudaEvent_t BaseCudaExecutionContext::getAsyncPairToken(int64_t key) { + auto* state = getResource(kRalBaseCudaContextState); + if (state->async_pair_tokens.find(key) != state->async_pair_tokens.end()) { + return state->async_pair_tokens[key]; + } + return nullptr; +} + +void BaseCudaExecutionContext::addAsyncPairToken(int64_t key, + cudaEvent_t token) { + auto* state = getResource(kRalBaseCudaContextState); + state->async_pair_tokens[key] = token; + return; +} + +void BaseCudaExecutionContext::removeAsyncPairToken(int64_t key) { + auto* state = getResource(kRalBaseCudaContextState); + if (state->async_pair_tokens.find(key) != state->async_pair_tokens.end()) { + state->async_pair_tokens.erase(key); + } + return; +} + void BaseCudaExecutionContext::setOutputDeleter(OutputBufferWrapper& output) { { if (synced) { diff --git a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h index 314f5f7285a..c15cf1ba56e 100644 --- a/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h +++ b/tao_compiler/mlir/ral/context/base/cuda/cuda_context_impl.h @@ -48,6 +48,7 @@ using GpuStreamHandle = CUstream; struct BaseCudaContextOption { ncclComm_t nccl_comm = nullptr; GpuStreamHandle stream = nullptr; + GpuStreamHandle comm_stream = nullptr; int device_ordinal = 0; bool use_stream_executor = true; bool cache_workspace_mem_across_execution = false; @@ -64,6 +65,12 @@ struct BaseCudaExecutionContext ~BaseCudaExecutionContext(); ncclComm_t getNcclComm(); + + GpuStreamHandle getCommStream(); + + cudaEvent_t getAsyncPairToken(int64_t key); + void addAsyncPairToken(int64_t key, cudaEvent_t token); + void removeAsyncPairToken(int64_t key); // We need to sync on the gpu stream before we fetch the first output. bool synced = false; // all buffer allocated by the gpu_allocator From 6054c53dfc0bbe6124d97d778ca1cbcffe92f23c Mon Sep 17 00:00:00 2001 From: eedalong Date: Thu, 23 May 2024 16:09:09 +0800 Subject: [PATCH 2/2] minor fixes --- .github/workflows/pytorch113_gpu.yml | 4 ++-- tao_compiler/mlir/disc/transforms/disc_op_schedule.cc | 4 ++-- .../mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc | 6 ++++-- tao_compiler/mlir/disc/transforms/mhlo_placer.cc | 2 +- 4 files changed, 9 insertions(+), 7 deletions(-) mode change 100644 => 100755 .github/workflows/pytorch113_gpu.yml mode change 100644 => 100755 tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc diff --git a/.github/workflows/pytorch113_gpu.yml b/.github/workflows/pytorch113_gpu.yml old mode 100644 new mode 100755 index 10005f73866..8f8a4faec77 --- a/.github/workflows/pytorch113_gpu.yml +++ b/.github/workflows/pytorch113_gpu.yml @@ -15,8 +15,8 @@ jobs: cuda_version: cu116 runner_tag: gpu-a10 remote_runtime_docker: bladedisc:latest-runtime-torch1.13.1-cu116 - develop_base_image: nvidia/cuda:11.6.0-cudnn8-devel-ubuntu20.04 - runtime_base_image: nvidia/cuda:11.6.0-cudnn8-devel-ubuntu20.04 + develop_base_image: nvidia/cuda:11.6.1-cudnn8-devel-ubuntu20.04 + runtime_base_image: nvidia/cuda:11.6.1-cudnn8-devel-ubuntu20.04 extra_build_args: --build-arg PYTHON_VERSION=PYTHON3.8 --build-arg ENABLE_FIND_FASTEST_APT_SOURCE=OFF extra_envs: -e TORCH_BLADE_BUILD_TENSORRT_STATIC=OFF -e TORCH_BLADE_CI_BUILD_TORCH_VERSION=1.13.1+cu116 diff --git a/tao_compiler/mlir/disc/transforms/disc_op_schedule.cc b/tao_compiler/mlir/disc/transforms/disc_op_schedule.cc index 36461586229..726af7ed2bd 100644 --- a/tao_compiler/mlir/disc/transforms/disc_op_schedule.cc +++ b/tao_compiler/mlir/disc/transforms/disc_op_schedule.cc @@ -354,7 +354,7 @@ class ScheduleGraph { explicit ScheduleGraph(std::vector& post_order_instructions, LatencyEstimator* latency_estimator, AsyncTracker* async_tracker) { - InitilizeGrpahTopology(post_order_instructions, latency_estimator, + InitilizeGraphTopology(post_order_instructions, latency_estimator, async_tracker); InitializeGraphAnalysis(latency_estimator, async_tracker); } @@ -497,7 +497,7 @@ class ScheduleGraph { } } - void InitilizeGrpahTopology(std::vector& post_order_instructions, + void InitilizeGraphTopology(std::vector& post_order_instructions, LatencyEstimator* latency_estimator, AsyncTracker* async_tracker) { original_order_ = post_order_instructions; diff --git a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc old mode 100644 new mode 100755 index 73b1feaaa68..93789480e5c --- a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc @@ -5712,9 +5712,11 @@ struct DiscLhloLegalizeRootsToParallelLoops // TODO: We should put even single nodes into a fusion by fusion pass // Revisit this and walk lmhlo::FusionOp only after the revision done. func.walk([&](lmhlo::LmhloOp op) { - // Skip the embedded ops in lmhlo.fusion or lmhlo.reduce/scatter + // Skip the embedded ops in lmhlo.fusion or lmhlo.reduce/scatter or + // lmhlo_disc.args_mutation lmhlo::LmhloOp parent = op->getParentOfType(); - if (parent && !isa(op)) { + if (isa(op) || + parent && !isa(op)) { return; } if (isFusionType(op) && diff --git a/tao_compiler/mlir/disc/transforms/mhlo_placer.cc b/tao_compiler/mlir/disc/transforms/mhlo_placer.cc index 3ffb2b60f5f..ed823247601 100644 --- a/tao_compiler/mlir/disc/transforms/mhlo_placer.cc +++ b/tao_compiler/mlir/disc/transforms/mhlo_placer.cc @@ -418,7 +418,7 @@ void OpsPlacer::placeI32Ops() { if (isa(op)) return; if (isa(op)) { + mhlo::ReturnOp, mhlo_disc::ArgsMutationOp>(op)) { return; } // Skip the Op that is already placed on CPU