Skip to content

Commit

Permalink
Reworked to do placement analysis on previous pass
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Feb 8, 2025
1 parent 8aeb390 commit d47e8a1
Show file tree
Hide file tree
Showing 10 changed files with 179 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,42 +169,25 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,

// Synchronizing operations should join with their producers if the producer
// is streamable.
if (dyn_cast<IREE::Stream::AsyncBarrierOp>(op)) {
if (dyn_cast<IREE::Stream::AsyncBarrierOp>(op) ||
dyn_cast<IREE::Stream::AsyncTransferOp>(op)) {
auto producer = op.getOperand(0).getDefiningOp();
auto streamable =
dyn_cast_or_null<IREE::Stream::StreamableOpInterface>(producer);
if (streamable) {

auto srcAffinity = dyn_cast<IREE::Stream::AffinityOpInterface>(producer);
auto opAffinity = dyn_cast<IREE::Stream::AffinityOpInterface>(op);

if (streamable && srcAffinity &&
IREE::Stream::AffinityAttr::canExecuteTogether(
opAffinity.getAffinityAttr(), srcAffinity.getAffinityAttr())) {
if (!syncOps.contains(producer))
syncOps[producer] = llvm::SmallVector<Operation *>();
syncOps[producer].push_back(&op);
continue;
}
}

if (auto transfer = dyn_cast<IREE::Stream::AsyncTransferOp>(op)) {
auto producer = op.getOperand(0).getDefiningOp();
auto streamable =
dyn_cast_or_null<IREE::Stream::StreamableOpInterface>(producer);
if (streamable) {
int transferCount = 0;
for (auto use : producer->getUsers()) {
if (isa<IREE::Stream::AsyncTransferOp>(use)) {
transferCount++;
}
}

if (transferCount < 2) {
if (!syncOps.contains(producer)) {
syncOps[producer] = llvm::SmallVector<Operation *>();
syncOps[producer].push_back(&op);
continue;
}
}

transfer.setAffinityAttr(transfer.getResultAffinityAttr());
}
}

// Initialize op info for this op - whether streamable or not. We track
// transitive hazards on each op. Note that thanks to the ordering of ops
// in SSA form (_reversed here!_) we know that once we visit this op no
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_compiler_cc_library(
"ElideTimepoints.cpp",
"EmplaceAllocations.cpp",
"EncodeTensors.cpp",
"ExecutionPlacement.cpp",
"FoldUniformOperands.cpp",
"FuseDispatchBindings.cpp",
"LayoutSlices.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_cc_library(
"ElideTimepoints.cpp"
"EmplaceAllocations.cpp"
"EncodeTensors.cpp"
"ExecutionPlacement.cpp"
"FoldUniformOperands.cpp"
"FuseDispatchBindings.cpp"
"LayoutSlices.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-stream-execution-placement"

namespace mlir::iree_compiler::IREE::Stream {

#define GEN_PASS_DEF_EXECUTIONPLACEMENTPASS
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc"

namespace {

struct ExecutionPlacementPass
: public IREE::Stream::impl::ExecutionPlacementPassBase<
ExecutionPlacementPass> {
void runOnOperation() override {

getOperation()->walk([](IREE::Stream::AsyncTransferOp transfer) {
if (transfer.getExecAffinityAttr())
return;

auto operand = transfer.getSource();
auto producer = operand.getDefiningOp();
auto streamable =
dyn_cast_or_null<IREE::Stream::StreamableOpInterface>(producer);
auto srcAffinity = dyn_cast<IREE::Stream::AffinityOpInterface>(producer);

bool hasOneUse = operand.hasOneUse();
if (hasOneUse && streamable && srcAffinity) {
transfer.setExecAffinityAttr(srcAffinity.getAffinityAttr());
} else {
transfer.setExecAffinityAttr(transfer.getResultAffinityAttr());
}
});
}
};

} // namespace
} // namespace mlir::iree_compiler::IREE::Stream
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
//----------------------------------------------------------------------------

FunctionLikeNest(passManager)
// Analyze and assign execution placement.
.addPass(IREE::Stream::createExecutionPlacementPass)
// Combine async work into execution regions.
.addPass(IREE::Stream::createScheduleExecutionPass)
// Group concurrently executable work into waves.
Expand Down
13 changes: 13 additions & 0 deletions compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,19 @@ def RefineUsagePass :
// Stream formation and scheduling
//===----------------------------------------------------------------------===//

def ExecutionPlacementPass :
InterfacePass<"iree-stream-execution-placement", "mlir::CallableOpInterface"> {
let summary = "Runs an analysis and placement for stream executions.";
let description = [{
Handles placement analysis for `stream.async` operators that have a preferable
placement. This is so that more complex analsysis can be separated from the
execution scheduling pass.
}];
let dependentDialects = [
"IREE::Stream::StreamDialect",
];
}

def ScheduleExecutionPass :
InterfacePass<"iree-stream-schedule-execution", "mlir::CallableOpInterface"> {
let summary = "Identifies and groups asynchronous operations into executable regions within function-like regions.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ iree_lit_test_suite(
"encode_host_tensors_encoding.mlir",
"encode_host_tensors_packing.mlir",
"encode_host_tensors_packing_i1_experimental_clopt.mlir",
"execution_placement.mlir",
"fold_globals.mlir",
"fold_uniform_operands.mlir",
"fuse_dispatch_bindings.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_lit_test_suite(
"encode_host_tensors_encoding.mlir"
"encode_host_tensors_packing.mlir"
"encode_host_tensors_packing_i1_experimental_clopt.mlir"
"execution_placement.mlir"
"fold_globals.mlir"
"fold_uniform_operands.mlir"
"fuse_dispatch_bindings.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(util.func(iree-stream-execution-placement))" %s | FileCheck %s

// Tests partitioning multi-device execution with barriers and transfers.
// It validates that multi-stream commands are created and run in parallel.

// CHECK-LABEL: util.func public @deviceMultiDeviceSync
util.func public @deviceMultiDeviceSync(%arg0: i1) -> !stream.resource<transient> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c128 = arith.constant 128 : index
%c255_i32 = arith.constant 255 : i32

%0 = stream.async.splat %c255_i32 : i32 -> !stream.resource<transient>{%c128}
// CHECK: stream.async.dispatch
%1 = stream.async.dispatch on(#hal.device.affinity<@device0>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
%3 = stream.async.barrier %1 : !stream.resource<transient>{%c128}

// CHECK: stream.async.transfer
// CHECK-SAME: on(#hal.device.affinity<@device1>)
%4 = stream.async.transfer %1 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device1>) !stream.resource<transient>{%c128}
// CHECK: stream.async.dispatch
%2 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch1[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
%5 = stream.async.barrier %2 : !stream.resource<transient>{%c128}

// CHECK: stream.async.transfer
// CHECK-SAME: on(#hal.device.affinity<@device0>)
%6 = stream.async.transfer %2 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device1>) -> to(#hal.device.affinity<@device0>) !stream.resource<transient>{%c128}
// CHECK: stream.async.dispatch
%7 = stream.async.dispatch on(#hal.device.affinity<@device0>) @ex::@dispatch2[%c1, %c1, %c1](%3[%c0 to %c128 for %c128], %6[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}, !stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
%8 = stream.async.barrier %7 : !stream.resource<transient>{%c128}
%9 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch2[%c1, %c1, %c1](%4[%c0 to %c128 for %c128], %5[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}, !stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}

// CHECK: stream.async.transfer
// CHECK-SAME: on(#hal.device.affinity<@device1>)
%10 = stream.async.transfer %9 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device1>) -> to(#hal.device.affinity<@device0>) !stream.resource<transient>{%c128}
// CHECK: stream.async.dispatch
%11 = stream.async.dispatch on(#hal.device.affinity<@device0>) @ex::@dispatch2[%c1, %c1, %c1](%8[%c0 to %c128 for %c128], %10[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}, !stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
util.return %11 : !stream.resource<transient>
}

// -----

// This one simulates how to do multi-device synchronization between
// more than two devices.

// CHECK-LABEL: @deviceTripleSync
util.func public @deviceTripleSync(%arg0: i1) -> (!stream.resource<transient>, !stream.resource<transient>, !stream.resource<transient>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c128 = arith.constant 128 : index
%c255_i32 = arith.constant 255 : i32

%0 = stream.async.splat %c255_i32 : i32 -> !stream.resource<transient>{%c128}
%1 = stream.async.dispatch on(#hal.device.affinity<@device0>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
%2 = stream.async.barrier %1 : !stream.resource<transient>{%c128}

%3 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}

// CHECK: stream.async.transfer
// CHECK-SAME: on(#hal.device.affinity<@device1>)
%4 = stream.async.transfer %3 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device1>) -> to(#hal.device.affinity<@device0>) !stream.resource<transient>{%c128}
%5 = stream.async.dispatch on(#hal.device.affinity<@device2>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}

// CHECK: stream.async.transfer
// CHECK-SAME: on(#hal.device.affinity<@device2>)
%6 = stream.async.transfer %5 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device2>) -> to(#hal.device.affinity<@device0>) !stream.resource<transient>{%c128}
%7 = stream.async.dispatch on(#hal.device.affinity<@device0>) @ex::@dispatch2[%c1, %c1, %c1](%2[%c0 to %c128 for %c128], %4[%c0 to %c128 for %c128], %6[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}, !stream.resource<transient>{%c128}, !stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
%8 = stream.async.barrier %7 : !stream.resource<transient>{%c128}
%11 = stream.async.dispatch on(#hal.device.affinity<@device0>) @ex::@dispatch0[%c1, %c1, %c1](%8[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}

// CHECK: stream.async.transfer
// CHECK-SAME: on(#hal.device.affinity<@device1>)
%9 = stream.async.transfer %7 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device1>) !stream.resource<transient>{%c128}
%12 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch0[%c1, %c1, %c1](%9[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}

// CHECK: stream.async.transfer
// CHECK-SAME: on(#hal.device.affinity<@device2>)
%10 = stream.async.transfer %7 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device2>) !stream.resource<transient>{%c128}
%13 = stream.async.dispatch on(#hal.device.affinity<@device2>) @ex::@dispatch0[%c1, %c1, %c1](%10[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
util.return %11, %12, %13 : !stream.resource<transient>, !stream.resource<transient>, !stream.resource<transient>
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ util.func public @deviceTripleSync(%arg0: i1) -> (!stream.resource<transient>, !
// CHECK: stream.async.dispatch
// CHECK: stream.async.transfer
%3 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
%4 = stream.async.transfer %3 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device1>) -> to(#hal.device.affinity<@device0>) !stream.resource<transient>{%c128}
%4 = stream.async.transfer %3 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device1>) -> to(#hal.device.affinity<@device0>) on(#hal.device.affinity<@device1>) !stream.resource<transient>{%c128}

// CHECK: stream.async.execute
// CHECK: stream.async.splat
// CHECK: stream.async.dispatch
// CHECK: stream.async.transfer
%5 = stream.async.dispatch on(#hal.device.affinity<@device2>) @ex::@dispatch0[%c1, %c1, %c1](%0[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}
%6 = stream.async.transfer %5 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device2>) -> to(#hal.device.affinity<@device0>) !stream.resource<transient>{%c128}
%6 = stream.async.transfer %5 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device2>) -> to(#hal.device.affinity<@device0>) on(#hal.device.affinity<@device2>) !stream.resource<transient>{%c128}

// CHECK: stream.async.execute
// CHECK: stream.async.dispatch
Expand All @@ -121,13 +121,13 @@ util.func public @deviceTripleSync(%arg0: i1) -> (!stream.resource<transient>, !
// CHECK: stream.async.execute
// CHECK: stream.async.transfer
// CHECK: stream.async.dispatch
%9 = stream.async.transfer %7 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device1>) !stream.resource<transient>{%c128}
%9 = stream.async.transfer %7 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device1>) on(#hal.device.affinity<@device1>) !stream.resource<transient>{%c128}
%12 = stream.async.dispatch on(#hal.device.affinity<@device1>) @ex::@dispatch0[%c1, %c1, %c1](%9[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}

// CHECK: stream.async.execute
// CHECK: stream.async.transfer
// CHECK: stream.async.dispatch
%10 = stream.async.transfer %7 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device2>) !stream.resource<transient>{%c128}
%10 = stream.async.transfer %7 : !stream.resource<transient>{%c128} from(#hal.device.affinity<@device0>) -> to(#hal.device.affinity<@device2>) on(#hal.device.affinity<@device2>) !stream.resource<transient>{%c128}
%13 = stream.async.dispatch on(#hal.device.affinity<@device2>) @ex::@dispatch0[%c1, %c1, %c1](%10[%c0 to %c128 for %c128]) : (!stream.resource<transient>{%c128}) -> !stream.resource<transient>{%c128}

util.return %11, %12, %13 : !stream.resource<transient>, !stream.resource<transient>, !stream.resource<transient>
Expand Down

0 comments on commit d47e8a1

Please sign in to comment.