From d47e8a110c2bbfdb0802b3a0bb3433a6511a7b1d Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 7 Feb 2025 17:29:01 -0800 Subject: [PATCH] Reworked to do placement analysis on previous pass --- .../Partitioning/ReferencePartitioning.cpp | 35 +++----- .../Dialect/Stream/Transforms/BUILD.bazel | 1 + .../Dialect/Stream/Transforms/CMakeLists.txt | 1 + .../Stream/Transforms/ExecutionPlacement.cpp | 66 +++++++++++++++ .../Dialect/Stream/Transforms/Passes.cpp | 2 + .../Dialect/Stream/Transforms/Passes.td | 13 +++ .../Stream/Transforms/test/BUILD.bazel | 1 + .../Stream/Transforms/test/CMakeLists.txt | 1 + .../Transforms/test/execution_placement.mlir | 81 +++++++++++++++++++ .../Transforms/test/schedule_execution.mlir | 8 +- 10 files changed, 179 insertions(+), 30 deletions(-) create mode 100644 compiler/src/iree/compiler/Dialect/Stream/Transforms/ExecutionPlacement.cpp create mode 100644 compiler/src/iree/compiler/Dialect/Stream/Transforms/test/execution_placement.mlir diff --git a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp index 46d17796c169b..5dd54b2e85a2b 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Analysis/Partitioning/ReferencePartitioning.cpp @@ -169,11 +169,18 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config, // Synchronizing operations should join with their producers if the producer // is streamable. - if (dyn_cast(op)) { + if (dyn_cast(op) || + dyn_cast(op)) { auto producer = op.getOperand(0).getDefiningOp(); auto streamable = dyn_cast_or_null(producer); - if (streamable) { + + auto srcAffinity = dyn_cast(producer); + auto opAffinity = dyn_cast(op); + + if (streamable && srcAffinity && + IREE::Stream::AffinityAttr::canExecuteTogether( + opAffinity.getAffinityAttr(), srcAffinity.getAffinityAttr())) { if (!syncOps.contains(producer)) syncOps[producer] = llvm::SmallVector(); syncOps[producer].push_back(&op); @@ -181,30 +188,6 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config, } } - if (auto transfer = dyn_cast(op)) { - auto producer = op.getOperand(0).getDefiningOp(); - auto streamable = - dyn_cast_or_null(producer); - if (streamable) { - int transferCount = 0; - for (auto use : producer->getUsers()) { - if (isa(use)) { - transferCount++; - } - } - - if (transferCount < 2) { - if (!syncOps.contains(producer)) { - syncOps[producer] = llvm::SmallVector(); - syncOps[producer].push_back(&op); - continue; - } - } - - transfer.setAffinityAttr(transfer.getResultAffinityAttr()); - } - } - // Initialize op info for this op - whether streamable or not. We track // transitive hazards on each op. Note that thanks to the ordering of ops // in SSA form (_reversed here!_) we know that once we visit this op no diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel index e63fb544eb1a6..14b5227b92e0c 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/BUILD.bazel @@ -24,6 +24,7 @@ iree_compiler_cc_library( "ElideTimepoints.cpp", "EmplaceAllocations.cpp", "EncodeTensors.cpp", + "ExecutionPlacement.cpp", "FoldUniformOperands.cpp", "FuseDispatchBindings.cpp", "LayoutSlices.cpp", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt index 0f1139c41737c..10b70da4690af 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/CMakeLists.txt @@ -25,6 +25,7 @@ iree_cc_library( "ElideTimepoints.cpp" "EmplaceAllocations.cpp" "EncodeTensors.cpp" + "ExecutionPlacement.cpp" "FoldUniformOperands.cpp" "FuseDispatchBindings.cpp" "LayoutSlices.cpp" diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ExecutionPlacement.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ExecutionPlacement.cpp new file mode 100644 index 0000000000000..b459595aa633c --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ExecutionPlacement.cpp @@ -0,0 +1,66 @@ +// Copyright 2025 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h" +#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h" +#include "iree/compiler/Dialect/Stream/IR/StreamOps.h" +#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" +#include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/Support/Debug.h" +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "iree-stream-execution-placement" + +namespace mlir::iree_compiler::IREE::Stream { + +#define GEN_PASS_DEF_EXECUTIONPLACEMENTPASS +#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc" + +namespace { + +struct ExecutionPlacementPass + : public IREE::Stream::impl::ExecutionPlacementPassBase< + ExecutionPlacementPass> { + void runOnOperation() override { + + getOperation()->walk([](IREE::Stream::AsyncTransferOp transfer) { + if (transfer.getExecAffinityAttr()) + return; + + auto operand = transfer.getSource(); + auto producer = operand.getDefiningOp(); + auto streamable = + dyn_cast_or_null(producer); + auto srcAffinity = dyn_cast(producer); + + bool hasOneUse = operand.hasOneUse(); + if (hasOneUse && streamable && srcAffinity) { + transfer.setExecAffinityAttr(srcAffinity.getAffinityAttr()); + } else { + transfer.setExecAffinityAttr(transfer.getResultAffinityAttr()); + } + }); + } +}; + +} // namespace +} // namespace mlir::iree_compiler::IREE::Stream diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp index 7935be7810e60..6a61636ac110c 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.cpp @@ -204,6 +204,8 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager, //---------------------------------------------------------------------------- FunctionLikeNest(passManager) + // Analyze and assign execution placement. + .addPass(IREE::Stream::createExecutionPlacementPass) // Combine async work into execution regions. .addPass(IREE::Stream::createScheduleExecutionPass) // Group concurrently executable work into waves. diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td index 775eadac008d1..de648953717aa 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td @@ -214,6 +214,19 @@ def RefineUsagePass : // Stream formation and scheduling //===----------------------------------------------------------------------===// +def ExecutionPlacementPass : + InterfacePass<"iree-stream-execution-placement", "mlir::CallableOpInterface"> { + let summary = "Runs an analysis and placement for stream executions."; + let description = [{ + Handles placement analysis for `stream.async` operators that have a preferable + placement. This is so that more complex analsysis can be separated from the + execution scheduling pass. + }]; + let dependentDialects = [ + "IREE::Stream::StreamDialect", + ]; +} + def ScheduleExecutionPass : InterfacePass<"iree-stream-schedule-execution", "mlir::CallableOpInterface"> { let summary = "Identifies and groups asynchronous operations into executable regions within function-like regions."; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel index 138ba0be66899..4083ad4224030 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel @@ -31,6 +31,7 @@ iree_lit_test_suite( "encode_host_tensors_encoding.mlir", "encode_host_tensors_packing.mlir", "encode_host_tensors_packing_i1_experimental_clopt.mlir", + "execution_placement.mlir", "fold_globals.mlir", "fold_uniform_operands.mlir", "fuse_dispatch_bindings.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt index 4c4cb93d80ef3..91a87dcc84e9d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt @@ -29,6 +29,7 @@ iree_lit_test_suite( "encode_host_tensors_encoding.mlir" "encode_host_tensors_packing.mlir" "encode_host_tensors_packing_i1_experimental_clopt.mlir" + "execution_placement.mlir" "fold_globals.mlir" "fold_uniform_operands.mlir" "fuse_dispatch_bindings.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/execution_placement.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/execution_placement.mlir new file mode 100644 index 0000000000000..298ffea4b0e00 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/execution_placement.mlir @@ -0,0 +1,81 @@ +// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-stream-execution-placement))" %s | FileCheck %s + +// Tests partitioning multi-device execution with barriers and transfers. +// It validates that multi-stream commands are created and run in parallel. + +// CHECK-LABEL: util.func public @deviceMultiDeviceSync +util.func public @deviceMultiDeviceSync(%arg0: i1) -> !stream.resource { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c255_i32 = arith.constant 255 : i32 + + %0 = stream.async.splat %c255_i32 : i32 -> !stream.resource{%c128} + // CHECK: stream.async.dispatch + %1 = stream.async.dispatch on(#hal.device.affinity<@device0>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource{%c128}) -> !stream.resource{%c128} + %3 = stream.async.barrier %1 : !stream.resource{%c128} + + // CHECK: stream.async.transfer + // CHECK-SAME: on(#hal.device.affinity<@device1>) + %4 = stream.async.transfer %1 : !stream.resource{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device1>) !stream.resource{%c128} + // CHECK: stream.async.dispatch + %2 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch1[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource{%c128}) -> !stream.resource{%c128} + %5 = stream.async.barrier %2 : !stream.resource{%c128} + + // CHECK: stream.async.transfer + // CHECK-SAME: on(#hal.device.affinity<@device0>) + %6 = stream.async.transfer %2 : !stream.resource{%c128} from(#hal.device.affinity<@device1>) -> to(#hal.device.affinity<@device0>) !stream.resource{%c128} + // CHECK: stream.async.dispatch + %7 = stream.async.dispatch on(#hal.device.affinity<@device0>) @ex::@dispatch2[%c1, %c1, %c1](%3[%c0 to %c128 for %c128], %6[%c0 to %c128 for %c128]) : (!stream.resource{%c128}, !stream.resource{%c128}) -> !stream.resource{%c128} + %8 = stream.async.barrier %7 : !stream.resource{%c128} + %9 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch2[%c1, %c1, %c1](%4[%c0 to %c128 for %c128], %5[%c0 to %c128 for %c128]) : (!stream.resource{%c128}, !stream.resource{%c128}) -> !stream.resource{%c128} + + // CHECK: stream.async.transfer + // CHECK-SAME: on(#hal.device.affinity<@device1>) + %10 = stream.async.transfer %9 : !stream.resource{%c128} from(#hal.device.affinity<@device1>) -> to(#hal.device.affinity<@device0>) !stream.resource{%c128} + // CHECK: stream.async.dispatch + %11 = stream.async.dispatch on(#hal.device.affinity<@device0>) @ex::@dispatch2[%c1, %c1, %c1](%8[%c0 to %c128 for %c128], %10[%c0 to %c128 for %c128]) : (!stream.resource{%c128}, !stream.resource{%c128}) -> !stream.resource{%c128} + util.return %11 : !stream.resource +} + +// ----- + +// This one simulates how to do multi-device synchronization between +// more than two devices. + +// CHECK-LABEL: @deviceTripleSync +util.func public @deviceTripleSync(%arg0: i1) -> (!stream.resource, !stream.resource, !stream.resource) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c255_i32 = arith.constant 255 : i32 + + %0 = stream.async.splat %c255_i32 : i32 -> !stream.resource{%c128} + %1 = stream.async.dispatch on(#hal.device.affinity<@device0>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource{%c128}) -> !stream.resource{%c128} + %2 = stream.async.barrier %1 : !stream.resource{%c128} + + %3 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource{%c128}) -> !stream.resource{%c128} + + // CHECK: stream.async.transfer + // CHECK-SAME: on(#hal.device.affinity<@device1>) + %4 = stream.async.transfer %3 : !stream.resource{%c128} from(#hal.device.affinity<@device1>) -> to(#hal.device.affinity<@device0>) !stream.resource{%c128} + %5 = stream.async.dispatch on(#hal.device.affinity<@device2>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource{%c128}) -> !stream.resource{%c128} + + // CHECK: stream.async.transfer + // CHECK-SAME: on(#hal.device.affinity<@device2>) + %6 = stream.async.transfer %5 : !stream.resource{%c128} from(#hal.device.affinity<@device2>) -> to(#hal.device.affinity<@device0>) !stream.resource{%c128} + %7 = stream.async.dispatch on(#hal.device.affinity<@device0>) @ex::@dispatch2[%c1, %c1, %c1](%2[%c0 to %c128 for %c128], %4[%c0 to %c128 for %c128], %6[%c0 to %c128 for %c128]) : (!stream.resource{%c128}, !stream.resource{%c128}, !stream.resource{%c128}) -> !stream.resource{%c128} + %8 = stream.async.barrier %7 : !stream.resource{%c128} + %11 = stream.async.dispatch on(#hal.device.affinity<@device0>) @ex::@dispatch0[%c1, %c1, %c1](%8[%c0 to %c128 for %c128]) : (!stream.resource{%c128}) -> !stream.resource{%c128} + + // CHECK: stream.async.transfer + // CHECK-SAME: on(#hal.device.affinity<@device1>) + %9 = stream.async.transfer %7 : !stream.resource{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device1>) !stream.resource{%c128} + %12 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch0[%c1, %c1, %c1](%9[%c0 to %c128 for %c128]) : (!stream.resource{%c128}) -> !stream.resource{%c128} + + // CHECK: stream.async.transfer + // CHECK-SAME: on(#hal.device.affinity<@device2>) + %10 = stream.async.transfer %7 : !stream.resource{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device2>) !stream.resource{%c128} + %13 = stream.async.dispatch on(#hal.device.affinity<@device2>) @ex::@dispatch0[%c1, %c1, %c1](%10[%c0 to %c128 for %c128]) : (!stream.resource{%c128}) -> !stream.resource{%c128} + util.return %11, %12, %13 : !stream.resource, !stream.resource, !stream.resource +} diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir index a0615b2480fd7..b1dfdca62561d 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_execution.mlir @@ -100,14 +100,14 @@ util.func public @deviceTripleSync(%arg0: i1) -> (!stream.resource, ! // CHECK: stream.async.dispatch // CHECK: stream.async.transfer %3 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource{%c128}) -> !stream.resource{%c128} - %4 = stream.async.transfer %3 : !stream.resource{%c128} from(#hal.device.affinity<@device1>) -> to(#hal.device.affinity<@device0>) !stream.resource{%c128} + %4 = stream.async.transfer %3 : !stream.resource{%c128} from(#hal.device.affinity<@device1>) -> to(#hal.device.affinity<@device0>) on(#hal.device.affinity<@device1>) !stream.resource{%c128} // CHECK: stream.async.execute // CHECK: stream.async.splat // CHECK: stream.async.dispatch // CHECK: stream.async.transfer %5 = stream.async.dispatch on(#hal.device.affinity<@device2>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource{%c128}) -> !stream.resource{%c128} - %6 = stream.async.transfer %5 : !stream.resource{%c128} from(#hal.device.affinity<@device2>) -> to(#hal.device.affinity<@device0>) !stream.resource{%c128} + %6 = stream.async.transfer %5 : !stream.resource{%c128} from(#hal.device.affinity<@device2>) -> to(#hal.device.affinity<@device0>) on(#hal.device.affinity<@device2>) !stream.resource{%c128} // CHECK: stream.async.execute // CHECK: stream.async.dispatch @@ -121,13 +121,13 @@ util.func public @deviceTripleSync(%arg0: i1) -> (!stream.resource, ! // CHECK: stream.async.execute // CHECK: stream.async.transfer // CHECK: stream.async.dispatch - %9 = stream.async.transfer %7 : !stream.resource{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device1>) !stream.resource{%c128} + %9 = stream.async.transfer %7 : !stream.resource{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device1>) on(#hal.device.affinity<@device1>) !stream.resource{%c128} %12 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch0[%c1, %c1, %c1](%9[%c0 to %c128 for %c128]) : (!stream.resource{%c128}) -> !stream.resource{%c128} // CHECK: stream.async.execute // CHECK: stream.async.transfer // CHECK: stream.async.dispatch - %10 = stream.async.transfer %7 : !stream.resource{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device2>) !stream.resource{%c128} + %10 = stream.async.transfer %7 : !stream.resource{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device2>) on(#hal.device.affinity<@device2>) !stream.resource{%c128} %13 = stream.async.dispatch on(#hal.device.affinity<@device2>) @ex::@dispatch0[%c1, %c1, %c1](%10[%c0 to %c128 for %c128]) : (!stream.resource{%c128}) -> !stream.resource{%c128} util.return %11, %12, %13 : !stream.resource, !stream.resource, !stream.resource